From 6e5d46bf37858935d4dc0d2325f4757661002369 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Mar 2015 13:52:29 +0800 Subject: [PATCH] Refactor Search API --- callback_delete.go | 2 +- callback_query.go | 2 +- main.go | 44 +++++++-------- preload.go | 4 +- scope.go | 16 +++--- scope_private.go | 60 ++++++++++---------- search.go | 134 ++++++++++++++++++++++----------------------- search_test.go | 10 ++-- 8 files changed, 136 insertions(+), 136 deletions(-) diff --git a/callback_delete.go b/callback_delete.go index e46e0221..eb1c202b 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -8,7 +8,7 @@ func BeforeDelete(scope *Scope) { func Delete(scope *Scope) { if !scope.HasError() { - if !scope.Search.Unscope && scope.HasColumn("DeletedAt") { + if !scope.Search.unscoped && scope.HasColumn("DeletedAt") { scope.Raw( fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.QuotedTableName(), diff --git a/callback_query.go b/callback_query.go index 888702de..4538b272 100644 --- a/callback_query.go +++ b/callback_query.go @@ -23,7 +23,7 @@ func Query(scope *Scope) { if orderBy, ok := scope.InstanceGet("gorm:order_by_primary_key"); ok { if primaryKey := scope.PrimaryKey(); primaryKey != "" { - scope.Search = scope.Search.clone().order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) + scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) } } diff --git a/main.go b/main.go index c00bd810..e2ffe417 100644 --- a/main.go +++ b/main.go @@ -94,7 +94,7 @@ func (s *DB) New() *DB { func (db *DB) NewScope(value interface{}) *Scope { dbClone := db.clone() dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search, Value: value} + return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} } // CommonDB Return the underlying sql.DB or sql.Tx instance. @@ -128,43 +128,43 @@ func (s *DB) SingularTable(enable bool) { } func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.where(query, args...).db + return s.clone().search.Where(query, args...).db } func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.or(query, args...).db + return s.clone().search.Or(query, args...).db } func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.not(query, args...).db + return s.clone().search.Not(query, args...).db } func (s *DB) Limit(value interface{}) *DB { - return s.clone().search.limit(value).db + return s.clone().search.Limit(value).db } func (s *DB) Offset(value interface{}) *DB { - return s.clone().search.offset(value).db + return s.clone().search.Offset(value).db } func (s *DB) Order(value string, reorder ...bool) *DB { - return s.clone().search.order(value, reorder...).db + return s.clone().search.Order(value, reorder...).db } func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.selects(query, args...).db + return s.clone().search.Selects(query, args...).db } func (s *DB) Group(query string) *DB { - return s.clone().search.group(query).db + return s.clone().search.Group(query).db } func (s *DB) Having(query string, values ...interface{}) *DB { - return s.clone().search.having(query, values...).db + return s.clone().search.Having(query, values...).db } func (s *DB) Joins(query string) *DB { - return s.clone().search.joins(query).db + return s.clone().search.Joins(query).db } func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { @@ -175,27 +175,27 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { } func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db + return s.clone().search.Unscoped().db } func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.attrs(attrs...).db + return s.clone().search.Attrs(attrs...).db } func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.assign(attrs...).db + return s.clone().search.Assign(attrs...).db } func (s *DB) First(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) - newScope.Search = newScope.Search.clone().limit(1) + newScope.Search.Limit(1) return newScope.InstanceSet("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) - newScope.Search = newScope.Search.clone().limit(1) + newScope.Search.Limit(1) return newScope.InstanceSet("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } @@ -226,7 +226,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { } c.NewScope(out).inlineCondition(where...).initialize() } else { - c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.AssignAttrs), false) + c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.assignAttrs), false) } return c } @@ -238,8 +238,8 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { return result } c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) - } else if len(c.search.AssignAttrs) > 0 { - c.NewScope(out).InstanceSet("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates) + } else if len(c.search.assignAttrs) > 0 { + c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates) } return c } @@ -284,7 +284,7 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB { } func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.raw(true).where(sql, values...).db + return s.clone().search.Raw(true).Where(sql, values...).db } func (s *DB) Exec(sql string, values ...interface{}) *DB { @@ -315,7 +315,7 @@ func (s *DB) Count(value interface{}) *DB { func (s *DB) Table(name string) *DB { clone := s.clone() - clone.search.table(name) + clone.search.Table(name) clone.Value = nil return clone } @@ -447,7 +447,7 @@ func (s *DB) Association(column string) *Association { } func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.preload(column, conditions...).db + return s.clone().search.Preload(column, conditions...).db } // Set set value by name diff --git a/preload.go b/preload.go index 8880870e..d252238a 100644 --- a/preload.go +++ b/preload.go @@ -23,8 +23,8 @@ func Preload(scope *Scope) { fields := scope.Fields() isSlice := scope.IndirectValue().Kind() == reflect.Slice - if scope.Search.Preload != nil { - for key, conditions := range scope.Search.Preload { + if scope.Search.preload != nil { + for key, conditions := range scope.Search.preload { for _, field := range fields { if field.Name == key && field.Relationship != nil { results := makeSlice(field.Struct.Type) diff --git a/scope.go b/scope.go index 2cfeaa9d..b8cbd1f3 100644 --- a/scope.go +++ b/scope.go @@ -10,15 +10,15 @@ import ( ) type Scope struct { - Value interface{} - indirectValue *reflect.Value Search *search + Value interface{} Sql string SqlVars []interface{} db *DB - skipLeft bool - primaryKeyField *Field + indirectValue *reflect.Value instanceId string + primaryKeyField *Field + skipLeft bool fields map[string]*Field } @@ -225,15 +225,15 @@ func (scope *Scope) AddToVars(value interface{}) string { // TableName get table name func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.TableName) > 0 { - return scope.Search.TableName + if scope.Search != nil && len(scope.Search.tableName) > 0 { + return scope.Search.tableName } return scope.GetModelStruct().TableName } func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.TableName) > 0 { - return scope.Quote(scope.Search.TableName) + if scope.Search != nil && len(scope.Search.tableName) > 0 { + return scope.Quote(scope.Search.tableName) } else { return scope.Quote(scope.TableName()) } diff --git a/scope_private.go b/scope_private.go index b5b5bb9c..e3b97509 100644 --- a/scope_private.go +++ b/scope_private.go @@ -159,7 +159,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) func (scope *Scope) whereSql() (sql string) { var primaryConditions, andConditions, orConditions []string - if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil { + if !scope.Search.unscoped && scope.Fields()["deleted_at"] != nil { sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) primaryConditions = append(primaryConditions, sql) } @@ -168,19 +168,19 @@ func (scope *Scope) whereSql() (sql string) { primaryConditions = append(primaryConditions, scope.primaryCondition(scope.AddToVars(scope.PrimaryKeyValue()))) } - for _, clause := range scope.Search.WhereConditions { + for _, clause := range scope.Search.whereConditions { if sql := scope.buildWhereCondition(clause); sql != "" { andConditions = append(andConditions, sql) } } - for _, clause := range scope.Search.OrConditions { + for _, clause := range scope.Search.orConditions { if sql := scope.buildWhereCondition(clause); sql != "" { orConditions = append(orConditions, sql) } } - for _, clause := range scope.Search.NotConditions { + for _, clause := range scope.Search.notConditions { if sql := scope.buildNotCondition(clause); sql != "" { andConditions = append(andConditions, sql) } @@ -208,76 +208,76 @@ func (scope *Scope) whereSql() (sql string) { } func (scope *Scope) selectSql() string { - if len(scope.Search.Selects) == 0 { + if len(scope.Search.selects) == 0 { return "*" } - return scope.buildSelectQuery(scope.Search.Selects) + return scope.buildSelectQuery(scope.Search.selects) } func (scope *Scope) orderSql() string { - if len(scope.Search.Orders) == 0 { + if len(scope.Search.orders) == 0 { return "" } - return " ORDER BY " + strings.Join(scope.Search.Orders, ",") + return " ORDER BY " + strings.Join(scope.Search.orders, ",") } func (scope *Scope) limitSql() string { if !scope.Dialect().HasTop() { - if len(scope.Search.Limit) == 0 { + if len(scope.Search.limit) == 0 { return "" } - return " LIMIT " + scope.Search.Limit + return " LIMIT " + scope.Search.limit } return "" } func (scope *Scope) topSql() string { - if scope.Dialect().HasTop() && len(scope.Search.Offset) == 0 { - if len(scope.Search.Limit) == 0 { + if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 { + if len(scope.Search.limit) == 0 { return "" } - return " TOP(" + scope.Search.Limit + ")" + return " TOP(" + scope.Search.limit + ")" } return "" } func (scope *Scope) offsetSql() string { - if len(scope.Search.Offset) == 0 { + if len(scope.Search.offset) == 0 { return "" } if scope.Dialect().HasTop() { - sql := " OFFSET " + scope.Search.Offset + " ROW " - if len(scope.Search.Limit) > 0 { - sql += "FETCH NEXT " + scope.Search.Limit + " ROWS ONLY" + sql := " OFFSET " + scope.Search.offset + " ROW " + if len(scope.Search.limit) > 0 { + sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY" } return sql } - return " OFFSET " + scope.Search.Offset + return " OFFSET " + scope.Search.offset } func (scope *Scope) groupSql() string { - if len(scope.Search.Group) == 0 { + if len(scope.Search.group) == 0 { return "" } - return " GROUP BY " + scope.Search.Group + return " GROUP BY " + scope.Search.group } func (scope *Scope) havingSql() string { - if scope.Search.HavingCondition == nil { + if scope.Search.havingCondition == nil { return "" } - return " HAVING " + scope.buildWhereCondition(scope.Search.HavingCondition) + return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition) } func (scope *Scope) joinsSql() string { - return scope.Search.Joins + " " + return scope.Search.joins + " " } func (scope *Scope) prepareQuerySql() { - if scope.Search.Raw { + if scope.Search.raw { scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) } else { scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) @@ -287,7 +287,7 @@ func (scope *Scope) prepareQuerySql() { func (scope *Scope) inlineCondition(values ...interface{}) *Scope { if len(values) > 0 { - scope.Search = scope.Search.clone().where(values[0], values[1:]...) + scope.Search.Where(values[0], values[1:]...) } return scope } @@ -348,17 +348,17 @@ func (scope *Scope) rows() (*sql.Rows, error) { } func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.WhereConditions { + for _, clause := range scope.Search.whereConditions { scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) } - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs), false) - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) return scope } func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search = scope.Search.clone().selects(column) + scope.Search.Selects(column) if dest.Kind() != reflect.Slice { scope.Err(errors.New("results should be a slice")) return scope @@ -377,7 +377,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { } func (scope *Scope) count(value interface{}) *Scope { - scope.Search = scope.Search.clone().selects("count(*)") + scope.Search.Selects("count(*)") scope.Err(scope.row().Scan(value)) return scope } diff --git a/search.go b/search.go index 75b1285a..4bf0775e 100644 --- a/search.go +++ b/search.go @@ -4,129 +4,129 @@ import "fmt" type search struct { db *DB - WhereConditions []map[string]interface{} - OrConditions []map[string]interface{} - NotConditions []map[string]interface{} - HavingCondition map[string]interface{} - InitAttrs []interface{} - AssignAttrs []interface{} - Selects map[string]interface{} - Orders []string - Joins string - Preload map[string][]interface{} - Offset string - Limit string - Group string - TableName string - Unscope bool - Raw bool + whereConditions []map[string]interface{} + orConditions []map[string]interface{} + notConditions []map[string]interface{} + havingCondition map[string]interface{} + initAttrs []interface{} + assignAttrs []interface{} + selects map[string]interface{} + orders []string + joins string + preload map[string][]interface{} + offset string + limit string + group string + tableName string + unscoped bool + raw bool } func (s *search) clone() *search { return &search{ - Preload: s.Preload, - WhereConditions: s.WhereConditions, - OrConditions: s.OrConditions, - NotConditions: s.NotConditions, - HavingCondition: s.HavingCondition, - InitAttrs: s.InitAttrs, - AssignAttrs: s.AssignAttrs, - Selects: s.Selects, - Orders: s.Orders, - Joins: s.Joins, - Offset: s.Offset, - Limit: s.Limit, - Group: s.Group, - TableName: s.TableName, - Unscope: s.Unscope, - Raw: s.Raw, + preload: s.preload, + whereConditions: s.whereConditions, + orConditions: s.orConditions, + notConditions: s.notConditions, + havingCondition: s.havingCondition, + initAttrs: s.initAttrs, + assignAttrs: s.assignAttrs, + selects: s.selects, + orders: s.orders, + joins: s.joins, + offset: s.offset, + limit: s.limit, + group: s.group, + tableName: s.tableName, + unscoped: s.unscoped, + raw: s.raw, } } -func (s *search) where(query interface{}, values ...interface{}) *search { - s.WhereConditions = append(s.WhereConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Where(query interface{}, values ...interface{}) *search { + s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) return s } -func (s *search) not(query interface{}, values ...interface{}) *search { - s.NotConditions = append(s.NotConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Not(query interface{}, values ...interface{}) *search { + s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) return s } -func (s *search) or(query interface{}, values ...interface{}) *search { - s.OrConditions = append(s.OrConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Or(query interface{}, values ...interface{}) *search { + s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) return s } -func (s *search) attrs(attrs ...interface{}) *search { - s.InitAttrs = append(s.InitAttrs, toSearchableMap(attrs...)) +func (s *search) Attrs(attrs ...interface{}) *search { + s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) return s } -func (s *search) assign(attrs ...interface{}) *search { - s.AssignAttrs = append(s.AssignAttrs, toSearchableMap(attrs...)) +func (s *search) Assign(attrs ...interface{}) *search { + s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) return s } -func (s *search) order(value string, reorder ...bool) *search { +func (s *search) Order(value string, reorder ...bool) *search { if len(reorder) > 0 && reorder[0] { - s.Orders = []string{value} + s.orders = []string{value} } else { - s.Orders = append(s.Orders, value) + s.orders = append(s.orders, value) } return s } -func (s *search) selects(query interface{}, args ...interface{}) *search { - s.Selects = map[string]interface{}{"query": query, "args": args} +func (s *search) Selects(query interface{}, args ...interface{}) *search { + s.selects = map[string]interface{}{"query": query, "args": args} return s } -func (s *search) limit(value interface{}) *search { - s.Limit = s.getInterfaceAsSql(value) +func (s *search) Limit(value interface{}) *search { + s.limit = s.getInterfaceAsSql(value) return s } -func (s *search) offset(value interface{}) *search { - s.Offset = s.getInterfaceAsSql(value) +func (s *search) Offset(value interface{}) *search { + s.offset = s.getInterfaceAsSql(value) return s } -func (s *search) group(query string) *search { - s.Group = s.getInterfaceAsSql(query) +func (s *search) Group(query string) *search { + s.group = s.getInterfaceAsSql(query) return s } -func (s *search) having(query string, values ...interface{}) *search { - s.HavingCondition = map[string]interface{}{"query": query, "args": values} +func (s *search) Having(query string, values ...interface{}) *search { + s.havingCondition = map[string]interface{}{"query": query, "args": values} return s } -func (s *search) joins(query string) *search { - s.Joins = query +func (s *search) Joins(query string) *search { + s.joins = query return s } -func (s *search) preload(column string, values ...interface{}) *search { - if s.Preload == nil { - s.Preload = map[string][]interface{}{} +func (s *search) Preload(column string, values ...interface{}) *search { + if s.preload == nil { + s.preload = map[string][]interface{}{} } - s.Preload[column] = values + s.preload[column] = values return s } -func (s *search) raw(b bool) *search { - s.Raw = b +func (s *search) Raw(b bool) *search { + s.raw = b return s } -func (s *search) unscoped() *search { - s.Unscope = true +func (s *search) Unscoped() *search { + s.unscoped = true return s } -func (s *search) table(name string) *search { - s.TableName = name +func (s *search) Table(name string) *search { + s.tableName = name return s } diff --git a/search_test.go b/search_test.go index 2d0af08b..79be3a07 100644 --- a/search_test.go +++ b/search_test.go @@ -7,20 +7,20 @@ import ( func TestCloneSearch(t *testing.T) { s := new(search) - s.where("name = ?", "jinzhu").order("name").attrs("name", "jinzhu").selects("name, age") + s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Selects("name, age") s1 := s.clone() - s1.where("age = ?", 20).order("age").attrs("email", "a@e.org").selects("email") + s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Selects("email") - if reflect.DeepEqual(s.WhereConditions, s1.WhereConditions) { + if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { t.Errorf("Where should be copied") } - if reflect.DeepEqual(s.Orders, s1.Orders) { + if reflect.DeepEqual(s.orders, s1.orders) { t.Errorf("Order should be copied") } - if reflect.DeepEqual(s.InitAttrs, s1.InitAttrs) { + if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { t.Errorf("InitAttrs should be copied") }