mirror of https://github.com/go-gorm/gorm.git
Support sql.Scanner
This commit is contained in:
parent
dc15849313
commit
8e0b125cb1
19
do.go
19
do.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
@ -560,13 +561,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||
default:
|
||||
switch arg.(type) {
|
||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
||||
value := reflect.ValueOf(arg).Field(0).Interface()
|
||||
str = strings.Replace(str, "?", s.addToVars(value), 1)
|
||||
default:
|
||||
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = scanner.Value()
|
||||
}
|
||||
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -624,13 +622,10 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||
default:
|
||||
switch arg.(type) {
|
||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
||||
value := reflect.ValueOf(arg).Field(0).Interface()
|
||||
str = strings.Replace(not_equal_sql, "?", s.addToVars(value), 1)
|
||||
default:
|
||||
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
|
||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = scanner.Value()
|
||||
}
|
||||
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
28
model.go
28
model.go
|
@ -14,6 +14,7 @@ import (
|
|||
type Model struct {
|
||||
data interface{}
|
||||
driver string
|
||||
debug bool
|
||||
_cache_fields map[string][]Field
|
||||
}
|
||||
|
||||
|
@ -106,11 +107,13 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
if is_time {
|
||||
field.IsBlank = time_value.IsZero()
|
||||
} else {
|
||||
switch value.Interface().(type) {
|
||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
||||
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if is_scanner {
|
||||
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
||||
default:
|
||||
} else {
|
||||
m := &Model{data: value.Interface(), driver: m.driver}
|
||||
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
field.IsBlank = true
|
||||
|
@ -370,25 +373,14 @@ func setFieldValue(field reflect.Value, value interface{}) bool {
|
|||
}
|
||||
field.SetInt(reflect.ValueOf(value).Int())
|
||||
default:
|
||||
field_type := field.Type()
|
||||
if field_type == reflect.TypeOf(value) {
|
||||
field.Set(reflect.ValueOf(value))
|
||||
} else if value == nil {
|
||||
field.Set(reflect.Zero(field.Type()))
|
||||
} else if field_type == reflect.TypeOf(sql.NullBool{}) {
|
||||
field.Set(reflect.ValueOf(sql.NullBool{value.(bool), true}))
|
||||
} else if field_type == reflect.TypeOf(sql.NullFloat64{}) {
|
||||
field.Set(reflect.ValueOf(sql.NullFloat64{value.(float64), true}))
|
||||
} else if field_type == reflect.TypeOf(sql.NullInt64{}) {
|
||||
field.Set(reflect.ValueOf(sql.NullInt64{value.(int64), true}))
|
||||
} else if field_type == reflect.TypeOf(sql.NullString{}) {
|
||||
field.Set(reflect.ValueOf(sql.NullString{value.(string), true}))
|
||||
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
|
||||
scanner.Scan(value)
|
||||
} else {
|
||||
field.Set(reflect.ValueOf(value))
|
||||
}
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue