mirror of https://github.com/go-gorm/gorm.git
Fix set scanner to a Field
This commit is contained in:
parent
be45d8312e
commit
ec0aa10bf2
26
field.go
26
field.go
|
@ -22,35 +22,23 @@ func (field *Field) Set(value interface{}) error {
|
||||||
return errors.New("unaddressable value")
|
return errors.New("unaddressable value")
|
||||||
}
|
}
|
||||||
|
|
||||||
if rvalue, ok := value.(reflect.Value); ok {
|
|
||||||
value = rvalue.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
|
||||||
if v, ok := value.(reflect.Value); ok {
|
|
||||||
if err := scanner.Scan(v.Interface()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := scanner.Scan(value); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
reflectValue, ok := value.(reflect.Value)
|
reflectValue, ok := value.(reflect.Value)
|
||||||
if !ok {
|
if !ok {
|
||||||
reflectValue = reflect.ValueOf(value)
|
reflectValue = reflect.ValueOf(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflectValue.IsValid() {
|
if reflectValue.IsValid() {
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
|
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
|
||||||
field.Field.Set(reflectValue.Convert(field.Field.Type()))
|
field.Field.Set(reflectValue.Convert(field.Field.Type()))
|
||||||
|
} else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
||||||
|
if err := scanner.Scan(reflectValue.Interface()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type())
|
return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type())
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
field.Field.Set(reflect.Zero(field.Field.Type()))
|
||||||
}
|
}
|
||||||
|
|
||||||
field.IsBlank = isBlank(field.Field)
|
field.IsBlank = isBlank(field.Field)
|
||||||
|
|
26
main_test.go
26
main_test.go
|
@ -228,7 +228,7 @@ func TestTableName(t *testing.T) {
|
||||||
DB.SingularTable(false)
|
DB.SingularTable(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlNullValue(t *testing.T) {
|
func TestNullValues(t *testing.T) {
|
||||||
DB.DropTable(&NullValue{})
|
DB.DropTable(&NullValue{})
|
||||||
DB.AutoMigrate(&NullValue{})
|
DB.AutoMigrate(&NullValue{})
|
||||||
|
|
||||||
|
@ -279,6 +279,30 @@ func TestSqlNullValue(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNullValuesWithFirstOrCreate(t *testing.T) {
|
||||||
|
var nv1 = NullValue{
|
||||||
|
Name: sql.NullString{String: "first_or_create", Valid: true},
|
||||||
|
Gender: &sql.NullString{String: "M", Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var nv2 NullValue
|
||||||
|
if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil {
|
||||||
|
t.Errorf("Should not raise any error, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" {
|
||||||
|
t.Errorf("first or create with nullvalues")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
|
||||||
|
t.Errorf("Should not raise any error, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nv2.Age.Int64 != 18 {
|
||||||
|
t.Errorf("should update age to 18")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTransaction(t *testing.T) {
|
func TestTransaction(t *testing.T) {
|
||||||
tx := DB.Begin()
|
tx := DB.Begin()
|
||||||
u := User{Name: "transcation"}
|
u := User{Name: "transcation"}
|
||||||
|
|
|
@ -334,9 +334,8 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasExpr bool
|
var hasExpr bool
|
||||||
fields := scope.Fields()
|
|
||||||
for key, value := range values {
|
for key, value := range values {
|
||||||
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
|
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
|
||||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||||
if _, ok := value.(*expr); ok {
|
if _, ok := value.(*expr); ok {
|
||||||
hasExpr = true
|
hasExpr = true
|
||||||
|
@ -347,13 +346,16 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasExpr {
|
if hasExpr {
|
||||||
var updateMap = map[string]interface{}{}
|
var updateMap = map[string]interface{}{}
|
||||||
for key, value := range fields {
|
for key, field := range scope.Fields() {
|
||||||
|
if field.IsNormal {
|
||||||
if v, ok := values[key]; ok {
|
if v, ok := values[key]; ok {
|
||||||
updateMap[key] = v
|
updateMap[key] = v
|
||||||
} else {
|
} else {
|
||||||
updateMap[key] = value.Field.Interface()
|
updateMap[key] = field.Field.Interface()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return updateMap, true
|
return updateMap, true
|
||||||
|
|
Loading…
Reference in New Issue