From f82d036f14928782a7743c9c2887ed84949222f3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Nov 2013 19:38:28 +0800 Subject: [PATCH] Better support for sql.Scanner --- README.md | 10 +++++----- do.go | 7 ++++--- gorm_test.go | 41 ++++++++++++++++++++++++++++++++--------- model.go | 32 ++++++++++++-------------------- sql_type.go | 17 +++++++++++++++++ 5 files changed, 70 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 035f93a4..ed53f5af 100644 --- a/README.md +++ b/README.md @@ -25,11 +25,11 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n UpdatedAt time.Time // Time of record is updated, will be updated automatically DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more - Email []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressId int64 // Embedded struct BillingAddress's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key + Email []Email // Embedded structs + BillingAddress Address // Embedded struct + BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key + ShippingAddress Address // Embedded struct + ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key } type Email struct { // TableName: `emails` diff --git a/do.go b/do.go index 05c2a681..b2f4afd8 100644 --- a/do.go +++ b/do.go @@ -64,7 +64,7 @@ func (s *Do) hasError() bool { } func (s *Do) setModel(value interface{}) *Do { - s.model = &Model{data: value, driver: s.driver} + s.model = &Model{data: value, driver: s.driver, debug: s.debug} s.value = value return s } @@ -561,8 +561,8 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() } str = strings.Replace(str, "?", s.addToVars(arg), 1) } @@ -725,6 +725,7 @@ func (s *Do) createTable() *Do { s.tableName(), strings.Join(sqls, ","), ) + s.exec() return s } diff --git a/gorm_test.go b/gorm_test.go index b165db15..cc643b8e 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" _ "github.com/go-sql-driver/mysql" @@ -1270,36 +1271,58 @@ func TestAutoMigration(t *testing.T) { } } +type NullTime struct { + Time time.Time + Valid bool +} + +func (nt *NullTime) Scan(value interface{}) error { + if value == nil { + nt.Valid = false + return nil + } + nt.Time, nt.Valid = value.(time.Time), true + return nil +} + +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + type NullValue struct { - Id int64 - Name sql.NullString - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 + Id int64 + Name sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + AddedAt NullTime } func TestSqlNullValue(t *testing.T) { db.DropTable(&NullValue{}) db.AutoMigrate(&NullValue{}) - if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil { + if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil { t.Errorf("Not error should raise when test null value", err) } var nv NullValue db.First(&nv, "name = ?", "hello") - if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 { + if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { t.Errorf("Should be able to fetch null value") } - if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil { + if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil { t.Errorf("Not error should raise when test null value", err) } var nv2 NullValue db.First(&nv2, "name = ?", "hello-2") - if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 { + if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { t.Errorf("Should be able to fetch null value") } } diff --git a/model.go b/model.go index 1ad17e26..653f227e 100644 --- a/model.go +++ b/model.go @@ -81,7 +81,6 @@ func (m *Model) fields(operation string) (fields []Field) { } typ := indirect_value.Type() - for i := 0; i < typ.NumField(); i++ { p := typ.Field(i) if !p.Anonymous && ast.IsExported(p.Name) { @@ -137,19 +136,11 @@ func (m *Model) fields(operation string) (fields []Field) { value.Set(reflect.ValueOf(time.Now())) } } - } - - field.Value = value.Interface() - - if field.IsPrimaryKey { - field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0) + field.SqlType = getSqlType(m.driver, value, 0) + } else if field.IsPrimaryKey { + field.SqlType = getPrimaryKeySqlType(m.driver, value, 0) } else { - field_value := reflect.ValueOf(field.Value) - if field_value.Kind() == reflect.Ptr { - if field_value.CanAddr() { - field_value = field_value.Elem() - } - } + field_value := reflect.Indirect(value) switch field_value.Kind() { case reflect.Slice: @@ -159,10 +150,11 @@ func (m *Model) fields(operation string) (fields []Field) { } field.afterAssociation = true case reflect.Struct: - switch value.Interface().(type) { - case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time: - field.SqlType = getSqlType(m.driver, field.Value, 0) - default: + _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) + + if is_scanner { + field.SqlType = getSqlType(m.driver, value, 0) + } else { if indirect_value.FieldByName(p.Name + "Id").IsValid() { field.foreignKey = p.Name + "Id" field.beforeAssociation = true @@ -174,12 +166,12 @@ func (m *Model) fields(operation string) (fields []Field) { field.afterAssociation = true } } - case reflect.Ptr: - debug("Errors when handle ptr sub structs") default: - field.SqlType = getSqlType(m.driver, field.Value, 0) + field.SqlType = getSqlType(m.driver, value, 0) } } + + field.Value = value.Interface() fields = append(fields, field) } } diff --git a/sql_type.go b/sql_type.go index 4db77102..d82cdc7e 100644 --- a/sql_type.go +++ b/sql_type.go @@ -2,11 +2,26 @@ package gorm import ( "database/sql" + "database/sql/driver" "fmt" + "reflect" "time" ) +func formatColumnValue(column interface{}) interface{} { + if v, ok := column.(reflect.Value); ok { + column = v.Interface() + } + + if valuer, ok := interface{}(column).(driver.Valuer); ok { + column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface() + } + return column +} + func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string { + column = formatColumnValue(column) + switch adaptor { case "sqlite3": return "INTEGER PRIMARY KEY" @@ -30,6 +45,8 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string { } func getSqlType(adaptor string, column interface{}, size int) string { + column = formatColumnValue(column) + switch adaptor { case "sqlite3": switch column.(type) {