From db68e7a8fe8076db860090ad04d4e685bab62c36 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Jan 2014 09:25:30 +0800 Subject: [PATCH] make callback query works --- callback_query.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++ main.go | 2 +- scope.go | 50 +++++++++++++++++++++---------- scope_condition.go | 12 +++++--- 4 files changed, 118 insertions(+), 21 deletions(-) diff --git a/callback_query.go b/callback_query.go index 25cd97b3..2cb2152a 100644 --- a/callback_query.go +++ b/callback_query.go @@ -1,6 +1,81 @@ package gorm +import ( + "fmt" + "reflect" + "strings" + "time" +) + func Query(scope *Scope) { + defer scope.Trace(time.Now()) + + inlineCondition, ok := scope.Get("gorm:inline_condition") + if ok { + inlineConditions := inlineCondition.([]interface{}) + if len(inlineConditions) > 0 { + scope.Search = scope.Search.clone().where(inlineConditions[0], inlineConditions[1:]...) + } + } + + var ( + isSlice bool + anyRecordFound bool + destType reflect.Type + ) + + var dest = reflect.Indirect(reflect.ValueOf(scope.Value)) + + if dest.Kind() == reflect.Slice { + isSlice = true + destType = dest.Type().Elem() + } else { + scope.Search = scope.Search.clone().limit(1) + } + + if scope.Search.raw { + scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) + } else { + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.SelectSql(), scope.TableName(), scope.CombinedConditionSql())) + } + + if !scope.HasError() { + rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...) + + if scope.Err(err) != nil { + return + } + + defer rows.Close() + for rows.Next() { + anyRecordFound = true + elem := dest + if isSlice { + elem = reflect.New(destType).Elem() + } + + columns, _ := rows.Columns() + var values []interface{} + for _, value := range columns { + field := elem.FieldByName(snakeToUpperCamel(value)) + if field.IsValid() { + values = append(values, field.Addr().Interface()) + } else { + var ignore interface{} + values = append(values, &ignore) + } + } + scope.Err(rows.Scan(values...)) + + if isSlice { + dest.Set(reflect.Append(dest, elem)) + } + } + + if !anyRecordFound && !isSlice { + scope.Err(RecordNotFound) + } + } } func AfterQuery(scope *Scope) { diff --git a/main.go b/main.go index c9135eba..b185c379 100644 --- a/main.go +++ b/main.go @@ -117,7 +117,7 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB { return s.clone().do(out).where(where...).last().db } func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().do(out).where(where...).query().db + return s.clone().NewScope(out).Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db } func (s *DB) Row() *sql.Row { diff --git a/scope.go b/scope.go index 5ebcde98..334229f6 100644 --- a/scope.go +++ b/scope.go @@ -151,24 +151,34 @@ func (scope *Scope) CallMethod(name string) { return } - if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() { - fi := fm.Interface() - if f, ok := fi.(func()); ok { - f() - } else if f, ok := fi.(func(s *Scope)); ok { - f(scope) - } else if f, ok := fi.(func(s *DB)); ok { - f(scope.db.new()) - } else if f, ok := fi.(func() error); ok { - scope.Err(f()) - } else if f, ok := fi.(func(s *Scope) error); ok { - scope.Err(f(scope)) - } else if f, ok := fi.(func(s *DB) error); ok { - scope.Err(f(scope.db.new())) - } else { - scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name))) + call := func(value interface{}) { + if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() { + fi := fm.Interface() + if f, ok := fi.(func()); ok { + f() + } else if f, ok := fi.(func(s *Scope)); ok { + f(scope) + } else if f, ok := fi.(func(s *DB)); ok { + f(scope.db.new()) + } else if f, ok := fi.(func() error); ok { + scope.Err(f()) + } else if f, ok := fi.(func(s *Scope) error); ok { + scope.Err(f(scope)) + } else if f, ok := fi.(func(s *DB) error); ok { + scope.Err(f(scope.db.new())) + } else { + scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name))) + } } } + + if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice { + for i := 0; i < values.Len(); i++ { + call(values.Index(i).Addr().Interface()) + } + } else { + call(scope.Value) + } } func (scope *Scope) AddToVars(value interface{}) string { @@ -367,3 +377,11 @@ func (scope *Scope) CommitOrRollback() *Scope { } return scope } + +func (scope *Scope) SelectSql() string { + if len(scope.Search.selectStr) == 0 { + return "*" + } else { + return scope.Search.selectStr + } +} diff --git a/scope_condition.go b/scope_condition.go index ceb159b7..d57d3a12 100644 --- a/scope_condition.go +++ b/scope_condition.go @@ -43,8 +43,10 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri return strings.Join(sqls, " AND ") case interface{}: var sqls []string - for _, field := range scope.Fields() { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.dbName), scope.AddToVars(field.Value))) + for _, field := range scope.New(value).Fields() { + if !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.DBName), scope.AddToVars(field.Value))) + } } return strings.Join(sqls, " AND ") } @@ -102,8 +104,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string return strings.Join(sqls, " AND ") case interface{}: var sqls []string - for _, field := range scope.Fields() { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.dbName), scope.AddToVars(field.Value))) + for _, field := range scope.New(value).Fields() { + if !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.DBName), scope.AddToVars(field.Value))) + } } return strings.Join(sqls, " AND ") }