Support sql.Scanner

This commit is contained in:
Jinzhu 2013-11-10 18:33:37 +08:00
parent dc15849313
commit 8e0b125cb1
3 changed files with 18 additions and 31 deletions

19
do.go
View File

@ -2,6 +2,7 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
@ -560,13 +561,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
switch arg.(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
value := reflect.ValueOf(arg).Field(0).Interface()
str = strings.Replace(str, "?", s.addToVars(value), 1)
default:
str = strings.Replace(str, "?", s.addToVars(arg), 1)
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
}
str = strings.Replace(str, "?", s.addToVars(arg), 1)
}
}
return
@ -624,13 +622,10 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
switch arg.(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
value := reflect.ValueOf(arg).Field(0).Interface()
str = strings.Replace(not_equal_sql, "?", s.addToVars(value), 1)
default:
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
}
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
}
}
return

View File

@ -14,6 +14,7 @@ import (
type Model struct {
data interface{}
driver string
debug bool
_cache_fields map[string][]Field
}
@ -106,11 +107,13 @@ func (m *Model) fields(operation string) (fields []Field) {
if is_time {
field.IsBlank = time_value.IsZero()
} else {
switch value.Interface().(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
default:
} else {
m := &Model{data: value.Interface(), driver: m.driver}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
@ -370,25 +373,14 @@ func setFieldValue(field reflect.Value, value interface{}) bool {
}
field.SetInt(reflect.ValueOf(value).Int())
default:
field_type := field.Type()
if field_type == reflect.TypeOf(value) {
field.Set(reflect.ValueOf(value))
} else if value == nil {
field.Set(reflect.Zero(field.Type()))
} else if field_type == reflect.TypeOf(sql.NullBool{}) {
field.Set(reflect.ValueOf(sql.NullBool{value.(bool), true}))
} else if field_type == reflect.TypeOf(sql.NullFloat64{}) {
field.Set(reflect.ValueOf(sql.NullFloat64{value.(float64), true}))
} else if field_type == reflect.TypeOf(sql.NullInt64{}) {
field.Set(reflect.ValueOf(sql.NullInt64{value.(int64), true}))
} else if field_type == reflect.TypeOf(sql.NullString{}) {
field.Set(reflect.ValueOf(sql.NullString{value.(string), true}))
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else {
field.Set(reflect.ValueOf(value))
}
}
return true
} else {
return false
}
return false
}

View File

@ -2,8 +2,8 @@ package gorm
import (
"bytes"
"fmt"
"fmt"
"strings"
)