From 9fcc337bd1ccfccfddcdbd4a9b8b08ad08bf465c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 17:41:36 +0800 Subject: [PATCH] Fix create from map --- callbacks/associations.go | 59 ++++++++++++++++++++++++--------------- callbacks/create.go | 22 ++++++++++++--- callbacks/helper.go | 10 ++++++- tests/create_test.go | 39 ++++++++++++++++++++++++++ tests/go.mod | 2 +- 5 files changed, 103 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3508335a..2710ffe9 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -48,14 +48,19 @@ func SaveBeforeAssociations(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } } + } else { + break } } @@ -112,22 +117,24 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) - } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() } - } - elems = reflect.Append(elems, rv) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } } } @@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) @@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) diff --git a/callbacks/create.go b/callbacks/create.go index 3a414dd7..4cc0f555 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID-- } } } else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID++ } } @@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) { for rows.Next() { BEGIN: reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) + if reflect.Indirect(reflectValue).Kind() != reflect.Struct { + break + } + for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) diff --git a/callbacks/helper.go b/callbacks/helper.go index 7bd910f6..80fbc2a1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -26,6 +26,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + values.Values[0] = append(values.Values[0], value) } } @@ -61,11 +65,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st sort.Strings(columns) values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + for i, v := range result[column] { - if i == 0 { + if len(values.Values[i]) == 0 { values.Values[i] = make([]interface{}, len(columns)) } + values.Values[i][idx] = v } } diff --git a/tests/create_test.go b/tests/create_test.go index ae6e1232..ab0a78d4 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -39,6 +39,45 @@ func TestCreate(t *testing.T) { } } +func TestCreateFromMap(t *testing.T) { + if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result User + if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + datas := []map[string]interface{}{ + {"Name": "create_from_map_2", "Age": 19}, + {"name": "create_from_map_3", "Age": 20}, + } + + if err := DB.Model(&User{}).Create(datas).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var result3 User + if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } +} + func TestCreateWithAssociations(t *testing.T) { var user = *GetUser("create_with_associations", Config{ Account: true, diff --git a/tests/go.mod b/tests/go.mod index 82d4fdc8..54a808d0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v0.3.1 gorm.io/driver/postgres v0.2.6 gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.6 + gorm.io/driver/sqlserver v0.2.7 gorm.io/gorm v0.2.19 )