Support poiner of Scanner

This commit is contained in:
Jinzhu 2016-01-04 18:40:06 +08:00
parent f330da219c
commit d1892d3177
3 changed files with 19 additions and 15 deletions

View File

@ -232,7 +232,9 @@ func TestSqlNullValue(t *testing.T) {
DB.DropTable(&NullValue{}) DB.DropTable(&NullValue{})
DB.AutoMigrate(&NullValue{}) DB.AutoMigrate(&NullValue{})
if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello", Valid: true}, if err := DB.Save(&NullValue{
Name: sql.NullString{String: "hello", Valid: true},
Gender: &sql.NullString{String: "M", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: true}, Age: sql.NullInt64{Int64: 18, Valid: true},
Male: sql.NullBool{Bool: true, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true},
@ -244,11 +246,13 @@ func TestSqlNullValue(t *testing.T) {
var nv NullValue var nv NullValue
DB.First(&nv, "name = ?", "hello") DB.First(&nv, "name = ?", "hello")
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
t.Errorf("Should be able to fetch null value") t.Errorf("Should be able to fetch null value")
} }
if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-2", Valid: true}, if err := DB.Save(&NullValue{
Name: sql.NullString{String: "hello-2", Valid: true},
Gender: &sql.NullString{String: "F", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: false}, Age: sql.NullInt64{Int64: 18, Valid: false},
Male: sql.NullBool{Bool: true, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true},
@ -259,11 +263,13 @@ func TestSqlNullValue(t *testing.T) {
var nv2 NullValue var nv2 NullValue
DB.First(&nv2, "name = ?", "hello-2") DB.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value") t.Errorf("Should be able to fetch null value")
} }
if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-3", Valid: false}, if err := DB.Save(&NullValue{
Name: sql.NullString{String: "hello-3", Valid: false},
Gender: &sql.NullString{String: "M", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: false}, Age: sql.NullInt64{Int64: 18, Valid: false},
Male: sql.NullBool{Bool: true, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true},

View File

@ -166,16 +166,18 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.HasDefaultValue = true field.HasDefaultValue = true
} }
fieldValue := reflect.New(fieldStruct.Type).Interface() indirectType := fieldStruct.Type
for indirectType.Kind() == reflect.Ptr {
indirectType = indirectType.Elem()
}
fieldValue := reflect.New(indirectType).Interface()
if _, isScanner := fieldValue.(sql.Scanner); isScanner { if _, isScanner := fieldValue.(sql.Scanner); isScanner {
// is scanner // is scanner
field.IsScanner, field.IsNormal = true, true field.IsScanner, field.IsNormal = true, true
} else if _, isTime := fieldValue.(*time.Time); isTime { } else if _, isTime := fieldValue.(*time.Time); isTime {
// is time // is time
field.IsNormal = true field.IsNormal = true
} else if _, isTime := fieldValue.(**time.Time); isTime {
// is time
field.IsNormal = true
} else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
// is embedded struct // is embedded struct
for _, subField := range scope.New(fieldValue).GetStructFields() { for _, subField := range scope.New(fieldValue).GetStructFields() {
@ -189,11 +191,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
continue continue
} else { } else {
// build relationships // build relationships
indirectType := fieldStruct.Type
for indirectType.Kind() == reflect.Ptr {
indirectType = indirectType.Elem()
}
switch indirectType.Kind() { switch indirectType.Kind() {
case reflect.Slice: case reflect.Slice:
defer func(field *StructField) { defer func(field *StructField) {

View File

@ -168,7 +168,8 @@ type Comment struct {
// Scanner // Scanner
type NullValue struct { type NullValue struct {
Id int64 Id int64
Name sql.NullString `sql:"not null"` Name sql.NullString `sql:"not null"`
Gender *sql.NullString `sql:"not null"`
Age sql.NullInt64 Age sql.NullInt64
Male sql.NullBool Male sql.NullBool
Height sql.NullFloat64 Height sql.NullFloat64