From ec0aa10bf2d52571c9b41ea9785e01338163c8cb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 4 Jan 2016 21:49:04 +0800 Subject: [PATCH] Fix set scanner to a Field --- field.go | 32 ++++++++++---------------------- main_test.go | 26 +++++++++++++++++++++++++- scope_private.go | 16 +++++++++------- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/field.go b/field.go index 7151f468..79e2c0ec 100644 --- a/field.go +++ b/field.go @@ -22,35 +22,23 @@ func (field *Field) Set(value interface{}) error { return errors.New("unaddressable value") } - if rvalue, ok := value.(reflect.Value); ok { - value = rvalue.Interface() + reflectValue, ok := value.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(value) } - if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { - if v, ok := value.(reflect.Value); ok { - if err := scanner.Scan(v.Interface()); err != nil { - return err - } - } else { - if err := scanner.Scan(value); err != nil { - return err - } - } - } else { - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - if !reflectValue.IsValid() { - return nil - } - + if reflectValue.IsValid() { if reflectValue.Type().ConvertibleTo(field.Field.Type()) { field.Field.Set(reflectValue.Convert(field.Field.Type())) + } else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { + if err := scanner.Scan(reflectValue.Interface()); err != nil { + return err + } } else { return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type()) } + } else { + field.Field.Set(reflect.Zero(field.Field.Type())) } field.IsBlank = isBlank(field.Field) diff --git a/main_test.go b/main_test.go index eb11ddb0..e6c703e4 100644 --- a/main_test.go +++ b/main_test.go @@ -228,7 +228,7 @@ func TestTableName(t *testing.T) { DB.SingularTable(false) } -func TestSqlNullValue(t *testing.T) { +func TestNullValues(t *testing.T) { DB.DropTable(&NullValue{}) DB.AutoMigrate(&NullValue{}) @@ -279,6 +279,30 @@ func TestSqlNullValue(t *testing.T) { } } +func TestNullValuesWithFirstOrCreate(t *testing.T) { + var nv1 = NullValue{ + Name: sql.NullString{String: "first_or_create", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + } + + var nv2 NullValue + if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { + t.Errorf("first or create with nullvalues") + } + + if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if nv2.Age.Int64 != 18 { + t.Errorf("should update age to 18") + } +} + func TestTransaction(t *testing.T) { tx := DB.Begin() u := User{Name: "transcation"} diff --git a/scope_private.go b/scope_private.go index 761241af..634c1c8a 100644 --- a/scope_private.go +++ b/scope_private.go @@ -334,9 +334,8 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore } var hasExpr bool - fields := scope.Fields() for key, value := range values { - if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() { + if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { if _, ok := value.(*expr); ok { hasExpr = true @@ -347,13 +346,16 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore } } } + if hasExpr { var updateMap = map[string]interface{}{} - for key, value := range fields { - if v, ok := values[key]; ok { - updateMap[key] = v - } else { - updateMap[key] = value.Field.Interface() + for key, field := range scope.Fields() { + if field.IsNormal { + if v, ok := values[key]; ok { + updateMap[key] = v + } else { + updateMap[key] = field.Field.Interface() + } } } return updateMap, true