make callback query works

This commit is contained in:
Jinzhu 2014-01-28 09:25:30 +08:00
parent 048b8b6abe
commit db68e7a8fe
4 changed files with 118 additions and 21 deletions

View File

@ -1,6 +1,81 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
func Query(scope *Scope) {
defer scope.Trace(time.Now())
inlineCondition, ok := scope.Get("gorm:inline_condition")
if ok {
inlineConditions := inlineCondition.([]interface{})
if len(inlineConditions) > 0 {
scope.Search = scope.Search.clone().where(inlineConditions[0], inlineConditions[1:]...)
}
}
var (
isSlice bool
anyRecordFound bool
destType reflect.Type
)
var dest = reflect.Indirect(reflect.ValueOf(scope.Value))
if dest.Kind() == reflect.Slice {
isSlice = true
destType = dest.Type().Elem()
} else {
scope.Search = scope.Search.clone().limit(1)
}
if scope.Search.raw {
scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE "))
} else {
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.SelectSql(), scope.TableName(), scope.CombinedConditionSql()))
}
if !scope.HasError() {
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...)
if scope.Err(err) != nil {
return
}
defer rows.Close()
for rows.Next() {
anyRecordFound = true
elem := dest
if isSlice {
elem = reflect.New(destType).Elem()
}
columns, _ := rows.Columns()
var values []interface{}
for _, value := range columns {
field := elem.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
values = append(values, field.Addr().Interface())
} else {
var ignore interface{}
values = append(values, &ignore)
}
}
scope.Err(rows.Scan(values...))
if isSlice {
dest.Set(reflect.Append(dest, elem))
}
}
if !anyRecordFound && !isSlice {
scope.Err(RecordNotFound)
}
}
}
func AfterQuery(scope *Scope) {

View File

@ -117,7 +117,7 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB {
return s.clone().do(out).where(where...).last().db
}
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().do(out).where(where...).query().db
return s.clone().NewScope(out).Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Row() *sql.Row {

View File

@ -151,24 +151,34 @@ func (scope *Scope) CallMethod(name string) {
return
}
if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() {
fi := fm.Interface()
if f, ok := fi.(func()); ok {
f()
} else if f, ok := fi.(func(s *Scope)); ok {
f(scope)
} else if f, ok := fi.(func(s *DB)); ok {
f(scope.db.new())
} else if f, ok := fi.(func() error); ok {
scope.Err(f())
} else if f, ok := fi.(func(s *Scope) error); ok {
scope.Err(f(scope))
} else if f, ok := fi.(func(s *DB) error); ok {
scope.Err(f(scope.db.new()))
} else {
scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name)))
call := func(value interface{}) {
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
fi := fm.Interface()
if f, ok := fi.(func()); ok {
f()
} else if f, ok := fi.(func(s *Scope)); ok {
f(scope)
} else if f, ok := fi.(func(s *DB)); ok {
f(scope.db.new())
} else if f, ok := fi.(func() error); ok {
scope.Err(f())
} else if f, ok := fi.(func(s *Scope) error); ok {
scope.Err(f(scope))
} else if f, ok := fi.(func(s *DB) error); ok {
scope.Err(f(scope.db.new()))
} else {
scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name)))
}
}
}
if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice {
for i := 0; i < values.Len(); i++ {
call(values.Index(i).Addr().Interface())
}
} else {
call(scope.Value)
}
}
func (scope *Scope) AddToVars(value interface{}) string {
@ -367,3 +377,11 @@ func (scope *Scope) CommitOrRollback() *Scope {
}
return scope
}
func (scope *Scope) SelectSql() string {
if len(scope.Search.selectStr) == 0 {
return "*"
} else {
return scope.Search.selectStr
}
}

View File

@ -43,8 +43,10 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.Fields() {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.dbName), scope.AddToVars(field.Value)))
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
}
}
return strings.Join(sqls, " AND ")
}
@ -102,8 +104,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.Fields() {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.dbName), scope.AddToVars(field.Value)))
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
}
}
return strings.Join(sqls, " AND ")
}