diff --git a/README.md b/README.md index ec5ce59f..28ac7592 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Yet Another ORM library for Go, aims for developer friendly * Logger Support * Bind struct with tag * Iteration Support via [Rows](#row--rows) +* Scopes * sql.Scanner support * Every feature comes with tests * Convention Over Configuration @@ -666,6 +667,37 @@ tx.Rollback() tx.Commit() ``` +## Scopes + +```go +func AmountGreaterThan1000(d *gorm.DB) *gorm.DB { + d.Where("amount > ?", 1000) +} + +func PaidWithCreditCard(d *gorm.DB) *gorm.DB { + d.Where("pay_mode_sign = ?", "C") +} + +func PaidWithCod(d *gorm.DB) *gorm.DB { + d.Where("pay_mode_sign = ?", "C") +} + +func OrderStatus(status []string) func (d *gorm.DB) *gorm.DB { + return func (d *gorm.DB) *gorm.DB { + return d.Scopes(AmountGreaterThan1000).Where("status in (?)", status) + } +} + +db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders) +// Find all credit card orders and amount greater than 1000 + +db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders) +// Find all COD orders and amount greater than 1000 + +db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// Find all paid, shipped orders and amount greater than 1000 +``` + ## Logger Grom has builtin logger support, enable it with: @@ -790,7 +822,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111 ``` ## TODO -* Scopes * Joins * Scan * AlertColumn, DropColumn, AddIndex, RemoveIndex diff --git a/gorm_test.go b/gorm_test.go index 1abf2e90..5d4f030c 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1430,6 +1430,38 @@ func TestGroup(t *testing.T) { } } +func NameIn1And2(d *DB) *DB { + return d.Where("name in (?)", []string{"1", "2"}) +} + +func NameIn2And3(d *DB) *DB { + return d.Where("name in (?)", []string{"2", "3"}) +} + +func NameIn(names []string) func(d *DB) *DB { + return func(d *DB) *DB { + return d.Where("name in (?)", names) + } +} + +func TestScopes(t *testing.T) { + var users1, users2, users3 []User + db.Scopes(NameIn1And2).Find(&users1) + if len(users1) != 2 { + t.Errorf("Should only have two users's name in 1, 2") + } + + db.Scopes(NameIn1And2, NameIn2And3).Find(&users2) + if len(users2) != 1 { + t.Errorf("Should only have two users's name is 2") + } + + db.Scopes(NameIn([]string{"1", "2"})).Find(&users3) + if len(users3) != 2 { + t.Errorf("Should only have two users's name is 2") + } +} + func TestHaving(t *testing.T) { rows, err := db.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() diff --git a/main.go b/main.go index eec0fbd6..be8b8981 100644 --- a/main.go +++ b/main.go @@ -93,6 +93,14 @@ func (s *DB) Includes(value interface{}) *DB { return s.clone().search.includes(value).db } +func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { + c := s + for _, f := range funcs { + c = f(c) + } + return c +} + func (s *DB) Unscoped() *DB { return s.clone().search.unscoped().db } diff --git a/model.go b/model.go index 65d8d622..0b54b2bb 100644 --- a/model.go +++ b/model.go @@ -74,6 +74,8 @@ func (m *Model) fields(operation string) (fields []*Field) { structs := getStructs(indirect_value.Type()) c := make(chan *Field, len(structs)) + defer close(c) + for _, field_struct := range structs { go func(field_struct reflect.StructField, c chan *Field) { var field Field @@ -110,7 +112,6 @@ func (m *Model) fields(operation string) (fields []*Field) { for i := 0; i < len(structs); i++ { fields = append(fields, <-c) } - close(c) if len(m._cache_fields) == 0 { m._cache_fields = map[string][]*Field{}