Add Scopes Support

This commit is contained in:
Jinzhu 2013-11-18 14:35:44 +08:00
parent 1a2eef181a
commit 87f493d080
4 changed files with 74 additions and 2 deletions

View File

@ -13,6 +13,7 @@ Yet Another ORM library for Go, aims for developer friendly
* Logger Support * Logger Support
* Bind struct with tag * Bind struct with tag
* Iteration Support via [Rows](#row--rows) * Iteration Support via [Rows](#row--rows)
* Scopes
* sql.Scanner support * sql.Scanner support
* Every feature comes with tests * Every feature comes with tests
* Convention Over Configuration * Convention Over Configuration
@ -666,6 +667,37 @@ tx.Rollback()
tx.Commit() tx.Commit()
``` ```
## Scopes
```go
func AmountGreaterThan1000(d *gorm.DB) *gorm.DB {
d.Where("amount > ?", 1000)
}
func PaidWithCreditCard(d *gorm.DB) *gorm.DB {
d.Where("pay_mode_sign = ?", "C")
}
func PaidWithCod(d *gorm.DB) *gorm.DB {
d.Where("pay_mode_sign = ?", "C")
}
func OrderStatus(status []string) func (d *gorm.DB) *gorm.DB {
return func (d *gorm.DB) *gorm.DB {
return d.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
}
}
db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders)
// Find all credit card orders and amount greater than 1000
db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders)
// Find all COD orders and amount greater than 1000
db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Find all paid, shipped orders and amount greater than 1000
```
## Logger ## Logger
Grom has builtin logger support, enable it with: Grom has builtin logger support, enable it with:
@ -790,7 +822,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
``` ```
## TODO ## TODO
* Scopes
* Joins * Joins
* Scan * Scan
* AlertColumn, DropColumn, AddIndex, RemoveIndex * AlertColumn, DropColumn, AddIndex, RemoveIndex

View File

@ -1430,6 +1430,38 @@ func TestGroup(t *testing.T) {
} }
} }
func NameIn1And2(d *DB) *DB {
return d.Where("name in (?)", []string{"1", "2"})
}
func NameIn2And3(d *DB) *DB {
return d.Where("name in (?)", []string{"2", "3"})
}
func NameIn(names []string) func(d *DB) *DB {
return func(d *DB) *DB {
return d.Where("name in (?)", names)
}
}
func TestScopes(t *testing.T) {
var users1, users2, users3 []User
db.Scopes(NameIn1And2).Find(&users1)
if len(users1) != 2 {
t.Errorf("Should only have two users's name in 1, 2")
}
db.Scopes(NameIn1And2, NameIn2And3).Find(&users2)
if len(users2) != 1 {
t.Errorf("Should only have two users's name is 2")
}
db.Scopes(NameIn([]string{"1", "2"})).Find(&users3)
if len(users3) != 2 {
t.Errorf("Should only have two users's name is 2")
}
}
func TestHaving(t *testing.T) { func TestHaving(t *testing.T) {
rows, err := db.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() rows, err := db.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows()

View File

@ -93,6 +93,14 @@ func (s *DB) Includes(value interface{}) *DB {
return s.clone().search.includes(value).db return s.clone().search.includes(value).db
} }
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
c := s
for _, f := range funcs {
c = f(c)
}
return c
}
func (s *DB) Unscoped() *DB { func (s *DB) Unscoped() *DB {
return s.clone().search.unscoped().db return s.clone().search.unscoped().db
} }

View File

@ -74,6 +74,8 @@ func (m *Model) fields(operation string) (fields []*Field) {
structs := getStructs(indirect_value.Type()) structs := getStructs(indirect_value.Type())
c := make(chan *Field, len(structs)) c := make(chan *Field, len(structs))
defer close(c)
for _, field_struct := range structs { for _, field_struct := range structs {
go func(field_struct reflect.StructField, c chan *Field) { go func(field_struct reflect.StructField, c chan *Field) {
var field Field var field Field
@ -110,7 +112,6 @@ func (m *Model) fields(operation string) (fields []*Field) {
for i := 0; i < len(structs); i++ { for i := 0; i < len(structs); i++ {
fields = append(fields, <-c) fields = append(fields, <-c)
} }
close(c)
if len(m._cache_fields) == 0 { if len(m._cache_fields) == 0 {
m._cache_fields = map[string][]*Field{} m._cache_fields = map[string][]*Field{}