Fix create from map

This commit is contained in:
Jinzhu 2020-08-17 17:41:36 +08:00
parent 681268cc43
commit 9fcc337bd1
5 changed files with 103 additions and 29 deletions

View File

@ -48,6 +48,8 @@ func SaveBeforeAssociations(db *gorm.DB) {
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj) objs = append(objs, obj)
@ -57,6 +59,9 @@ func SaveBeforeAssociations(db *gorm.DB) {
elems = reflect.Append(elems, rv.Addr()) elems = reflect.Append(elems, rv.Addr())
} }
} }
} else {
break
}
} }
if elems.Len() > 0 { if elems.Len() > 0 {
@ -112,6 +117,7 @@ func SaveAfterAssociations(db *gorm.DB) {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj) rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
@ -130,6 +136,7 @@ func SaveAfterAssociations(db *gorm.DB) {
elems = reflect.Append(elems, rv) elems = reflect.Append(elems, rv)
} }
} }
}
if elems.Len() > 0 { if elems.Len() > 0 {
assignmentColumns := []string{} assignmentColumns := []string{}
@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i)) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
} }
case reflect.Struct: case reflect.Struct:
appendToElems(db.Statement.ReflectValue) appendToElems(db.Statement.ReflectValue)
@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i)) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
} }
case reflect.Struct: case reflect.Struct:
appendToElems(db.Statement.ReflectValue) appendToElems(db.Statement.ReflectValue)

View File

@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed { if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
if isZero { if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID-- insertID--
} }
} }
} else { } else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { rv := db.Statement.ReflectValue.Index(i)
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID++ insertID++
} }
} }
@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) {
for rows.Next() { for rows.Next() {
BEGIN: BEGIN:
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
break
}
for idx, field := range fields { for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue) fieldValue := field.ReflectValueOf(reflectValue)

View File

@ -26,6 +26,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: k}) values.Columns = append(values.Columns, clause.Column{Name: k})
if len(values.Values) == 0 {
values.Values = [][]interface{}{{}}
}
values.Values[0] = append(values.Values[0], value) values.Values[0] = append(values.Values[0], value)
} }
} }
@ -61,11 +65,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
sort.Strings(columns) sort.Strings(columns)
values.Values = make([][]interface{}, len(mapValues)) values.Values = make([][]interface{}, len(mapValues))
values.Columns = make([]clause.Column, len(columns))
for idx, column := range columns { for idx, column := range columns {
values.Columns[idx] = clause.Column{Name: column}
for i, v := range result[column] { for i, v := range result[column] {
if i == 0 { if len(values.Values[i]) == 0 {
values.Values[i] = make([]interface{}, len(columns)) values.Values[i] = make([]interface{}, len(columns))
} }
values.Values[i][idx] = v values.Values[i][idx] = v
} }
} }

View File

@ -39,6 +39,45 @@ func TestCreate(t *testing.T) {
} }
} }
func TestCreateFromMap(t *testing.T) {
if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
var result User
if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 {
t.Fatalf("failed to create from map, got error %v", err)
}
if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
var result1 User
if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 {
t.Fatalf("failed to create from map, got error %v", err)
}
datas := []map[string]interface{}{
{"Name": "create_from_map_2", "Age": 19},
{"name": "create_from_map_3", "Age": 20},
}
if err := DB.Model(&User{}).Create(datas).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}
var result2 User
if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
var result3 User
if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
}
func TestCreateWithAssociations(t *testing.T) { func TestCreateWithAssociations(t *testing.T) {
var user = *GetUser("create_with_associations", Config{ var user = *GetUser("create_with_associations", Config{
Account: true, Account: true,

View File

@ -9,7 +9,7 @@ require (
gorm.io/driver/mysql v0.3.1 gorm.io/driver/mysql v0.3.1
gorm.io/driver/postgres v0.2.6 gorm.io/driver/postgres v0.2.6
gorm.io/driver/sqlite v1.0.9 gorm.io/driver/sqlite v1.0.9
gorm.io/driver/sqlserver v0.2.6 gorm.io/driver/sqlserver v0.2.7
gorm.io/gorm v0.2.19 gorm.io/gorm v0.2.19
) )