Better support for sql.Scanner

This commit is contained in:
Jinzhu 2013-11-10 19:38:28 +08:00
parent 8e0b125cb1
commit f82d036f14
5 changed files with 70 additions and 37 deletions

View File

@ -27,7 +27,7 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
Email []Email // Embedded structs Email []Email // Embedded structs
BillingAddress Address // Embedded struct BillingAddress Address // Embedded struct
BillingAddressId int64 // Embedded struct BillingAddress's foreign key BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
ShippingAddress Address // Embedded struct ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
} }

7
do.go
View File

@ -64,7 +64,7 @@ func (s *Do) hasError() bool {
} }
func (s *Do) setModel(value interface{}) *Do { func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, driver: s.driver} s.model = &Model{data: value, driver: s.driver, debug: s.debug}
s.value = value s.value = value
return s return s
} }
@ -561,8 +561,8 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
} }
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default: default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok { if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value() arg, _ = valuer.Value()
} }
str = strings.Replace(str, "?", s.addToVars(arg), 1) str = strings.Replace(str, "?", s.addToVars(arg), 1)
} }
@ -725,6 +725,7 @@ func (s *Do) createTable() *Do {
s.tableName(), s.tableName(),
strings.Join(sqls, ","), strings.Join(sqls, ","),
) )
s.exec() s.exec()
return s return s
} }

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@ -1270,36 +1271,58 @@ func TestAutoMigration(t *testing.T) {
} }
} }
type NullTime struct {
Time time.Time
Valid bool
}
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
type NullValue struct { type NullValue struct {
Id int64 Id int64
Name sql.NullString Name sql.NullString
Age sql.NullInt64 Age sql.NullInt64
Male sql.NullBool Male sql.NullBool
Height sql.NullFloat64 Height sql.NullFloat64
AddedAt NullTime
} }
func TestSqlNullValue(t *testing.T) { func TestSqlNullValue(t *testing.T) {
db.DropTable(&NullValue{}) db.DropTable(&NullValue{})
db.AutoMigrate(&NullValue{}) db.AutoMigrate(&NullValue{})
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil { if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err) t.Errorf("Not error should raise when test null value", err)
} }
var nv NullValue var nv NullValue
db.First(&nv, "name = ?", "hello") db.First(&nv, "name = ?", "hello")
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 { if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
t.Errorf("Should be able to fetch null value") t.Errorf("Should be able to fetch null value")
} }
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil { if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err) t.Errorf("Not error should raise when test null value", err)
} }
var nv2 NullValue var nv2 NullValue
db.First(&nv2, "name = ?", "hello-2") db.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 { if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value") t.Errorf("Should be able to fetch null value")
} }
} }

View File

@ -81,7 +81,6 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
typ := indirect_value.Type() typ := indirect_value.Type()
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
p := typ.Field(i) p := typ.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) { if !p.Anonymous && ast.IsExported(p.Name) {
@ -137,19 +136,11 @@ func (m *Model) fields(operation string) (fields []Field) {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
} }
} field.SqlType = getSqlType(m.driver, value, 0)
} else if field.IsPrimaryKey {
field.Value = value.Interface() field.SqlType = getPrimaryKeySqlType(m.driver, value, 0)
if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
} else { } else {
field_value := reflect.ValueOf(field.Value) field_value := reflect.Indirect(value)
if field_value.Kind() == reflect.Ptr {
if field_value.CanAddr() {
field_value = field_value.Elem()
}
}
switch field_value.Kind() { switch field_value.Kind() {
case reflect.Slice: case reflect.Slice:
@ -159,10 +150,11 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
field.afterAssociation = true field.afterAssociation = true
case reflect.Struct: case reflect.Struct:
switch value.Interface().(type) { _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
field.SqlType = getSqlType(m.driver, field.Value, 0) if is_scanner {
default: field.SqlType = getSqlType(m.driver, value, 0)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() { if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id" field.foreignKey = p.Name + "Id"
field.beforeAssociation = true field.beforeAssociation = true
@ -174,12 +166,12 @@ func (m *Model) fields(operation string) (fields []Field) {
field.afterAssociation = true field.afterAssociation = true
} }
} }
case reflect.Ptr:
debug("Errors when handle ptr sub structs")
default: default:
field.SqlType = getSqlType(m.driver, field.Value, 0) field.SqlType = getSqlType(m.driver, value, 0)
} }
} }
field.Value = value.Interface()
fields = append(fields, field) fields = append(fields, field)
} }
} }

View File

@ -2,11 +2,26 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"fmt" "fmt"
"reflect"
"time" "time"
) )
func formatColumnValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string { func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
column = formatColumnValue(column)
switch adaptor { switch adaptor {
case "sqlite3": case "sqlite3":
return "INTEGER PRIMARY KEY" return "INTEGER PRIMARY KEY"
@ -30,6 +45,8 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
} }
func getSqlType(adaptor string, column interface{}, size int) string { func getSqlType(adaptor string, column interface{}, size int) string {
column = formatColumnValue(column)
switch adaptor { switch adaptor {
case "sqlite3": case "sqlite3":
switch column.(type) { switch column.(type) {