mirror of https://github.com/go-gorm/gorm.git
make callback query works
This commit is contained in:
parent
048b8b6abe
commit
db68e7a8fe
|
@ -1,6 +1,81 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Query(scope *Scope) {
|
||||
defer scope.Trace(time.Now())
|
||||
|
||||
inlineCondition, ok := scope.Get("gorm:inline_condition")
|
||||
if ok {
|
||||
inlineConditions := inlineCondition.([]interface{})
|
||||
if len(inlineConditions) > 0 {
|
||||
scope.Search = scope.Search.clone().where(inlineConditions[0], inlineConditions[1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
isSlice bool
|
||||
anyRecordFound bool
|
||||
destType reflect.Type
|
||||
)
|
||||
|
||||
var dest = reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
|
||||
if dest.Kind() == reflect.Slice {
|
||||
isSlice = true
|
||||
destType = dest.Type().Elem()
|
||||
} else {
|
||||
scope.Search = scope.Search.clone().limit(1)
|
||||
}
|
||||
|
||||
if scope.Search.raw {
|
||||
scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE "))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.SelectSql(), scope.TableName(), scope.CombinedConditionSql()))
|
||||
}
|
||||
|
||||
if !scope.HasError() {
|
||||
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...)
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
anyRecordFound = true
|
||||
elem := dest
|
||||
if isSlice {
|
||||
elem = reflect.New(destType).Elem()
|
||||
}
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
var values []interface{}
|
||||
for _, value := range columns {
|
||||
field := elem.FieldByName(snakeToUpperCamel(value))
|
||||
if field.IsValid() {
|
||||
values = append(values, field.Addr().Interface())
|
||||
} else {
|
||||
var ignore interface{}
|
||||
values = append(values, &ignore)
|
||||
}
|
||||
}
|
||||
scope.Err(rows.Scan(values...))
|
||||
|
||||
if isSlice {
|
||||
dest.Set(reflect.Append(dest, elem))
|
||||
}
|
||||
}
|
||||
|
||||
if !anyRecordFound && !isSlice {
|
||||
scope.Err(RecordNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(scope *Scope) {
|
||||
|
|
2
main.go
2
main.go
|
@ -117,7 +117,7 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
|||
return s.clone().do(out).where(where...).last().db
|
||||
}
|
||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||
return s.clone().do(out).where(where...).query().db
|
||||
return s.clone().NewScope(out).Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db
|
||||
}
|
||||
|
||||
func (s *DB) Row() *sql.Row {
|
||||
|
|
20
scope.go
20
scope.go
|
@ -151,7 +151,8 @@ func (scope *Scope) CallMethod(name string) {
|
|||
return
|
||||
}
|
||||
|
||||
if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() {
|
||||
call := func(value interface{}) {
|
||||
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
|
||||
fi := fm.Interface()
|
||||
if f, ok := fi.(func()); ok {
|
||||
f()
|
||||
|
@ -169,6 +170,15 @@ func (scope *Scope) CallMethod(name string) {
|
|||
scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice {
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
call(values.Index(i).Addr().Interface())
|
||||
}
|
||||
} else {
|
||||
call(scope.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) AddToVars(value interface{}) string {
|
||||
|
@ -367,3 +377,11 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
|||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) SelectSql() string {
|
||||
if len(scope.Search.selectStr) == 0 {
|
||||
return "*"
|
||||
} else {
|
||||
return scope.Search.selectStr
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,8 +43,10 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
var sqls []string
|
||||
for _, field := range scope.Fields() {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.dbName), scope.AddToVars(field.Value)))
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
@ -102,8 +104,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
var sqls []string
|
||||
for _, field := range scope.Fields() {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.dbName), scope.AddToVars(field.Value)))
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue