From ecc946be6e93a108bbdcc10cf2719d08baa50f3f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 16:05:06 +0800 Subject: [PATCH] Test update from sub query --- callbacks/update.go | 9 +++++++-- tests/update_test.go | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 12806af6..0ced3ffb 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { sort.Strings(keys) for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { if field.DBName != "" { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) assignValue(field, value[k]) } } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { @@ -189,7 +194,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) } } diff --git a/tests/update_test.go b/tests/update_test.go index 2ff150dd..83a7b9a2 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -545,3 +545,21 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) { t.Errorf("element's ignored field should not be updated") } } + +func TestUpdateFromSubQuery(t *testing.T) { + user := *GetUser("update_from_sub_query", Config{Company: true}) + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error: %v", err) + } + + if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Company.Name { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } +}