forked from mirror/gorm
Don't preload if foreign keys zero
This commit is contained in:
parent
2ca4e91d88
commit
5ec4fee797
|
@ -101,8 +101,10 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||||
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
if len(values) > 0 {
|
||||||
association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap)
|
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
||||||
|
association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap)
|
||||||
|
}
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
var primaryFields, relPrimaryFields []*schema.Field
|
var primaryFields, relPrimaryFields []*schema.Field
|
||||||
var foreignKeys, relForeignKeys []string
|
var foreignKeys, relForeignKeys []string
|
||||||
|
@ -200,13 +202,13 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
if _, zero := rel.Field.ValueOf(data); !zero {
|
if _, zero := rel.Field.ValueOf(data); !zero {
|
||||||
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
|
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
|
||||||
|
|
||||||
fieldValues := make([]reflect.Value, len(relFields))
|
fieldValues := make([]interface{}, len(relFields))
|
||||||
switch fieldValue.Kind() {
|
switch fieldValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
validFieldValues := reflect.Zero(rel.Field.FieldType)
|
validFieldValues := reflect.Zero(rel.Field.FieldType)
|
||||||
for i := 0; i < fieldValue.Len(); i++ {
|
for i := 0; i < fieldValue.Len(); i++ {
|
||||||
for idx, field := range relFields {
|
for idx, field := range relFields {
|
||||||
fieldValues[idx] = field.ReflectValueOf(fieldValue.Index(i))
|
fieldValues[idx], _ = field.ValueOf(fieldValue.Index(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok {
|
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok {
|
||||||
|
@ -217,7 +219,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
rel.Field.Set(data, validFieldValues)
|
rel.Field.Set(data, validFieldValues)
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for idx, field := range relFields {
|
for idx, field := range relFields {
|
||||||
fieldValues[idx] = field.ReflectValueOf(data)
|
fieldValues[idx], _ = field.ValueOf(data)
|
||||||
}
|
}
|
||||||
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok {
|
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||||
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType))
|
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType))
|
||||||
|
|
|
@ -276,7 +276,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if joins.Len() > 0 {
|
if joins.Len() > 0 {
|
||||||
db.Session(&gorm.Session{}).Debug().Create(joins.Interface())
|
db.Session(&gorm.Session{}).Create(joins.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,22 +42,25 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, joinForeignFields)
|
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
||||||
|
if len(joinForeignValues) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||||
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues)
|
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues)
|
||||||
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
|
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
|
||||||
|
|
||||||
// convert join identity map to relation identity map
|
// convert join identity map to relation identity map
|
||||||
fieldValues := make([]reflect.Value, len(foreignFields))
|
fieldValues := make([]interface{}, len(foreignFields))
|
||||||
joinFieldValues := make([]reflect.Value, len(joinForeignFields))
|
joinFieldValues := make([]interface{}, len(joinForeignFields))
|
||||||
for i := 0; i < joinResults.Len(); i++ {
|
for i := 0; i < joinResults.Len(); i++ {
|
||||||
for idx, field := range foreignFields {
|
for idx, field := range joinForeignFields {
|
||||||
fieldValues[idx] = field.ReflectValueOf(joinResults.Index(i))
|
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, field := range joinForeignFields {
|
for idx, field := range joinRelForeignFields {
|
||||||
joinFieldValues[idx] = field.ReflectValueOf(joinResults.Index(i))
|
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||||
|
@ -82,16 +85,19 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
||||||
|
if len(foreignValues) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||||
column, values := schema.ToQueryValues(relForeignKeys, foreignValues)
|
column, values := schema.ToQueryValues(relForeignKeys, foreignValues)
|
||||||
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...)
|
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...)
|
||||||
|
|
||||||
fieldValues := make([]reflect.Value, len(foreignFields))
|
fieldValues := make([]interface{}, len(foreignFields))
|
||||||
for i := 0; i < reflectResults.Len(); i++ {
|
for i := 0; i < reflectResults.Len(); i++ {
|
||||||
for idx, field := range relForeignFields {
|
for idx, field := range relForeignFields {
|
||||||
fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i))
|
fieldValues[idx], _ = field.ValueOf(reflectResults.Index(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
|
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
|
||||||
|
|
|
@ -89,9 +89,9 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle
|
||||||
// GetIdentityFieldValuesMap get identity map from fields
|
// GetIdentityFieldValuesMap get identity map from fields
|
||||||
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||||
var (
|
var (
|
||||||
fieldValues = make([]reflect.Value, len(fields))
|
results = [][]interface{}{}
|
||||||
results = [][]interface{}{}
|
dataResults = map[string][]reflect.Value{}
|
||||||
dataResults = map[string][]reflect.Value{}
|
notZero, zero bool
|
||||||
)
|
)
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
|
@ -99,28 +99,33 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
|
||||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||||
|
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
fieldValues[idx] = field.ReflectValueOf(reflectValue)
|
results[0][idx], zero = field.ValueOf(reflectValue)
|
||||||
results[0][idx] = fieldValues[idx].Interface()
|
notZero = notZero || !zero
|
||||||
}
|
}
|
||||||
|
|
||||||
dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue}
|
if !notZero {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue}
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
fieldValues := make([]interface{}, len(fields))
|
||||||
|
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
notZero = false
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i))
|
fieldValues[idx], zero = field.ValueOf(reflectValue.Index(idx))
|
||||||
|
notZero = notZero || !zero
|
||||||
}
|
}
|
||||||
|
|
||||||
dataKey := utils.ToStringKey(fieldValues...)
|
if notZero {
|
||||||
if _, ok := dataResults[dataKey]; !ok {
|
dataKey := utils.ToStringKey(fieldValues...)
|
||||||
result := make([]interface{}, len(fieldValues))
|
if _, ok := dataResults[dataKey]; !ok {
|
||||||
for idx, fieldValue := range fieldValues {
|
results = append(results, fieldValues[:])
|
||||||
result[idx] = fieldValue.Interface()
|
dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)}
|
||||||
|
} else {
|
||||||
|
dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i))
|
||||||
}
|
}
|
||||||
results = append(results, result)
|
|
||||||
|
|
||||||
dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)}
|
|
||||||
} else {
|
|
||||||
dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,7 @@ func GetUser(name string, config Config) *User {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < config.Languages; i++ {
|
for i := 0; i < config.Languages; i++ {
|
||||||
name := name + "_locale_" + strconv.Itoa(i+0)
|
name := name + "_locale_" + strconv.Itoa(i+1)
|
||||||
language := Language{Code: name, Name: name}
|
language := Language{Code: name, Name: name}
|
||||||
DB.Create(&language)
|
DB.Create(&language)
|
||||||
user.Languages = append(user.Languages, language)
|
user.Languages = append(user.Languages, language)
|
||||||
|
|
|
@ -34,7 +34,7 @@ func TestCreate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateWithAssociations(t *testing.T) {
|
func TestCreateWithAssociations(t *testing.T) {
|
||||||
var user = *GetUser("create_with_belongs_to", Config{
|
var user = *GetUser("create_with_associations", Config{
|
||||||
Account: true,
|
Account: true,
|
||||||
Pets: 2,
|
Pets: 2,
|
||||||
Toys: 3,
|
Toys: 3,
|
||||||
|
@ -52,34 +52,38 @@ func TestCreateWithAssociations(t *testing.T) {
|
||||||
CheckUser(t, user, user)
|
CheckUser(t, user, user)
|
||||||
|
|
||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Find(&user2, "id = ?", user.ID)
|
DB.Debug().Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
// func TestBulkCreateWithBelongsTo(t *testing.T) {
|
func TestBulkCreateWithAssociations(t *testing.T) {
|
||||||
// users := []User{
|
users := []User{
|
||||||
// *GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}),
|
*GetUser("bulk_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}),
|
||||||
// *GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}),
|
*GetUser("bulk_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}),
|
||||||
// *GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}),
|
*GetUser("bulk_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}),
|
||||||
// *GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}),
|
*GetUser("bulk_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}),
|
||||||
// }
|
*GetUser("bulk_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}),
|
||||||
|
*GetUser("bulk_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}),
|
||||||
|
*GetUser("bulk_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}),
|
||||||
|
*GetUser("bulk_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}),
|
||||||
|
}
|
||||||
|
|
||||||
// if err := DB.Create(&users).Error; err != nil {
|
if err := DB.Create(&users).Error; err != nil {
|
||||||
// t.Fatalf("errors happened when create: %v", err)
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// var userIDs []uint
|
var userIDs []uint
|
||||||
// for _, user := range users {
|
for _, user := range users {
|
||||||
// userIDs = append(userIDs, user.ID)
|
userIDs = append(userIDs, user.ID)
|
||||||
// CheckUser(t, user, user)
|
CheckUser(t, user, user)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// var users2 []User
|
var users2 []User
|
||||||
// DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs)
|
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs)
|
||||||
// for idx, user := range users2 {
|
for idx, user := range users2 {
|
||||||
// CheckUser(t, user, users[idx])
|
CheckUser(t, user, users[idx])
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func TestBulkCreatePtrDataWithBelongsTo(t *testing.T) {
|
// func TestBulkCreatePtrDataWithBelongsTo(t *testing.T) {
|
||||||
// users := []*User{
|
// users := []*User{
|
||||||
|
|
|
@ -73,7 +73,7 @@ func RunMigrations() {
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||||
|
|
||||||
DB.Migrator().DropTable("user_friends", "user_speak")
|
DB.Migrator().DropTable("user_friends", "user_speaks")
|
||||||
|
|
||||||
if err = DB.Migrator().DropTable(allModels...); err != nil {
|
if err = DB.Migrator().DropTable(allModels...); err != nil {
|
||||||
log.Printf("Failed to drop table, got error %v\n", err)
|
log.Printf("Failed to drop table, got error %v\n", err)
|
||||||
|
|
|
@ -41,16 +41,15 @@ func CheckTruth(val interface{}) bool {
|
||||||
return !reflect.ValueOf(val).IsZero()
|
return !reflect.ValueOf(val).IsZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToStringKey(values ...reflect.Value) string {
|
func ToStringKey(values ...interface{}) string {
|
||||||
results := make([]string, len(values))
|
results := make([]string, len(values))
|
||||||
|
|
||||||
for idx, value := range values {
|
for idx, value := range values {
|
||||||
rv := reflect.Indirect(value).Interface()
|
if valuer, ok := value.(driver.Valuer); ok {
|
||||||
if valuer, ok := rv.(driver.Valuer); ok {
|
value, _ = valuer.Value()
|
||||||
rv, _ = valuer.Value()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := rv.(type) {
|
switch v := value.(type) {
|
||||||
case string:
|
case string:
|
||||||
results[idx] = v
|
results[idx] = v
|
||||||
case []byte:
|
case []byte:
|
||||||
|
@ -58,7 +57,7 @@ func ToStringKey(values ...reflect.Value) string {
|
||||||
case uint:
|
case uint:
|
||||||
results[idx] = strconv.FormatUint(uint64(v), 10)
|
results[idx] = strconv.FormatUint(uint64(v), 10)
|
||||||
default:
|
default:
|
||||||
results[idx] = fmt.Sprint(v)
|
results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue