mirror of https://github.com/go-gorm/gorm.git
make callback query works
This commit is contained in:
parent
048b8b6abe
commit
db68e7a8fe
|
@ -1,6 +1,81 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
func Query(scope *Scope) {
|
func Query(scope *Scope) {
|
||||||
|
defer scope.Trace(time.Now())
|
||||||
|
|
||||||
|
inlineCondition, ok := scope.Get("gorm:inline_condition")
|
||||||
|
if ok {
|
||||||
|
inlineConditions := inlineCondition.([]interface{})
|
||||||
|
if len(inlineConditions) > 0 {
|
||||||
|
scope.Search = scope.Search.clone().where(inlineConditions[0], inlineConditions[1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
isSlice bool
|
||||||
|
anyRecordFound bool
|
||||||
|
destType reflect.Type
|
||||||
|
)
|
||||||
|
|
||||||
|
var dest = reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||||
|
|
||||||
|
if dest.Kind() == reflect.Slice {
|
||||||
|
isSlice = true
|
||||||
|
destType = dest.Type().Elem()
|
||||||
|
} else {
|
||||||
|
scope.Search = scope.Search.clone().limit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if scope.Search.raw {
|
||||||
|
scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE "))
|
||||||
|
} else {
|
||||||
|
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.SelectSql(), scope.TableName(), scope.CombinedConditionSql()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.HasError() {
|
||||||
|
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...)
|
||||||
|
|
||||||
|
if scope.Err(err) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
anyRecordFound = true
|
||||||
|
elem := dest
|
||||||
|
if isSlice {
|
||||||
|
elem = reflect.New(destType).Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
columns, _ := rows.Columns()
|
||||||
|
var values []interface{}
|
||||||
|
for _, value := range columns {
|
||||||
|
field := elem.FieldByName(snakeToUpperCamel(value))
|
||||||
|
if field.IsValid() {
|
||||||
|
values = append(values, field.Addr().Interface())
|
||||||
|
} else {
|
||||||
|
var ignore interface{}
|
||||||
|
values = append(values, &ignore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scope.Err(rows.Scan(values...))
|
||||||
|
|
||||||
|
if isSlice {
|
||||||
|
dest.Set(reflect.Append(dest, elem))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !anyRecordFound && !isSlice {
|
||||||
|
scope.Err(RecordNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterQuery(scope *Scope) {
|
func AfterQuery(scope *Scope) {
|
||||||
|
|
2
main.go
2
main.go
|
@ -117,7 +117,7 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().do(out).where(where...).last().db
|
return s.clone().do(out).where(where...).last().db
|
||||||
}
|
}
|
||||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().do(out).where(where...).query().db
|
return s.clone().NewScope(out).Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Row() *sql.Row {
|
func (s *DB) Row() *sql.Row {
|
||||||
|
|
20
scope.go
20
scope.go
|
@ -151,7 +151,8 @@ func (scope *Scope) CallMethod(name string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() {
|
call := func(value interface{}) {
|
||||||
|
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
|
||||||
fi := fm.Interface()
|
fi := fm.Interface()
|
||||||
if f, ok := fi.(func()); ok {
|
if f, ok := fi.(func()); ok {
|
||||||
f()
|
f()
|
||||||
|
@ -171,6 +172,15 @@ func (scope *Scope) CallMethod(name string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < values.Len(); i++ {
|
||||||
|
call(values.Index(i).Addr().Interface())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
call(scope.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (scope *Scope) AddToVars(value interface{}) string {
|
func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
scope.SqlVars = append(scope.SqlVars, value)
|
scope.SqlVars = append(scope.SqlVars, value)
|
||||||
return scope.Dialect().BinVar(len(scope.SqlVars))
|
return scope.Dialect().BinVar(len(scope.SqlVars))
|
||||||
|
@ -367,3 +377,11 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
||||||
}
|
}
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) SelectSql() string {
|
||||||
|
if len(scope.Search.selectStr) == 0 {
|
||||||
|
return "*"
|
||||||
|
} else {
|
||||||
|
return scope.Search.selectStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -43,8 +43,10 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.New(value).Fields() {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.dbName), scope.AddToVars(field.Value)))
|
if !field.IsBlank {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
}
|
}
|
||||||
|
@ -102,8 +104,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.New(value).Fields() {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.dbName), scope.AddToVars(field.Value)))
|
if !field.IsBlank {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.DBName), scope.AddToVars(field.Value)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue