mirror of https://github.com/go-gorm/gorm.git
Fix set scanner to a Field
This commit is contained in:
parent
be45d8312e
commit
ec0aa10bf2
32
field.go
32
field.go
|
@ -22,35 +22,23 @@ func (field *Field) Set(value interface{}) error {
|
|||
return errors.New("unaddressable value")
|
||||
}
|
||||
|
||||
if rvalue, ok := value.(reflect.Value); ok {
|
||||
value = rvalue.Interface()
|
||||
reflectValue, ok := value.(reflect.Value)
|
||||
if !ok {
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
||||
if v, ok := value.(reflect.Value); ok {
|
||||
if err := scanner.Scan(v.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := scanner.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
reflectValue, ok := value.(reflect.Value)
|
||||
if !ok {
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
if !reflectValue.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if reflectValue.IsValid() {
|
||||
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
|
||||
field.Field.Set(reflectValue.Convert(field.Field.Type()))
|
||||
} else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
||||
if err := scanner.Scan(reflectValue.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type())
|
||||
}
|
||||
} else {
|
||||
field.Field.Set(reflect.Zero(field.Field.Type()))
|
||||
}
|
||||
|
||||
field.IsBlank = isBlank(field.Field)
|
||||
|
|
26
main_test.go
26
main_test.go
|
@ -228,7 +228,7 @@ func TestTableName(t *testing.T) {
|
|||
DB.SingularTable(false)
|
||||
}
|
||||
|
||||
func TestSqlNullValue(t *testing.T) {
|
||||
func TestNullValues(t *testing.T) {
|
||||
DB.DropTable(&NullValue{})
|
||||
DB.AutoMigrate(&NullValue{})
|
||||
|
||||
|
@ -279,6 +279,30 @@ func TestSqlNullValue(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNullValuesWithFirstOrCreate(t *testing.T) {
|
||||
var nv1 = NullValue{
|
||||
Name: sql.NullString{String: "first_or_create", Valid: true},
|
||||
Gender: &sql.NullString{String: "M", Valid: true},
|
||||
}
|
||||
|
||||
var nv2 NullValue
|
||||
if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil {
|
||||
t.Errorf("Should not raise any error, but got %v", err)
|
||||
}
|
||||
|
||||
if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" {
|
||||
t.Errorf("first or create with nullvalues")
|
||||
}
|
||||
|
||||
if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
|
||||
t.Errorf("Should not raise any error, but got %v", err)
|
||||
}
|
||||
|
||||
if nv2.Age.Int64 != 18 {
|
||||
t.Errorf("should update age to 18")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
tx := DB.Begin()
|
||||
u := User{Name: "transcation"}
|
||||
|
|
|
@ -334,9 +334,8 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
|||
}
|
||||
|
||||
var hasExpr bool
|
||||
fields := scope.Fields()
|
||||
for key, value := range values {
|
||||
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
|
||||
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasExpr = true
|
||||
|
@ -347,13 +346,16 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasExpr {
|
||||
var updateMap = map[string]interface{}{}
|
||||
for key, value := range fields {
|
||||
if v, ok := values[key]; ok {
|
||||
updateMap[key] = v
|
||||
} else {
|
||||
updateMap[key] = value.Field.Interface()
|
||||
for key, field := range scope.Fields() {
|
||||
if field.IsNormal {
|
||||
if v, ok := values[key]; ok {
|
||||
updateMap[key] = v
|
||||
} else {
|
||||
updateMap[key] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
return updateMap, true
|
||||
|
|
Loading…
Reference in New Issue