forked from mirror/gorm
Better support for sql.Scanner
This commit is contained in:
parent
8e0b125cb1
commit
f82d036f14
|
@ -27,7 +27,7 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
|
||||||
|
|
||||||
Email []Email // Embedded structs
|
Email []Email // Embedded structs
|
||||||
BillingAddress Address // Embedded struct
|
BillingAddress Address // Embedded struct
|
||||||
BillingAddressId int64 // Embedded struct BillingAddress's foreign key
|
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
|
||||||
ShippingAddress Address // Embedded struct
|
ShippingAddress Address // Embedded struct
|
||||||
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
|
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
|
||||||
}
|
}
|
||||||
|
|
7
do.go
7
do.go
|
@ -64,7 +64,7 @@ func (s *Do) hasError() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) setModel(value interface{}) *Do {
|
func (s *Do) setModel(value interface{}) *Do {
|
||||||
s.model = &Model{data: value, driver: s.driver}
|
s.model = &Model{data: value, driver: s.driver, debug: s.debug}
|
||||||
s.value = value
|
s.value = value
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -561,8 +561,8 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||||
default:
|
default:
|
||||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = scanner.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
||||||
}
|
}
|
||||||
|
@ -725,6 +725,7 @@ func (s *Do) createTable() *Do {
|
||||||
s.tableName(),
|
s.tableName(),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(sqls, ","),
|
||||||
)
|
)
|
||||||
|
|
||||||
s.exec()
|
s.exec()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
31
gorm_test.go
31
gorm_test.go
|
@ -2,6 +2,7 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
@ -1270,36 +1271,58 @@ func TestAutoMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NullTime struct {
|
||||||
|
Time time.Time
|
||||||
|
Valid bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nt *NullTime) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
nt.Valid = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
nt.Time, nt.Valid = value.(time.Time), true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nt NullTime) Value() (driver.Value, error) {
|
||||||
|
if !nt.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nt.Time, nil
|
||||||
|
}
|
||||||
|
|
||||||
type NullValue struct {
|
type NullValue struct {
|
||||||
Id int64
|
Id int64
|
||||||
Name sql.NullString
|
Name sql.NullString
|
||||||
Age sql.NullInt64
|
Age sql.NullInt64
|
||||||
Male sql.NullBool
|
Male sql.NullBool
|
||||||
Height sql.NullFloat64
|
Height sql.NullFloat64
|
||||||
|
AddedAt NullTime
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlNullValue(t *testing.T) {
|
func TestSqlNullValue(t *testing.T) {
|
||||||
db.DropTable(&NullValue{})
|
db.DropTable(&NullValue{})
|
||||||
db.AutoMigrate(&NullValue{})
|
db.AutoMigrate(&NullValue{})
|
||||||
|
|
||||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
|
||||||
t.Errorf("Not error should raise when test null value", err)
|
t.Errorf("Not error should raise when test null value", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var nv NullValue
|
var nv NullValue
|
||||||
db.First(&nv, "name = ?", "hello")
|
db.First(&nv, "name = ?", "hello")
|
||||||
|
|
||||||
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 {
|
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
|
||||||
t.Errorf("Should be able to fetch null value")
|
t.Errorf("Should be able to fetch null value")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
|
||||||
t.Errorf("Not error should raise when test null value", err)
|
t.Errorf("Not error should raise when test null value", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var nv2 NullValue
|
var nv2 NullValue
|
||||||
db.First(&nv2, "name = ?", "hello-2")
|
db.First(&nv2, "name = ?", "hello-2")
|
||||||
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 {
|
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
||||||
t.Errorf("Should be able to fetch null value")
|
t.Errorf("Should be able to fetch null value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
32
model.go
32
model.go
|
@ -81,7 +81,6 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
}
|
}
|
||||||
|
|
||||||
typ := indirect_value.Type()
|
typ := indirect_value.Type()
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
p := typ.Field(i)
|
p := typ.Field(i)
|
||||||
if !p.Anonymous && ast.IsExported(p.Name) {
|
if !p.Anonymous && ast.IsExported(p.Name) {
|
||||||
|
@ -137,19 +136,11 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
field.SqlType = getSqlType(m.driver, value, 0)
|
||||||
|
} else if field.IsPrimaryKey {
|
||||||
field.Value = value.Interface()
|
field.SqlType = getPrimaryKeySqlType(m.driver, value, 0)
|
||||||
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
|
|
||||||
} else {
|
} else {
|
||||||
field_value := reflect.ValueOf(field.Value)
|
field_value := reflect.Indirect(value)
|
||||||
if field_value.Kind() == reflect.Ptr {
|
|
||||||
if field_value.CanAddr() {
|
|
||||||
field_value = field_value.Elem()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch field_value.Kind() {
|
switch field_value.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
|
@ -159,10 +150,11 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
}
|
}
|
||||||
field.afterAssociation = true
|
field.afterAssociation = true
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
switch value.Interface().(type) {
|
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
|
|
||||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
if is_scanner {
|
||||||
default:
|
field.SqlType = getSqlType(m.driver, value, 0)
|
||||||
|
} else {
|
||||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||||
field.foreignKey = p.Name + "Id"
|
field.foreignKey = p.Name + "Id"
|
||||||
field.beforeAssociation = true
|
field.beforeAssociation = true
|
||||||
|
@ -174,12 +166,12 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
field.afterAssociation = true
|
field.afterAssociation = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Ptr:
|
|
||||||
debug("Errors when handle ptr sub structs")
|
|
||||||
default:
|
default:
|
||||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
field.SqlType = getSqlType(m.driver, value, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
field.Value = value.Interface()
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
17
sql_type.go
17
sql_type.go
|
@ -2,11 +2,26 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func formatColumnValue(column interface{}) interface{} {
|
||||||
|
if v, ok := column.(reflect.Value); ok {
|
||||||
|
column = v.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
if valuer, ok := interface{}(column).(driver.Valuer); ok {
|
||||||
|
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
|
||||||
|
}
|
||||||
|
return column
|
||||||
|
}
|
||||||
|
|
||||||
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
|
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
|
||||||
|
column = formatColumnValue(column)
|
||||||
|
|
||||||
switch adaptor {
|
switch adaptor {
|
||||||
case "sqlite3":
|
case "sqlite3":
|
||||||
return "INTEGER PRIMARY KEY"
|
return "INTEGER PRIMARY KEY"
|
||||||
|
@ -30,6 +45,8 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSqlType(adaptor string, column interface{}, size int) string {
|
func getSqlType(adaptor string, column interface{}, size int) string {
|
||||||
|
column = formatColumnValue(column)
|
||||||
|
|
||||||
switch adaptor {
|
switch adaptor {
|
||||||
case "sqlite3":
|
case "sqlite3":
|
||||||
switch column.(type) {
|
switch column.(type) {
|
||||||
|
|
Loading…
Reference in New Issue