mirror of https://github.com/go-gorm/gorm.git
Respect update permission for OnConflict Create
This commit is contained in:
parent
0329b800b0
commit
2ec7043818
|
@ -37,7 +37,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
|
|
||||||
return func(db *gorm.DB) {
|
return func(db *gorm.DB) {
|
||||||
if db.Error != nil {
|
if db.Error != nil {
|
||||||
// maybe record logger TODO
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,11 +63,9 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
if !(db.RowsAffected > 0) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
||||||
|
db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||||
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
@ -107,7 +104,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -349,14 +345,18 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
|
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
|
||||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
|
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
|
||||||
if stmt.Schema != nil && len(values.Columns) > 1 {
|
if stmt.Schema != nil && len(values.Columns) > 1 {
|
||||||
|
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
|
||||||
|
|
||||||
columns := make([]string, 0, len(values.Columns)-1)
|
columns := make([]string, 0, len(values.Columns)-1)
|
||||||
for _, column := range values.Columns {
|
for _, column := range values.Columns {
|
||||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||||
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
||||||
columns = append(columns, column.Name)
|
columns = append(columns, column.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onConflict.DoUpdates = clause.AssignmentColumns(columns)
|
onConflict.DoUpdates = clause.AssignmentColumns(columns)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
@ -51,6 +53,19 @@ func TestUpsert(t *testing.T) {
|
||||||
if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name {
|
if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name {
|
||||||
t.Fatalf("failed to upsert, got name %v", result.Name)
|
t.Fatalf("failed to upsert, got name %v", result.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if name := DB.Dialector.Name(); name != "sqlserver" {
|
||||||
|
type RestrictedLanguage struct {
|
||||||
|
Code string `gorm:"primarykey"`
|
||||||
|
Name string
|
||||||
|
Lang string `gorm:"<-:create"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
|
||||||
|
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) {
|
||||||
|
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpsertSlice(t *testing.T) {
|
func TestUpsertSlice(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue