diff --git a/main_test.go b/main_test.go index 25b5940c..14bf34ac 100644 --- a/main_test.go +++ b/main_test.go @@ -1182,6 +1182,27 @@ func TestFloatColumnPrecision(t *testing.T) { } } +func TestWhereUpdates(t *testing.T) { + type OwnerEntity struct { + gorm.Model + OwnerID uint + OwnerType string + } + + type SomeEntity struct { + gorm.Model + Name string + OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"` + } + + db := DB.Debug() + db.DropTable(&SomeEntity{}) + db.AutoMigrate(&SomeEntity{}) + + a := SomeEntity{Name: "test"} + db.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/scope.go b/scope.go index c6c92d5a..9f8820eb 100644 --- a/scope.go +++ b/scope.go @@ -872,7 +872,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { return scope } -func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} { +func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { var attrs = map[string]interface{}{} switch value := values.(type) { @@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string return value case []interface{}: for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField) { + for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { attrs[key] = value } } @@ -893,7 +893,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: - for _, field := range (&Scope{Value: values}).Fields() { + for _, field := range (&Scope{Value: values, db: db}).Fields() { if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { attrs[field.DBName] = field.Field.Interface() } @@ -905,12 +905,12 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false), true + return convertInterfaceToMap(value, false, scope.db), true } results = map[string]interface{}{} - for key, value := range convertInterfaceToMap(value, true) { + for key, value := range convertInterfaceToMap(value, true, scope.db) { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { if _, ok := value.(*expr); ok { hasUpdate = true