Fix set scanner to a Field

This commit is contained in:
Jinzhu 2016-01-04 21:49:04 +08:00
parent be45d8312e
commit ec0aa10bf2
3 changed files with 44 additions and 30 deletions

View File

@ -22,35 +22,23 @@ func (field *Field) Set(value interface{}) error {
return errors.New("unaddressable value")
}
if rvalue, ok := value.(reflect.Value); ok {
value = rvalue.Interface()
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if v, ok := value.(reflect.Value); ok {
if err := scanner.Scan(v.Interface()); err != nil {
return err
}
} else {
if err := scanner.Scan(value); err != nil {
return err
}
}
} else {
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
if !reflectValue.IsValid() {
return nil
}
if reflectValue.IsValid() {
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type()))
} else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if err := scanner.Scan(reflectValue.Interface()); err != nil {
return err
}
} else {
return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type())
}
} else {
field.Field.Set(reflect.Zero(field.Field.Type()))
}
field.IsBlank = isBlank(field.Field)

View File

@ -228,7 +228,7 @@ func TestTableName(t *testing.T) {
DB.SingularTable(false)
}
func TestSqlNullValue(t *testing.T) {
func TestNullValues(t *testing.T) {
DB.DropTable(&NullValue{})
DB.AutoMigrate(&NullValue{})
@ -279,6 +279,30 @@ func TestSqlNullValue(t *testing.T) {
}
}
func TestNullValuesWithFirstOrCreate(t *testing.T) {
var nv1 = NullValue{
Name: sql.NullString{String: "first_or_create", Valid: true},
Gender: &sql.NullString{String: "M", Valid: true},
}
var nv2 NullValue
if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil {
t.Errorf("Should not raise any error, but got %v", err)
}
if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" {
t.Errorf("first or create with nullvalues")
}
if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
t.Errorf("Should not raise any error, but got %v", err)
}
if nv2.Age.Int64 != 18 {
t.Errorf("should update age to 18")
}
}
func TestTransaction(t *testing.T) {
tx := DB.Begin()
u := User{Name: "transcation"}

View File

@ -334,9 +334,8 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
}
var hasExpr bool
fields := scope.Fields()
for key, value := range values {
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if _, ok := value.(*expr); ok {
hasExpr = true
@ -347,13 +346,16 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
}
}
}
if hasExpr {
var updateMap = map[string]interface{}{}
for key, value := range fields {
if v, ok := values[key]; ok {
updateMap[key] = v
} else {
updateMap[key] = value.Field.Interface()
for key, field := range scope.Fields() {
if field.IsNormal {
if v, ok := values[key]; ok {
updateMap[key] = v
} else {
updateMap[key] = field.Field.Interface()
}
}
}
return updateMap, true