diff --git a/callbacks/associations.go b/callbacks/associations.go index 5ff63cc4..3ff0f4b0 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -52,21 +52,19 @@ func SaveBeforeAssociations(db *gorm.DB) { obj := db.Statement.ReflectValue.Index(i) if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(obj) // relation reflect value - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) - } + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) } else { - setupReferences(obj, rv) + elems = reflect.Append(elems, rv.Addr()) } } } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -79,10 +77,11 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Session(&gorm.Session{}).Create(rv.Interface()) + if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(rv.Interface()).Error) == nil { + setupReferences(db.Statement.ReflectValue, rv) } - setupReferences(db.Statement.ReflectValue, rv) } } } @@ -130,16 +129,20 @@ func SaveAfterAssociations(db *gorm.DB) { } } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - elems = reflect.Append(elems, rv) - } else { - db.Session(&gorm.Session{}).Save(rv.Addr().Interface()) - } + elems = reflect.Append(elems, rv) } } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + DoUpdates: clause.AssignmentColumns(assignmentColumns), + }).Create(elems.Interface()) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -148,6 +151,7 @@ func SaveAfterAssociations(db *gorm.DB) { f = f.Addr() } + assignmentColumns := []string{} for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) @@ -155,13 +159,13 @@ func SaveAfterAssociations(db *gorm.DB) { } else if ref.PrimaryValue != "" { ref.ForeignKey.Set(f, ref.PrimaryValue) } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Save(f.Interface()) - } + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + DoUpdates: clause.AssignmentColumns(assignmentColumns), + }).Create(f.Interface()) } } } @@ -193,14 +197,10 @@ func SaveAfterAssociations(db *gorm.DB) { } } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) - } + if isPtr { + elems = reflect.Append(elems, elem) } else { - db.Session(&gorm.Session{}).Save(elem.Addr().Interface()) + elems = reflect.Append(elems, elem.Addr()) } } } @@ -216,7 +216,15 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + DoUpdates: clause.AssignmentColumns(assignmentColumns), + }).Create(elems.Interface()) } } @@ -258,15 +266,11 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < f.Len(); i++ { elem := f.Index(i) - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { - objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) - } + objs = append(objs, v) + if isPtr { + elems = reflect.Append(elems, elem) } else { - appendToJoins(v, elem) + elems = reflect.Append(elems, elem.Addr()) } } } @@ -282,7 +286,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) diff --git a/callbacks/create.go b/callbacks/create.go index 684d5530..283d3fd1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -55,29 +55,44 @@ func Create(config *Config) func(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected > 0 { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } + } else { + allUpdated := int(db.RowsAffected) == db.Statement.ReflectValue.Len() + isZero := true + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + + if !allUpdated { + _, isZero = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + } + + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } + } } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + db.AddError(err) } - } else { - db.AddError(err) } } - db.RowsAffected, _ = result.RowsAffected() } else { db.AddError(err) } @@ -129,9 +144,19 @@ func CreateWithReturning(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: + c := db.Statement.Clauses["ON CONFLICT"] + onConflict, _ := c.Expression.(clause.OnConflict) + for rows.Next() { + BEGIN: for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + fieldValue := field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))) + if onConflict.DoNothing && !fieldValue.IsZero() { + db.RowsAffected++ + goto BEGIN + } + + values[idx] = fieldValue.Addr().Interface() } db.RowsAffected++ @@ -211,7 +236,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { case reflect.Slice, reflect.Array: stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) - defaultValueFieldsHavingValue := map[string][]interface{}{} + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) values.Values[i] = make([]interface{}, len(values.Columns)) @@ -231,20 +256,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(rv); !isZero { - if len(defaultValueFieldsHavingValue[field.DBName]) == 0 { - defaultValueFieldsHavingValue[field.DBName] = make([]interface{}, stmt.ReflectValue.Len()) + if len(defaultValueFieldsHavingValue[field]) == 0 { + defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) } - defaultValueFieldsHavingValue[field.DBName][i] = v + defaultValueFieldsHavingValue[field][i] = v } } } } - for db, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: db}) + for field, vs := range defaultValueFieldsHavingValue { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) for idx := range values.Values { if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], clause.Expr{SQL: "DEFAULT"}) + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) } else { values.Values[idx] = append(values.Values[idx], vs[idx]) } diff --git a/clause/clause.go b/clause/clause.go index 64f08d14..c7d1efeb 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { const ( PrimaryKey string = "@@@py@@@" // primary key CurrentTable string = "@@@ct@@@" // current table + Associations string = "@@@as@@@" // associations ) var ( diff --git a/interfaces.go b/interfaces.go index f3e5c028..96289a90 100644 --- a/interfaces.go +++ b/interfaces.go @@ -14,6 +14,7 @@ type Dialector interface { Initialize(*DB) error Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string + DefaultValueOf(*schema.Field) clause.Expression BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) QuoteTo(clause.Writer, string) Explain(sql string, vars ...interface{}) string diff --git a/migrator/migrator.go b/migrator/migrator.go index a4cc99a6..b598bd93 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -57,10 +57,6 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) - if field.AutoIncrement { - expr.SQL += " AUTO_INCREMENT" - } - if field.NotNull { expr.SQL += " NOT NULL" } diff --git a/schema/field_test.go b/schema/field_test.go index 0936c0d1..7970b614 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -235,7 +235,7 @@ func TestParseFieldWithPermission(t *testing.T) { } fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, diff --git a/schema/relationship.go b/schema/relationship.go index afa083ed..c69a4a09 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -251,11 +251,13 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) // build references - for _, f := range relation.JoinTable.Fields { + for idx, f := range relation.JoinTable.Fields { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType + relation.JoinTable.PrimaryFields[idx] = f relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], diff --git a/schema/schema.go b/schema/schema.go index 5b360f5e..e5894443 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -188,6 +188,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } field.HasDefaultValue = true + field.AutoIncrement = true } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 4ec7ff0c..99781e47 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -32,7 +32,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, @@ -125,7 +125,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index a6dcc6c5..f487bd9e 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -68,6 +68,7 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 0, "after delete") // Prepare Data for Clear + account = Account{Number: "account-has-one-append"} if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { t.Fatalf("Error happened when append Account, got %v", err) } @@ -185,6 +186,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { AssertAssociationCount(t, pet2, "Toy", 0, "after delete") // Prepare Data for Clear + toy = Toy{Name: "toy-has-one-append"} if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { t.Fatalf("Error happened when append Toy, got %v", err) } diff --git a/tests/go.mod b/tests/go.mod index a2121b7a..1cd56f6b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.1 - gorm.io/driver/postgres v0.2.1 - gorm.io/driver/sqlite v1.0.4 - gorm.io/driver/sqlserver v0.2.1 - gorm.io/gorm v0.2.7 + gorm.io/driver/mysql v0.2.2 + gorm.io/driver/postgres v0.2.2 + gorm.io/driver/sqlite v1.0.5 + gorm.io/driver/sqlserver v0.2.2 + gorm.io/gorm v0.2.9 ) replace gorm.io/gorm => ../ diff --git a/tests/helper_test.go b/tests/helper_test.go index b05f5297..cc0d808c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -58,7 +58,6 @@ func GetUser(name string, config Config) *User { for i := 0; i < config.Languages; i++ { name := name + "_locale_" + strconv.Itoa(i+1) language := Language{Code: name, Name: name} - DB.Create(&language) user.Languages = append(user.Languages, language) } diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 4a25a69b..d40309e7 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "reflect" + "sort" "testing" "gorm.io/gorm" @@ -735,7 +736,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { t.Error(err) } - return if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) @@ -1459,6 +1459,12 @@ func TestPrefixedPreloadDuplication(t *testing.T) { t.Error(err) } + for _, level1 := range append(got, want...) { + sort.Slice(level1.Level2.Level3.Level4s, func(i, j int) bool { + return level1.Level2.Level3.Level4s[i].ID > level1.Level2.Level3.Level4s[j].ID + }) + } + if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } diff --git a/tests/tests_all.sh b/tests/tests_all.sh index fd696e38..a321fe31 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -8,7 +8,7 @@ if [ -d tests ] then cd tests cp go.mod go.mod.bak - sed '/gorm.io\/driver/d' go.mod.bak > go.mod + sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod cd .. fi diff --git a/tests/tests_test.go b/tests/tests_test.go index 9e135b4e..fa8bad5c 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -34,6 +34,9 @@ func init() { } RunMigrations() + if DB.Dialector.Name() == "sqlite" { + DB.Exec("PRAGMA foreign_keys = ON") + } } } @@ -66,7 +69,6 @@ func OpenTestConnection() (db *gorm.DB, err error) { default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) - db.Exec("PRAGMA foreign_keys = ON") } if debug := os.Getenv("DEBUG"); debug == "true" { diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index cd4bbd45..b8452ef9 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -18,6 +18,10 @@ func (DummyDialector) Initialize(*gorm.DB) error { return nil } +func (DummyDialector) DefaultValueOf(field *schema.Field) clause.Expression { + return clause.Expr{SQL: "DEFAULT"} +} + func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { return nil }