diff --git a/callbacks/create.go b/callbacks/create.go index 04ee6b30..2ebe5cab 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -37,7 +37,6 @@ func Create(config *Config) func(db *gorm.DB) { return func(db *gorm.DB) { if db.Error != nil { - // maybe record logger TODO return } @@ -64,11 +63,9 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() - if !(db.RowsAffected > 0) { - return - } - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: @@ -107,7 +104,6 @@ func Create(config *Config) func(db *gorm.DB) { db.AddError(err) } } - } } } @@ -349,11 +345,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if c, ok := stmt.Clauses["ON CONFLICT"]; ok { if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { if stmt.Schema != nil && len(values.Columns) > 1 { + selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) + columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { - if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + columns = append(columns, column.Name) + } } } } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 0ba8b9f0..867110d8 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -1,9 +1,11 @@ package tests_test import ( + "regexp" "testing" "time" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -51,6 +53,19 @@ func TestUpsert(t *testing.T) { if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { t.Fatalf("failed to upsert, got name %v", result.Name) } + + if name := DB.Dialector.Name(); name != "sqlserver" { + type RestrictedLanguage struct { + Code string `gorm:"primarykey"` + Name string + Lang string `gorm:"<-:create"` + } + + r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } } func TestUpsertSlice(t *testing.T) {