diff --git a/association.go b/association.go index ab9090ac..abcae47d 100644 --- a/association.go +++ b/association.go @@ -101,8 +101,10 @@ func (association *Association) Replace(values ...interface{}) error { } _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - column, queryValues := schema.ToQueryValues(foreignKeys, values) - association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + if len(values) > 0 { + column, queryValues := schema.ToQueryValues(foreignKeys, values) + association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + } case schema.Many2Many: var primaryFields, relPrimaryFields []*schema.Field var foreignKeys, relForeignKeys []string @@ -200,13 +202,13 @@ func (association *Association) Delete(values ...interface{}) error { if _, zero := rel.Field.ValueOf(data); !zero { fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) - fieldValues := make([]reflect.Value, len(relFields)) + fieldValues := make([]interface{}, len(relFields)) switch fieldValue.Kind() { case reflect.Slice, reflect.Array: validFieldValues := reflect.Zero(rel.Field.FieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range relFields { - fieldValues[idx] = field.ReflectValueOf(fieldValue.Index(i)) + fieldValues[idx], _ = field.ValueOf(fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok { @@ -217,7 +219,7 @@ func (association *Association) Delete(values ...interface{}) error { rel.Field.Set(data, validFieldValues) case reflect.Struct: for idx, field := range relFields { - fieldValues[idx] = field.ReflectValueOf(data) + fieldValues[idx], _ = field.ValueOf(data) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) diff --git a/callbacks/associations.go b/callbacks/associations.go index 96d9ce22..ef040b71 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -276,7 +276,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.Session(&gorm.Session{}).Debug().Create(joins.Interface()) + db.Session(&gorm.Session{}).Create(joins.Interface()) } } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 9f23a2ca..7e3810b5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -42,22 +42,25 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, joinForeignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(joinForeignValues) == 0 { + return + } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map - fieldValues := make([]reflect.Value, len(foreignFields)) - joinFieldValues := make([]reflect.Value, len(joinForeignFields)) + fieldValues := make([]interface{}, len(foreignFields)) + joinFieldValues := make([]interface{}, len(joinForeignFields)) for i := 0; i < joinResults.Len(); i++ { - for idx, field := range foreignFields { - fieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + for idx, field := range joinForeignFields { + fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) } - for idx, field := range joinForeignFields { - joinFieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + for idx, field := range joinRelForeignFields { + joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -82,16 +85,19 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(foreignValues) == 0 { + return + } } reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(relForeignKeys, foreignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...) - fieldValues := make([]reflect.Value, len(foreignFields)) + fieldValues := make([]interface{}, len(foreignFields)) for i := 0; i < reflectResults.Len(); i++ { for idx, field := range relForeignFields { - fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i)) + fieldValues[idx], _ = field.ValueOf(reflectResults.Index(i)) } for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { diff --git a/schema/utils.go b/schema/utils.go index 72bd149c..ead83cab 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -89,9 +89,9 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle // GetIdentityFieldValuesMap get identity map from fields func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( - fieldValues = make([]reflect.Value, len(fields)) - results = [][]interface{}{} - dataResults = map[string][]reflect.Value{} + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + notZero, zero bool ) switch reflectValue.Kind() { @@ -99,28 +99,33 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue) - results[0][idx] = fieldValues[idx].Interface() + results[0][idx], zero = field.ValueOf(reflectValue) + notZero = notZero || !zero } - dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} + if !notZero { + return nil, nil + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Slice, reflect.Array: + fieldValues := make([]interface{}, len(fields)) + for i := 0; i < reflectValue.Len(); i++ { + notZero = false for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) + fieldValues[idx], zero = field.ValueOf(reflectValue.Index(idx)) + notZero = notZero || !zero } - dataKey := utils.ToStringKey(fieldValues...) - if _, ok := dataResults[dataKey]; !ok { - result := make([]interface{}, len(fieldValues)) - for idx, fieldValue := range fieldValues { - result[idx] = fieldValue.Interface() + if notZero { + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + results = append(results, fieldValues[:]) + dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} + } else { + dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) } - results = append(results, result) - - dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} - } else { - dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) } } } diff --git a/tests/create.go b/tests/create.go index 09464674..0d85a29e 100644 --- a/tests/create.go +++ b/tests/create.go @@ -52,7 +52,7 @@ func GetUser(name string, config Config) *User { } for i := 0; i < config.Languages; i++ { - name := name + "_locale_" + strconv.Itoa(i+0) + name := name + "_locale_" + strconv.Itoa(i+1) language := Language{Code: name, Name: name} DB.Create(&language) user.Languages = append(user.Languages, language) diff --git a/tests/create_test.go b/tests/create_test.go index 9241e0a6..ef9203aa 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -34,7 +34,7 @@ func TestCreate(t *testing.T) { } func TestCreateWithAssociations(t *testing.T) { - var user = *GetUser("create_with_belongs_to", Config{ + var user = *GetUser("create_with_associations", Config{ Account: true, Pets: 2, Toys: 3, @@ -52,34 +52,38 @@ func TestCreateWithAssociations(t *testing.T) { CheckUser(t, user, user) var user2 User - DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Find(&user2, "id = ?", user.ID) + DB.Debug().Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) } -// func TestBulkCreateWithBelongsTo(t *testing.T) { -// users := []User{ -// *GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}), -// *GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}), -// *GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}), -// *GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}), -// } +func TestBulkCreateWithAssociations(t *testing.T) { + users := []User{ + *GetUser("bulk_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("bulk_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("bulk_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("bulk_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("bulk_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("bulk_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + *GetUser("bulk_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), + *GetUser("bulk_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), + } -// if err := DB.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// CheckUser(t, user, user) -// } + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + CheckUser(t, user, user) + } -// var users2 []User -// DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs) -// for idx, user := range users2 { -// CheckUser(t, user, users[idx]) -// } -// } + var users2 []User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} // func TestBulkCreatePtrDataWithBelongsTo(t *testing.T) { // users := []*User{ diff --git a/tests/tests.go b/tests/tests.go index 1ff700c5..2b2bfc20 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -73,7 +73,7 @@ func RunMigrations() { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - DB.Migrator().DropTable("user_friends", "user_speak") + DB.Migrator().DropTable("user_friends", "user_speaks") if err = DB.Migrator().DropTable(allModels...); err != nil { log.Printf("Failed to drop table, got error %v\n", err) diff --git a/utils/utils.go b/utils/utils.go index 5d6c9da2..3924e69e 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -41,16 +41,15 @@ func CheckTruth(val interface{}) bool { return !reflect.ValueOf(val).IsZero() } -func ToStringKey(values ...reflect.Value) string { +func ToStringKey(values ...interface{}) string { results := make([]string, len(values)) for idx, value := range values { - rv := reflect.Indirect(value).Interface() - if valuer, ok := rv.(driver.Valuer); ok { - rv, _ = valuer.Value() + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() } - switch v := rv.(type) { + switch v := value.(type) { case string: results[idx] = v case []byte: @@ -58,7 +57,7 @@ func ToStringKey(values ...reflect.Value) string { case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: - results[idx] = fmt.Sprint(v) + results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) } }