From 81536f823c055ba293dfbb7a8e90ebf93d32b431 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 Mar 2024 11:50:28 +0800 Subject: [PATCH] Fix insert id into map results, fix #6812 --- callbacks/create.go | 23 ++++++++++++++++------- tests/create_test.go | 30 ++++++++++++++---------------- tests/go.mod | 2 +- tests/helper_test.go | 4 ++++ 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index d930e922..afea2cca 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) { pkField *schema.Field pkFieldName = "@id" ) + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + + if !insertOk { + if !supportReturning { + db.AddError(err) + } + return + } + if db.Statement.Schema != nil { if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { return @@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName } - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) - return - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { @@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) { } } } + + if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + } + for _, mapValue := range mapValues { if mapValue != nil { mapValue[pkFieldName] = insertID diff --git a/tests/create_test.go b/tests/create_test.go index 5e97a542..abb82472 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -713,18 +713,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) { } func TestCreateFromMapWithTable(t *testing.T) { - if !isMysql() { - t.Skipf("This test case skipped, because of only supportting for mysql") - } - tableDB := DB.Table("`users`") + tableDB := DB.Table("users") + supportLastInsertID := isMysql() || isSqlite() // case 1: create from map[string]interface{} - record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} + record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18} if err := tableDB.Create(record).Error; err != nil { t.Fatalf("failed to create data from map with table, got error: %v", err) } - if _, ok := record["@id"]; !ok { + if _, ok := record["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -733,8 +731,8 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res["id"].(uint64)) != record["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) { + t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"]) } // case 2: create from *map[string]interface{} @@ -743,7 +741,7 @@ func TestCreateFromMapWithTable(t *testing.T) { if err := tableDB2.Create(&record1).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } - if _, ok := record1["@id"]; !ok { + if _, ok := record1["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -752,7 +750,7 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res1["id"].(uint64)) != record1["@id"] { + if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) { t.Fatal("failed to create data from map with table, @id != id") } @@ -767,11 +765,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create data from slice of map, got error: %v", err) } - if _, ok := records[0]["@id"]; !ok { + if _, ok := records[0]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } - if _, ok := records[1]["@id"]; !ok { + if _, ok := records[1]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -785,11 +783,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } - if int64(res2["id"].(uint64)) != records[0]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"]) } - if int64(res3["id"].(uint64)) != records[1]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id") } } diff --git a/tests/go.mod b/tests/go.mod index 350152d3..5616ebb3 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -11,7 +11,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.7 + gorm.io/gorm v1.25.8 ) require ( diff --git a/tests/helper_test.go b/tests/helper_test.go index feb67f9e..dc250b7c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -281,6 +281,10 @@ func isMysql() bool { return os.Getenv("GORM_DIALECT") == "mysql" } +func isSqlite() bool { + return os.Getenv("GORM_DIALECT") == "sqlite" +} + func db(unscoped bool) *gorm.DB { if unscoped { return DB.Unscoped()