Better support for sql.Scanner

This commit is contained in:
Jinzhu 2013-11-10 19:38:28 +08:00
parent 8e0b125cb1
commit f82d036f14
5 changed files with 70 additions and 37 deletions

View File

@ -27,7 +27,7 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
Email []Email // Embedded structs
BillingAddress Address // Embedded struct
BillingAddressId int64 // Embedded struct BillingAddress's foreign key
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
}

7
do.go
View File

@ -64,7 +64,7 @@ func (s *Do) hasError() bool {
}
func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, driver: s.driver}
s.model = &Model{data: value, driver: s.driver, debug: s.debug}
s.value = value
return s
}
@ -561,8 +561,8 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", s.addToVars(arg), 1)
}
@ -725,6 +725,7 @@ func (s *Do) createTable() *Do {
s.tableName(),
strings.Join(sqls, ","),
)
s.exec()
return s
}

View File

@ -2,6 +2,7 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
@ -1270,36 +1271,58 @@ func TestAutoMigration(t *testing.T) {
}
}
type NullTime struct {
Time time.Time
Valid bool
}
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
type NullValue struct {
Id int64
Name sql.NullString
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
}
func TestSqlNullValue(t *testing.T) {
db.DropTable(&NullValue{})
db.AutoMigrate(&NullValue{})
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
var nv NullValue
db.First(&nv, "name = ?", "hello")
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 {
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
t.Errorf("Should be able to fetch null value")
}
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
var nv2 NullValue
db.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 {
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value")
}
}

View File

@ -81,7 +81,6 @@ func (m *Model) fields(operation string) (fields []Field) {
}
typ := indirect_value.Type()
for i := 0; i < typ.NumField(); i++ {
p := typ.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
@ -137,19 +136,11 @@ func (m *Model) fields(operation string) (fields []Field) {
value.Set(reflect.ValueOf(time.Now()))
}
}
}
field.Value = value.Interface()
if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
field.SqlType = getSqlType(m.driver, value, 0)
} else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.driver, value, 0)
} else {
field_value := reflect.ValueOf(field.Value)
if field_value.Kind() == reflect.Ptr {
if field_value.CanAddr() {
field_value = field_value.Elem()
}
}
field_value := reflect.Indirect(value)
switch field_value.Kind() {
case reflect.Slice:
@ -159,10 +150,11 @@ func (m *Model) fields(operation string) (fields []Field) {
}
field.afterAssociation = true
case reflect.Struct:
switch value.Interface().(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
field.SqlType = getSqlType(m.driver, field.Value, 0)
default:
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if is_scanner {
field.SqlType = getSqlType(m.driver, value, 0)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
@ -174,12 +166,12 @@ func (m *Model) fields(operation string) (fields []Field) {
field.afterAssociation = true
}
}
case reflect.Ptr:
debug("Errors when handle ptr sub structs")
default:
field.SqlType = getSqlType(m.driver, field.Value, 0)
field.SqlType = getSqlType(m.driver, value, 0)
}
}
field.Value = value.Interface()
fields = append(fields, field)
}
}

View File

@ -2,11 +2,26 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"time"
)
func formatColumnValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
column = formatColumnValue(column)
switch adaptor {
case "sqlite3":
return "INTEGER PRIMARY KEY"
@ -30,6 +45,8 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
}
func getSqlType(adaptor string, column interface{}, size int) string {
column = formatColumnValue(column)
switch adaptor {
case "sqlite3":
switch column.(type) {