mirror of https://github.com/go-gorm/gorm.git
Better support for sql.Scanner
This commit is contained in:
parent
8e0b125cb1
commit
f82d036f14
|
@ -27,7 +27,7 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
|
|||
|
||||
Email []Email // Embedded structs
|
||||
BillingAddress Address // Embedded struct
|
||||
BillingAddressId int64 // Embedded struct BillingAddress's foreign key
|
||||
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
|
||||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
|
||||
}
|
||||
|
|
7
do.go
7
do.go
|
@ -64,7 +64,7 @@ func (s *Do) hasError() bool {
|
|||
}
|
||||
|
||||
func (s *Do) setModel(value interface{}) *Do {
|
||||
s.model = &Model{data: value, driver: s.driver}
|
||||
s.model = &Model{data: value, driver: s.driver, debug: s.debug}
|
||||
s.value = value
|
||||
return s
|
||||
}
|
||||
|
@ -561,8 +561,8 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||
default:
|
||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = scanner.Value()
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
||||
}
|
||||
|
@ -725,6 +725,7 @@ func (s *Do) createTable() *Do {
|
|||
s.tableName(),
|
||||
strings.Join(sqls, ","),
|
||||
)
|
||||
|
||||
s.exec()
|
||||
return s
|
||||
}
|
||||
|
|
31
gorm_test.go
31
gorm_test.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
|
@ -1270,36 +1271,58 @@ func TestAutoMigration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
nt.Valid = false
|
||||
return nil
|
||||
}
|
||||
nt.Time, nt.Valid = value.(time.Time), true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
||||
type NullValue struct {
|
||||
Id int64
|
||||
Name sql.NullString
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
AddedAt NullTime
|
||||
}
|
||||
|
||||
func TestSqlNullValue(t *testing.T) {
|
||||
db.DropTable(&NullValue{})
|
||||
db.AutoMigrate(&NullValue{})
|
||||
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
|
||||
t.Errorf("Not error should raise when test null value", err)
|
||||
}
|
||||
|
||||
var nv NullValue
|
||||
db.First(&nv, "name = ?", "hello")
|
||||
|
||||
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 {
|
||||
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
|
||||
t.Errorf("Should be able to fetch null value")
|
||||
}
|
||||
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
|
||||
t.Errorf("Not error should raise when test null value", err)
|
||||
}
|
||||
|
||||
var nv2 NullValue
|
||||
db.First(&nv2, "name = ?", "hello-2")
|
||||
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 {
|
||||
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
||||
t.Errorf("Should be able to fetch null value")
|
||||
}
|
||||
}
|
||||
|
|
32
model.go
32
model.go
|
@ -81,7 +81,6 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
|
||||
typ := indirect_value.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
p := typ.Field(i)
|
||||
if !p.Anonymous && ast.IsExported(p.Name) {
|
||||
|
@ -137,19 +136,11 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
field.Value = value.Interface()
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
|
||||
field.SqlType = getSqlType(m.driver, value, 0)
|
||||
} else if field.IsPrimaryKey {
|
||||
field.SqlType = getPrimaryKeySqlType(m.driver, value, 0)
|
||||
} else {
|
||||
field_value := reflect.ValueOf(field.Value)
|
||||
if field_value.Kind() == reflect.Ptr {
|
||||
if field_value.CanAddr() {
|
||||
field_value = field_value.Elem()
|
||||
}
|
||||
}
|
||||
field_value := reflect.Indirect(value)
|
||||
|
||||
switch field_value.Kind() {
|
||||
case reflect.Slice:
|
||||
|
@ -159,10 +150,11 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
field.afterAssociation = true
|
||||
case reflect.Struct:
|
||||
switch value.Interface().(type) {
|
||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
default:
|
||||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if is_scanner {
|
||||
field.SqlType = getSqlType(m.driver, value, 0)
|
||||
} else {
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
field.beforeAssociation = true
|
||||
|
@ -174,12 +166,12 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.afterAssociation = true
|
||||
}
|
||||
}
|
||||
case reflect.Ptr:
|
||||
debug("Errors when handle ptr sub structs")
|
||||
default:
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
field.SqlType = getSqlType(m.driver, value, 0)
|
||||
}
|
||||
}
|
||||
|
||||
field.Value = value.Interface()
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
|
|
17
sql_type.go
17
sql_type.go
|
@ -2,11 +2,26 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
func formatColumnValue(column interface{}) interface{} {
|
||||
if v, ok := column.(reflect.Value); ok {
|
||||
column = v.Interface()
|
||||
}
|
||||
|
||||
if valuer, ok := interface{}(column).(driver.Valuer); ok {
|
||||
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
|
||||
}
|
||||
return column
|
||||
}
|
||||
|
||||
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
|
||||
column = formatColumnValue(column)
|
||||
|
||||
switch adaptor {
|
||||
case "sqlite3":
|
||||
return "INTEGER PRIMARY KEY"
|
||||
|
@ -30,6 +45,8 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
|
|||
}
|
||||
|
||||
func getSqlType(adaptor string, column interface{}, size int) string {
|
||||
column = formatColumnValue(column)
|
||||
|
||||
switch adaptor {
|
||||
case "sqlite3":
|
||||
switch column.(type) {
|
||||
|
|
Loading…
Reference in New Issue