make callback query works

This commit is contained in:
Jinzhu 2014-01-28 09:25:30 +08:00
parent 048b8b6abe
commit db68e7a8fe
4 changed files with 118 additions and 21 deletions

View File

@ -1,6 +1,81 @@
package gorm package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
func Query(scope *Scope) { func Query(scope *Scope) {
defer scope.Trace(time.Now())
inlineCondition, ok := scope.Get("gorm:inline_condition")
if ok {
inlineConditions := inlineCondition.([]interface{})
if len(inlineConditions) > 0 {
scope.Search = scope.Search.clone().where(inlineConditions[0], inlineConditions[1:]...)
}
}
var (
isSlice bool
anyRecordFound bool
destType reflect.Type
)
var dest = reflect.Indirect(reflect.ValueOf(scope.Value))
if dest.Kind() == reflect.Slice {
isSlice = true
destType = dest.Type().Elem()
} else {
scope.Search = scope.Search.clone().limit(1)
}
if scope.Search.raw {
scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE "))
} else {
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.SelectSql(), scope.TableName(), scope.CombinedConditionSql()))
}
if !scope.HasError() {
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...)
if scope.Err(err) != nil {
return
}
defer rows.Close()
for rows.Next() {
anyRecordFound = true
elem := dest
if isSlice {
elem = reflect.New(destType).Elem()
}
columns, _ := rows.Columns()
var values []interface{}
for _, value := range columns {
field := elem.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
values = append(values, field.Addr().Interface())
} else {
var ignore interface{}
values = append(values, &ignore)
}
}
scope.Err(rows.Scan(values...))
if isSlice {
dest.Set(reflect.Append(dest, elem))
}
}
if !anyRecordFound && !isSlice {
scope.Err(RecordNotFound)
}
}
} }
func AfterQuery(scope *Scope) { func AfterQuery(scope *Scope) {

View File

@ -117,7 +117,7 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB {
return s.clone().do(out).where(where...).last().db return s.clone().do(out).where(where...).last().db
} }
func (s *DB) Find(out interface{}, where ...interface{}) *DB { func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().do(out).where(where...).query().db return s.clone().NewScope(out).Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db
} }
func (s *DB) Row() *sql.Row { func (s *DB) Row() *sql.Row {

View File

@ -151,7 +151,8 @@ func (scope *Scope) CallMethod(name string) {
return return
} }
if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() { call := func(value interface{}) {
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
fi := fm.Interface() fi := fm.Interface()
if f, ok := fi.(func()); ok { if f, ok := fi.(func()); ok {
f() f()
@ -171,6 +172,15 @@ func (scope *Scope) CallMethod(name string) {
} }
} }
if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice {
for i := 0; i < values.Len(); i++ {
call(values.Index(i).Addr().Interface())
}
} else {
call(scope.Value)
}
}
func (scope *Scope) AddToVars(value interface{}) string { func (scope *Scope) AddToVars(value interface{}) string {
scope.SqlVars = append(scope.SqlVars, value) scope.SqlVars = append(scope.SqlVars, value)
return scope.Dialect().BinVar(len(scope.SqlVars)) return scope.Dialect().BinVar(len(scope.SqlVars))
@ -367,3 +377,11 @@ func (scope *Scope) CommitOrRollback() *Scope {
} }
return scope return scope
} }
func (scope *Scope) SelectSql() string {
if len(scope.Search.selectStr) == 0 {
return "*"
} else {
return scope.Search.selectStr
}
}

View File

@ -43,8 +43,10 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
var sqls []string var sqls []string
for _, field := range scope.Fields() { for _, field := range scope.New(value).Fields() {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.dbName), scope.AddToVars(field.Value))) if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
}
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
} }
@ -102,8 +104,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
var sqls []string var sqls []string
for _, field := range scope.Fields() { for _, field := range scope.New(value).Fields() {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.dbName), scope.AddToVars(field.Value))) if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
}
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
} }