Refactor Search API

This commit is contained in:
Jinzhu 2015-03-12 13:52:29 +08:00
parent 9f2959959b
commit 6e5d46bf37
8 changed files with 136 additions and 136 deletions

View File

@ -8,7 +8,7 @@ func BeforeDelete(scope *Scope) {
func Delete(scope *Scope) {
if !scope.HasError() {
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
if !scope.Search.unscoped && scope.HasColumn("DeletedAt") {
scope.Raw(
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
scope.QuotedTableName(),

View File

@ -23,7 +23,7 @@ func Query(scope *Scope) {
if orderBy, ok := scope.InstanceGet("gorm:order_by_primary_key"); ok {
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
scope.Search = scope.Search.clone().order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
}
}

44
main.go
View File

@ -94,7 +94,7 @@ func (s *DB) New() *DB {
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}
// CommonDB Return the underlying sql.DB or sql.Tx instance.
@ -128,43 +128,43 @@ func (s *DB) SingularTable(enable bool) {
}
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.where(query, args...).db
return s.clone().search.Where(query, args...).db
}
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
return s.clone().search.or(query, args...).db
return s.clone().search.Or(query, args...).db
}
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.not(query, args...).db
return s.clone().search.Not(query, args...).db
}
func (s *DB) Limit(value interface{}) *DB {
return s.clone().search.limit(value).db
return s.clone().search.Limit(value).db
}
func (s *DB) Offset(value interface{}) *DB {
return s.clone().search.offset(value).db
return s.clone().search.Offset(value).db
}
func (s *DB) Order(value string, reorder ...bool) *DB {
return s.clone().search.order(value, reorder...).db
return s.clone().search.Order(value, reorder...).db
}
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.selects(query, args...).db
return s.clone().search.Selects(query, args...).db
}
func (s *DB) Group(query string) *DB {
return s.clone().search.group(query).db
return s.clone().search.Group(query).db
}
func (s *DB) Having(query string, values ...interface{}) *DB {
return s.clone().search.having(query, values...).db
return s.clone().search.Having(query, values...).db
}
func (s *DB) Joins(query string) *DB {
return s.clone().search.joins(query).db
return s.clone().search.Joins(query).db
}
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
@ -175,27 +175,27 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
}
func (s *DB) Unscoped() *DB {
return s.clone().search.unscoped().db
return s.clone().search.Unscoped().db
}
func (s *DB) Attrs(attrs ...interface{}) *DB {
return s.clone().search.attrs(attrs...).db
return s.clone().search.Attrs(attrs...).db
}
func (s *DB) Assign(attrs ...interface{}) *DB {
return s.clone().search.assign(attrs...).db
return s.clone().search.Assign(attrs...).db
}
func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone().limit(1)
newScope.Search.Limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone().limit(1)
newScope.Search.Limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
@ -226,7 +226,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
}
c.NewScope(out).inlineCondition(where...).initialize()
} else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.AssignAttrs), false)
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.assignAttrs), false)
}
return c
}
@ -238,8 +238,8 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
return result
}
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
} else if len(c.search.AssignAttrs) > 0 {
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates)
} else if len(c.search.assignAttrs) > 0 {
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates)
}
return c
}
@ -284,7 +284,7 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
}
func (s *DB) Raw(sql string, values ...interface{}) *DB {
return s.clone().search.raw(true).where(sql, values...).db
return s.clone().search.Raw(true).Where(sql, values...).db
}
func (s *DB) Exec(sql string, values ...interface{}) *DB {
@ -315,7 +315,7 @@ func (s *DB) Count(value interface{}) *DB {
func (s *DB) Table(name string) *DB {
clone := s.clone()
clone.search.table(name)
clone.search.Table(name)
clone.Value = nil
return clone
}
@ -447,7 +447,7 @@ func (s *DB) Association(column string) *Association {
}
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
return s.clone().search.preload(column, conditions...).db
return s.clone().search.Preload(column, conditions...).db
}
// Set set value by name

View File

@ -23,8 +23,8 @@ func Preload(scope *Scope) {
fields := scope.Fields()
isSlice := scope.IndirectValue().Kind() == reflect.Slice
if scope.Search.Preload != nil {
for key, conditions := range scope.Search.Preload {
if scope.Search.preload != nil {
for key, conditions := range scope.Search.preload {
for _, field := range fields {
if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Struct.Type)

View File

@ -10,15 +10,15 @@ import (
)
type Scope struct {
Value interface{}
indirectValue *reflect.Value
Search *search
Value interface{}
Sql string
SqlVars []interface{}
db *DB
skipLeft bool
primaryKeyField *Field
indirectValue *reflect.Value
instanceId string
primaryKeyField *Field
skipLeft bool
fields map[string]*Field
}
@ -225,15 +225,15 @@ func (scope *Scope) AddToVars(value interface{}) string {
// TableName get table name
func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName
if scope.Search != nil && len(scope.Search.tableName) > 0 {
return scope.Search.tableName
}
return scope.GetModelStruct().TableName
}
func (scope *Scope) QuotedTableName() (name string) {
if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Quote(scope.Search.TableName)
if scope.Search != nil && len(scope.Search.tableName) > 0 {
return scope.Quote(scope.Search.tableName)
} else {
return scope.Quote(scope.TableName())
}

View File

@ -159,7 +159,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
func (scope *Scope) whereSql() (sql string) {
var primaryConditions, andConditions, orConditions []string
if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil {
if !scope.Search.unscoped && scope.Fields()["deleted_at"] != nil {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
primaryConditions = append(primaryConditions, sql)
}
@ -168,19 +168,19 @@ func (scope *Scope) whereSql() (sql string) {
primaryConditions = append(primaryConditions, scope.primaryCondition(scope.AddToVars(scope.PrimaryKeyValue())))
}
for _, clause := range scope.Search.WhereConditions {
for _, clause := range scope.Search.whereConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
for _, clause := range scope.Search.OrConditions {
for _, clause := range scope.Search.orConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
orConditions = append(orConditions, sql)
}
}
for _, clause := range scope.Search.NotConditions {
for _, clause := range scope.Search.notConditions {
if sql := scope.buildNotCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
@ -208,76 +208,76 @@ func (scope *Scope) whereSql() (sql string) {
}
func (scope *Scope) selectSql() string {
if len(scope.Search.Selects) == 0 {
if len(scope.Search.selects) == 0 {
return "*"
}
return scope.buildSelectQuery(scope.Search.Selects)
return scope.buildSelectQuery(scope.Search.selects)
}
func (scope *Scope) orderSql() string {
if len(scope.Search.Orders) == 0 {
if len(scope.Search.orders) == 0 {
return ""
}
return " ORDER BY " + strings.Join(scope.Search.Orders, ",")
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
}
func (scope *Scope) limitSql() string {
if !scope.Dialect().HasTop() {
if len(scope.Search.Limit) == 0 {
if len(scope.Search.limit) == 0 {
return ""
}
return " LIMIT " + scope.Search.Limit
return " LIMIT " + scope.Search.limit
}
return ""
}
func (scope *Scope) topSql() string {
if scope.Dialect().HasTop() && len(scope.Search.Offset) == 0 {
if len(scope.Search.Limit) == 0 {
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
if len(scope.Search.limit) == 0 {
return ""
}
return " TOP(" + scope.Search.Limit + ")"
return " TOP(" + scope.Search.limit + ")"
}
return ""
}
func (scope *Scope) offsetSql() string {
if len(scope.Search.Offset) == 0 {
if len(scope.Search.offset) == 0 {
return ""
}
if scope.Dialect().HasTop() {
sql := " OFFSET " + scope.Search.Offset + " ROW "
if len(scope.Search.Limit) > 0 {
sql += "FETCH NEXT " + scope.Search.Limit + " ROWS ONLY"
sql := " OFFSET " + scope.Search.offset + " ROW "
if len(scope.Search.limit) > 0 {
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
}
return sql
}
return " OFFSET " + scope.Search.Offset
return " OFFSET " + scope.Search.offset
}
func (scope *Scope) groupSql() string {
if len(scope.Search.Group) == 0 {
if len(scope.Search.group) == 0 {
return ""
}
return " GROUP BY " + scope.Search.Group
return " GROUP BY " + scope.Search.group
}
func (scope *Scope) havingSql() string {
if scope.Search.HavingCondition == nil {
if scope.Search.havingCondition == nil {
return ""
}
return " HAVING " + scope.buildWhereCondition(scope.Search.HavingCondition)
return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition)
}
func (scope *Scope) joinsSql() string {
return scope.Search.Joins + " "
return scope.Search.joins + " "
}
func (scope *Scope) prepareQuerySql() {
if scope.Search.Raw {
if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else {
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
@ -287,7 +287,7 @@ func (scope *Scope) prepareQuerySql() {
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
if len(values) > 0 {
scope.Search = scope.Search.clone().where(values[0], values[1:]...)
scope.Search.Where(values[0], values[1:]...)
}
return scope
}
@ -348,17 +348,17 @@ func (scope *Scope) rows() (*sql.Rows, error) {
}
func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.WhereConditions {
for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
return scope
}
func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search = scope.Search.clone().selects(column)
scope.Search.Selects(column)
if dest.Kind() != reflect.Slice {
scope.Err(errors.New("results should be a slice"))
return scope
@ -377,7 +377,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
}
func (scope *Scope) count(value interface{}) *Scope {
scope.Search = scope.Search.clone().selects("count(*)")
scope.Search.Selects("count(*)")
scope.Err(scope.row().Scan(value))
return scope
}

134
search.go
View File

@ -4,129 +4,129 @@ import "fmt"
type search struct {
db *DB
WhereConditions []map[string]interface{}
OrConditions []map[string]interface{}
NotConditions []map[string]interface{}
HavingCondition map[string]interface{}
InitAttrs []interface{}
AssignAttrs []interface{}
Selects map[string]interface{}
Orders []string
Joins string
Preload map[string][]interface{}
Offset string
Limit string
Group string
TableName string
Unscope bool
Raw bool
whereConditions []map[string]interface{}
orConditions []map[string]interface{}
notConditions []map[string]interface{}
havingCondition map[string]interface{}
initAttrs []interface{}
assignAttrs []interface{}
selects map[string]interface{}
orders []string
joins string
preload map[string][]interface{}
offset string
limit string
group string
tableName string
unscoped bool
raw bool
}
func (s *search) clone() *search {
return &search{
Preload: s.Preload,
WhereConditions: s.WhereConditions,
OrConditions: s.OrConditions,
NotConditions: s.NotConditions,
HavingCondition: s.HavingCondition,
InitAttrs: s.InitAttrs,
AssignAttrs: s.AssignAttrs,
Selects: s.Selects,
Orders: s.Orders,
Joins: s.Joins,
Offset: s.Offset,
Limit: s.Limit,
Group: s.Group,
TableName: s.TableName,
Unscope: s.Unscope,
Raw: s.Raw,
preload: s.preload,
whereConditions: s.whereConditions,
orConditions: s.orConditions,
notConditions: s.notConditions,
havingCondition: s.havingCondition,
initAttrs: s.initAttrs,
assignAttrs: s.assignAttrs,
selects: s.selects,
orders: s.orders,
joins: s.joins,
offset: s.offset,
limit: s.limit,
group: s.group,
tableName: s.tableName,
unscoped: s.unscoped,
raw: s.raw,
}
}
func (s *search) where(query interface{}, values ...interface{}) *search {
s.WhereConditions = append(s.WhereConditions, map[string]interface{}{"query": query, "args": values})
func (s *search) Where(query interface{}, values ...interface{}) *search {
s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) not(query interface{}, values ...interface{}) *search {
s.NotConditions = append(s.NotConditions, map[string]interface{}{"query": query, "args": values})
func (s *search) Not(query interface{}, values ...interface{}) *search {
s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) or(query interface{}, values ...interface{}) *search {
s.OrConditions = append(s.OrConditions, map[string]interface{}{"query": query, "args": values})
func (s *search) Or(query interface{}, values ...interface{}) *search {
s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) attrs(attrs ...interface{}) *search {
s.InitAttrs = append(s.InitAttrs, toSearchableMap(attrs...))
func (s *search) Attrs(attrs ...interface{}) *search {
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
return s
}
func (s *search) assign(attrs ...interface{}) *search {
s.AssignAttrs = append(s.AssignAttrs, toSearchableMap(attrs...))
func (s *search) Assign(attrs ...interface{}) *search {
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
return s
}
func (s *search) order(value string, reorder ...bool) *search {
func (s *search) Order(value string, reorder ...bool) *search {
if len(reorder) > 0 && reorder[0] {
s.Orders = []string{value}
s.orders = []string{value}
} else {
s.Orders = append(s.Orders, value)
s.orders = append(s.orders, value)
}
return s
}
func (s *search) selects(query interface{}, args ...interface{}) *search {
s.Selects = map[string]interface{}{"query": query, "args": args}
func (s *search) Selects(query interface{}, args ...interface{}) *search {
s.selects = map[string]interface{}{"query": query, "args": args}
return s
}
func (s *search) limit(value interface{}) *search {
s.Limit = s.getInterfaceAsSql(value)
func (s *search) Limit(value interface{}) *search {
s.limit = s.getInterfaceAsSql(value)
return s
}
func (s *search) offset(value interface{}) *search {
s.Offset = s.getInterfaceAsSql(value)
func (s *search) Offset(value interface{}) *search {
s.offset = s.getInterfaceAsSql(value)
return s
}
func (s *search) group(query string) *search {
s.Group = s.getInterfaceAsSql(query)
func (s *search) Group(query string) *search {
s.group = s.getInterfaceAsSql(query)
return s
}
func (s *search) having(query string, values ...interface{}) *search {
s.HavingCondition = map[string]interface{}{"query": query, "args": values}
func (s *search) Having(query string, values ...interface{}) *search {
s.havingCondition = map[string]interface{}{"query": query, "args": values}
return s
}
func (s *search) joins(query string) *search {
s.Joins = query
func (s *search) Joins(query string) *search {
s.joins = query
return s
}
func (s *search) preload(column string, values ...interface{}) *search {
if s.Preload == nil {
s.Preload = map[string][]interface{}{}
func (s *search) Preload(column string, values ...interface{}) *search {
if s.preload == nil {
s.preload = map[string][]interface{}{}
}
s.Preload[column] = values
s.preload[column] = values
return s
}
func (s *search) raw(b bool) *search {
s.Raw = b
func (s *search) Raw(b bool) *search {
s.raw = b
return s
}
func (s *search) unscoped() *search {
s.Unscope = true
func (s *search) Unscoped() *search {
s.unscoped = true
return s
}
func (s *search) table(name string) *search {
s.TableName = name
func (s *search) Table(name string) *search {
s.tableName = name
return s
}

View File

@ -7,20 +7,20 @@ import (
func TestCloneSearch(t *testing.T) {
s := new(search)
s.where("name = ?", "jinzhu").order("name").attrs("name", "jinzhu").selects("name, age")
s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Selects("name, age")
s1 := s.clone()
s1.where("age = ?", 20).order("age").attrs("email", "a@e.org").selects("email")
s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Selects("email")
if reflect.DeepEqual(s.WhereConditions, s1.WhereConditions) {
if reflect.DeepEqual(s.whereConditions, s1.whereConditions) {
t.Errorf("Where should be copied")
}
if reflect.DeepEqual(s.Orders, s1.Orders) {
if reflect.DeepEqual(s.orders, s1.orders) {
t.Errorf("Order should be copied")
}
if reflect.DeepEqual(s.InitAttrs, s1.InitAttrs) {
if reflect.DeepEqual(s.initAttrs, s1.initAttrs) {
t.Errorf("InitAttrs should be copied")
}