mirror of https://github.com/go-gorm/gorm.git
map insert support return increment id (#6662)
This commit is contained in:
parent
c1e911f6ed
commit
3207ad6033
|
@ -103,9 +103,22 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||
}
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
if !insertOk {
|
||||
|
@ -113,6 +126,33 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// append @id column with value for auto-increment primary key
|
||||
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||
switch values := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values[pkFieldName] = insertID
|
||||
case *map[string]interface{}:
|
||||
(*values)[pkFieldName] = insertID
|
||||
case []map[string]interface{}, *[]map[string]interface{}:
|
||||
mapValues, ok := values.([]map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||
if *v != nil {
|
||||
mapValues = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
}
|
||||
insertID += schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
default:
|
||||
if pkField == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
|
@ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
break
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
break
|
||||
}
|
||||
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,6 +49,8 @@ const (
|
|||
Bytes DataType = "bytes"
|
||||
)
|
||||
|
||||
const DefaultAutoIncrementIncrement int64 = 1
|
||||
|
||||
// Field is the representation of model schema's field
|
||||
type Field struct {
|
||||
Name string
|
||||
|
@ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||
Comment: tagSetting["COMMENT"],
|
||||
AutoIncrementIncrement: 1,
|
||||
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
|
||||
}
|
||||
|
||||
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
||||
|
|
|
@ -2,6 +2,7 @@ package tests_test
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -580,7 +581,7 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
|
||||
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
|
||||
type OnConfilctUser struct {
|
||||
ID string
|
||||
Name string `gorm:"default:null"`
|
||||
|
@ -615,3 +616,180 @@ func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
|
|||
AssertEqual(t, u2.Email, "on-confilct-user-email-2")
|
||||
AssertEqual(t, u2.Mobile, "133xxxx")
|
||||
}
|
||||
|
||||
func TestCreateFromMapWithoutPK(t *testing.T) {
|
||||
if !isMysql() {
|
||||
t.Skipf("This test case skipped, because of only supportting for mysql")
|
||||
}
|
||||
|
||||
// case 1: one record, create from map[string]interface{}
|
||||
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
|
||||
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := mapValue1["id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result1 User
|
||||
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
var idVal int64
|
||||
_, ok := mapValue1["id"].(uint)
|
||||
if ok {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
idVal, ok = mapValue1["id"].(int64)
|
||||
if !ok {
|
||||
t.Fatal("ret result missing id")
|
||||
}
|
||||
|
||||
if int64(result1.ID) != idVal {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case2: one record, create from *map[string]interface{}
|
||||
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
|
||||
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := mapValue2["id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result2 User
|
||||
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
_, ok = mapValue2["id"].(uint)
|
||||
if ok {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
idVal, ok = mapValue2["id"].(int64)
|
||||
if !ok {
|
||||
t.Fatal("ret result missing id")
|
||||
}
|
||||
|
||||
if int64(result2.ID) != idVal {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case 3: records
|
||||
values := []map[string]interface{}{
|
||||
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
|
||||
}
|
||||
|
||||
beforeLen := len(values)
|
||||
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
// mariadb with returning, values will be appended with id map
|
||||
if len(values) == beforeLen*2 {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
v, ok := values[i]["id"]
|
||||
if !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
if int64(result.ID) != v.(int64) {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFromMapWithTable(t *testing.T) {
|
||||
if !isMysql() {
|
||||
t.Skipf("This test case skipped, because of only supportting for mysql")
|
||||
}
|
||||
tableDB := DB.Table("`users`")
|
||||
|
||||
// case 1: create from map[string]interface{}
|
||||
record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18}
|
||||
if err := tableDB.Create(record).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map with table, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record["@id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res map[string]interface{}
|
||||
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if int64(res["id"].(uint64)) != record["@id"] {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case 2: create from *map[string]interface{}
|
||||
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
|
||||
tableDB2 := DB.Table("users")
|
||||
if err := tableDB2.Create(&record1).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
if _, ok := record1["@id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res1 map[string]interface{}
|
||||
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if int64(res1["id"].(uint64)) != record1["@id"] {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case 3: create from []map[string]interface{}
|
||||
records := []map[string]interface{}{
|
||||
{"name": "create_from_map_with_table_2", "age": 19},
|
||||
{"name": "create_from_map_with_table_3", "age": 20},
|
||||
}
|
||||
|
||||
tableDB = DB.Table("users")
|
||||
if err := tableDB.Create(&records).Error; err != nil {
|
||||
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := records[0]["@id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
if _, ok := records[1]["@id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res2 map[string]interface{}
|
||||
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
var res3 map[string]interface{}
|
||||
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
if int64(res2["id"].(uint64)) != records[0]["@id"] {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
if int64(res3["id"].(uint64)) != records[1]["@id"] {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue