diff --git a/do.go b/do.go index 4548bd93..05c2a681 100644 --- a/do.go +++ b/do.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" @@ -560,13 +561,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: - switch arg.(type) { - case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: - value := reflect.ValueOf(arg).Field(0).Interface() - str = strings.Replace(str, "?", s.addToVars(value), 1) - default: - str = strings.Replace(str, "?", s.addToVars(arg), 1) + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = scanner.Value() } + str = strings.Replace(str, "?", s.addToVars(arg), 1) } } return @@ -624,13 +622,10 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: - switch arg.(type) { - case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: - value := reflect.ValueOf(arg).Field(0).Interface() - str = strings.Replace(not_equal_sql, "?", s.addToVars(value), 1) - default: - str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1) + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = scanner.Value() } + str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1) } } return diff --git a/model.go b/model.go index 90a9d591..1ad17e26 100644 --- a/model.go +++ b/model.go @@ -14,6 +14,7 @@ import ( type Model struct { data interface{} driver string + debug bool _cache_fields map[string][]Field } @@ -106,11 +107,13 @@ func (m *Model) fields(operation string) (fields []Field) { if is_time { field.IsBlank = time_value.IsZero() } else { - switch value.Interface().(type) { - case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: + _, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner) + + if is_scanner { field.IsBlank = !value.FieldByName("Valid").Interface().(bool) - default: + } else { m := &Model{data: value.Interface(), driver: m.driver} + fields := m.columnsHasValue("other") if len(fields) == 0 { field.IsBlank = true @@ -370,25 +373,14 @@ func setFieldValue(field reflect.Value, value interface{}) bool { } field.SetInt(reflect.ValueOf(value).Int()) default: - field_type := field.Type() - if field_type == reflect.TypeOf(value) { - field.Set(reflect.ValueOf(value)) - } else if value == nil { - field.Set(reflect.Zero(field.Type())) - } else if field_type == reflect.TypeOf(sql.NullBool{}) { - field.Set(reflect.ValueOf(sql.NullBool{value.(bool), true})) - } else if field_type == reflect.TypeOf(sql.NullFloat64{}) { - field.Set(reflect.ValueOf(sql.NullFloat64{value.(float64), true})) - } else if field_type == reflect.TypeOf(sql.NullInt64{}) { - field.Set(reflect.ValueOf(sql.NullInt64{value.(int64), true})) - } else if field_type == reflect.TypeOf(sql.NullString{}) { - field.Set(reflect.ValueOf(sql.NullString{value.(string), true})) + if scanner, ok := field.Addr().Interface().(sql.Scanner); ok { + scanner.Scan(value) } else { field.Set(reflect.ValueOf(value)) } } return true - } else { - return false } + + return false } diff --git a/utils.go b/utils.go index 6277bc40..0848ec0b 100644 --- a/utils.go +++ b/utils.go @@ -2,8 +2,8 @@ package gorm import ( "bytes" - "fmt" + "fmt" "strings" )