Test Association For BelongsTo

This commit is contained in:
Jinzhu 2020-05-24 17:24:23 +08:00
parent cbc4a81140
commit 91a695893c
10 changed files with 265 additions and 45 deletions

View File

@ -19,8 +19,10 @@ type Association struct {
func (db *DB) Association(column string) *Association { func (db *DB) Association(column string) *Association {
association := &Association{DB: db} association := &Association{DB: db}
table := db.Statement.Table
if err := db.Statement.Parse(db.Statement.Model); err == nil { if err := db.Statement.Parse(db.Statement.Model); err == nil {
db.Statement.Table = table
association.Relationship = db.Statement.Schema.Relationships.Relations[column] association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil { if association.Relationship == nil {
@ -83,6 +85,16 @@ func (association *Association) Replace(values ...interface{}) error {
rel := association.Relationship rel := association.Relationship
switch rel.Type { switch rel.Type {
case schema.BelongsTo:
if len(values) == 0 {
updateMap := map[string]interface{}{}
for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil
}
association.DB.UpdateColumns(updateMap)
}
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
var ( var (
primaryFields []*schema.Field primaryFields []*schema.Field
@ -90,6 +102,9 @@ func (association *Association) Replace(values ...interface{}) error {
updateMap = map[string]interface{}{} updateMap = map[string]interface{}{}
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
) )
if rel.Type == schema.BelongsTo {
modelValue = reflect.New(rel.Schema.ModelType).Interface()
}
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -101,7 +116,7 @@ func (association *Association) Replace(values ...interface{}) error {
} }
_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if len(values) > 0 { if len(values) == 0 {
column, queryValues := schema.ToQueryValues(foreignKeys, values) column, queryValues := schema.ToQueryValues(foreignKeys, values)
association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap)
} }
@ -161,8 +176,8 @@ func (association *Association) Delete(values ...interface{}) error {
tx = association.DB tx = association.DB
rel = association.Relationship rel = association.Relationship
reflectValue = tx.Statement.ReflectValue reflectValue = tx.Statement.ReflectValue
conds = rel.ToQueryConditions(reflectValue)
relFields []*schema.Field relFields []*schema.Field
foreignKeyFields []*schema.Field
foreignKeys []string foreignKeys []string
updateAttrs = map[string]interface{}{} updateAttrs = map[string]interface{}{}
) )
@ -174,6 +189,7 @@ func (association *Association) Delete(values ...interface{}) error {
relFields = append(relFields, ref.ForeignKey) relFields = append(relFields, ref.ForeignKey)
} else { } else {
relFields = append(relFields, ref.PrimaryKey) relFields = append(relFields, ref.PrimaryKey)
foreignKeyFields = append(foreignKeyFields, ref.ForeignKey)
} }
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
@ -189,11 +205,14 @@ func (association *Association) Delete(values ...interface{}) error {
switch rel.Type { switch rel.Type {
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
conds := rel.ToQueryConditions(reflectValue)
tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
case schema.BelongsTo: case schema.BelongsTo:
tx.Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) modelValue := reflect.New(rel.Schema.ModelType).Interface()
tx.Model(modelValue).UpdateColumns(updateAttrs)
case schema.Many2Many: case schema.Many2Many:
modelValue := reflect.New(rel.JoinTable.ModelType).Interface() modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
conds := rel.ToQueryConditions(reflectValue)
tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue)
} }
@ -216,13 +235,16 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
rel.Field.Set(data, validFieldValues) rel.Field.Set(data, validFieldValues.Interface())
case reflect.Struct: case reflect.Struct:
for idx, field := range relFields { for idx, field := range relFields {
fieldValues[idx], _ = field.ValueOf(data) fieldValues[idx], _ = field.ValueOf(fieldValue)
} }
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok {
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface())
for _, field := range foreignKeyFields {
field.Set(data, reflect.Zero(field.FieldType).Interface())
}
} }
} }
} }
@ -275,7 +297,11 @@ func (association *Association) Count() (count int64) {
} }
func (association *Association) saveAssociation(clear bool, values ...interface{}) { func (association *Association) saveAssociation(clear bool, values ...interface{}) {
reflectValue := association.DB.Statement.ReflectValue var (
reflectValue = association.DB.Statement.ReflectValue
assignBacks = [][2]reflect.Value{}
assignBack = association.Relationship.Field.FieldType.Kind() == reflect.Struct
)
appendToRelations := func(source, rv reflect.Value, clear bool) { appendToRelations := func(source, rv reflect.Value, clear bool) {
switch association.Relationship.Type { switch association.Relationship.Type {
@ -283,10 +309,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() { switch rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if rv.Len() > 0 { if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(source, rv.Index(0)) association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
if assignBack {
assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)})
}
} }
case reflect.Struct: case reflect.Struct:
association.Error = association.Relationship.Field.Set(source, rv) association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
if assignBack {
assignBacks = append(assignBacks, [2]reflect.Value{source, rv})
}
} }
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem() elemType := association.Relationship.Field.IndirectFieldType.Elem()
@ -315,7 +347,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
if association.Error == nil { if association.Error == nil {
association.Error = association.Relationship.Field.Set(source, fieldValue) association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface())
} }
} }
} }
@ -333,7 +365,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
if len(values) != reflectValue.Len() { if len(values) != reflectValue.Len() {
if clear && len(values) == 0 { if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
} }
break break
} }
@ -349,19 +381,24 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
case reflect.Struct: case reflect.Struct:
if clear && len(values) == 0 { if clear && len(values) == 0 {
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
} }
for idx, value := range values { for idx, value := range values {
appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
} }
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
} }
if hasZero { if hasZero {
association.DB.Save(reflectValue.Interface()) association.DB.Save(reflectValue.Addr().Interface())
} else { } else {
association.DB.Select(selectedColumns).Save(reflectValue.Interface()) association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface())
}
for _, assignBack := range assignBacks {
reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0]))
} }
} }

View File

@ -73,8 +73,8 @@ func SaveBeforeAssociations(db *gorm.DB) {
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Session(&gorm.Session{}).Create(rv.Interface()) db.Session(&gorm.Session{}).Create(rv.Interface())
setupReferences(db.Statement.ReflectValue, rv)
} }
setupReferences(db.Statement.ReflectValue, rv)
} }
} }
} }

View File

@ -22,7 +22,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo
break break
} }
if field := stmt.Schema.LookUpField(column); field != nil { if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true results[field.DBName] = true
} else { } else {
results[column] = true results[column] = true

View File

@ -7,6 +7,7 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
) )
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
@ -91,8 +92,27 @@ func AfterUpdate(db *gorm.DB) {
// ConvertToAssignments convert to update assignments // ConvertToAssignments convert to update assignments
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
selectColumns, restricted := SelectAndOmitColumns(stmt, false, true) var (
reflectModelValue := reflect.ValueOf(stmt.Model) selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model))
assignValue func(field *schema.Field, value interface{})
)
switch reflectModelValue.Kind() {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < reflectModelValue.Len(); i++ {
field.Set(reflectModelValue.Index(i), value)
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
field.Set(reflectModelValue, value)
}
default:
assignValue = func(field *schema.Field, value interface{}) {
}
}
switch value := stmt.Dest.(type) { switch value := stmt.Dest.(type) {
case map[string]interface{}: case map[string]interface{}:
@ -111,7 +131,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
value[k] = time.Now() value[k] = time.Now()
} }
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
field.Set(reflectModelValue, value[k]) assignValue(field, value[k])
} }
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
@ -122,7 +142,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
now := time.Now() now := time.Now()
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
field.Set(reflectModelValue, now) assignValue(field, now)
} }
} }
default: default:
@ -140,7 +160,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if ok || !isZero { if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
field.Set(reflectModelValue, value) assignValue(field, value)
} }
} }
} else { } else {

11
gorm.go
View File

@ -105,11 +105,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
func (db *DB) Session(config *Session) *DB { func (db *DB) Session(config *Session) *DB {
var ( var (
tx = db.getInstance() tx = db.getInstance()
stmt = tx.Statement.clone()
txConfig = *tx.Config txConfig = *tx.Config
) )
if config.Context != nil { if config.Context != nil {
tx.Statement.Context = config.Context stmt.Context = config.Context
} }
if config.Logger != nil { if config.Logger != nil {
@ -120,9 +121,11 @@ func (db *DB) Session(config *Session) *DB {
txConfig.NowFunc = config.NowFunc txConfig.NowFunc = config.NowFunc
} }
tx.Config = &txConfig return &DB{
tx.clone = true Config: &txConfig,
return tx Statement: stmt,
clone: true,
}
} }
// WithContext change current instance db's context to ctx // WithContext change current instance db's context to ctx

View File

@ -372,7 +372,11 @@ func (field *Field) setupValuerAndSetter() {
} }
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
if v == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
@ -386,6 +390,7 @@ func (field *Field) setupValuerAndSetter() {
} else { } else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
}
return err return err
} }

View File

@ -387,6 +387,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
column, values := ToQueryValues(relForeignKeys, foreignValues) column, values := ToQueryValues(relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})
return return
} }

View File

@ -278,6 +278,39 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
return err return err
} }
func (stmt *Statement) clone() *Statement {
newStmt := &Statement{
DB: stmt.DB,
Table: stmt.Table,
Model: stmt.Model,
Dest: stmt.Dest,
ReflectValue: stmt.ReflectValue,
Clauses: map[string]clause.Clause{},
Selects: stmt.Selects,
Omits: stmt.Omits,
Joins: map[string][]interface{}{},
Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
}
for k, c := range stmt.Clauses {
newStmt.Clauses[k] = c
}
for k, p := range stmt.Preloads {
newStmt.Preloads[k] = p
}
for k, j := range stmt.Joins {
newStmt.Joins[k] = j
}
return newStmt
}
func (stmt *Statement) reinit() { func (stmt *Statement) reinit() {
// stmt.Table = "" // stmt.Table = ""
// stmt.Model = nil // stmt.Model = nil

View File

@ -15,6 +15,7 @@ func TestAssociationForBelongsTo(t *testing.T) {
CheckUser(t, user, user) CheckUser(t, user, user)
// Find
var user2 User var user2 User
DB.Find(&user2, "id = ?", user.ID) DB.Find(&user2, "id = ?", user.ID)
DB.Model(&user2).Association("Company").Find(&user2.Company) DB.Model(&user2).Association("Company").Find(&user2.Company)
@ -22,6 +23,7 @@ func TestAssociationForBelongsTo(t *testing.T) {
DB.Model(&user2).Association("Manager").Find(user2.Manager) DB.Model(&user2).Association("Manager").Find(user2.Manager)
CheckUser(t, user2, user) CheckUser(t, user2, user)
// Count
if count := DB.Model(&user).Association("Company").Count(); count != 1 { if count := DB.Model(&user).Association("Company").Count(); count != 1 {
t.Errorf("invalid company count, got %v", count) t.Errorf("invalid company count, got %v", count)
} }
@ -29,4 +31,123 @@ func TestAssociationForBelongsTo(t *testing.T) {
if count := DB.Model(&user).Association("Manager").Count(); count != 1 { if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
t.Errorf("invalid manager count, got %v", count) t.Errorf("invalid manager count, got %v", count)
} }
// Append
var company = Company{Name: "company-belongs-to-append"}
var manager = GetUser("manager-belongs-to-append", Config{})
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
t.Fatalf("Error happened when append Company, got %v", err)
}
if company.ID == 0 {
t.Fatalf("Company's ID should be created")
}
if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil {
t.Fatalf("Error happened when append Manager, got %v", err)
}
if manager.ID == 0 {
t.Fatalf("Manager's ID should be created")
}
user.Company = company
user.Manager = manager
user.CompanyID = &company.ID
user.ManagerID = &manager.ID
CheckUser(t, user2, user)
// Replace
var company2 = Company{Name: "company-belongs-to-replace"}
var manager2 = GetUser("manager-belongs-to-replace", Config{})
if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil {
t.Fatalf("Error happened when replace Company, got %v", err)
}
if company2.ID == 0 {
t.Fatalf("Company's ID should be created")
}
if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil {
t.Fatalf("Error happened when replace Manager, got %v", err)
}
if manager2.ID == 0 {
t.Fatalf("Manager's ID should be created")
}
user.Company = company2
user.Manager = manager2
user.CompanyID = &company2.ID
user.ManagerID = &manager2.ID
CheckUser(t, user2, user)
// Delete
if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil {
t.Fatalf("Error happened when delete Company, got %v", err)
}
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
t.Errorf("Invalid company count after delete non-existing association, got %v", count)
}
if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil {
t.Fatalf("Error happened when delete Company, got %v", err)
}
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
t.Errorf("Invalid company count after delete, got %v", count)
}
if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil {
t.Fatalf("Error happened when delete Manager, got %v", err)
}
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
t.Errorf("Invalid manager count after delete non-existing association, got %v", count)
}
if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil {
t.Fatalf("Error happened when delete Manager, got %v", err)
}
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
t.Errorf("Invalid manager count after delete, got %v", count)
}
// Prepare Data
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
t.Fatalf("Error happened when append Company, got %v", err)
}
if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil {
t.Fatalf("Error happened when append Manager, got %v", err)
}
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
t.Errorf("Invalid company count after append, got %v", count)
}
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
t.Errorf("Invalid manager count after append, got %v", count)
}
// Clear
if err := DB.Model(&user2).Association("Company").Clear(); err != nil {
t.Errorf("Error happened when clear Company, got %v", err)
}
if err := DB.Model(&user2).Association("Manager").Clear(); err != nil {
t.Errorf("Error happened when clear Manager, got %v", err)
}
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
t.Errorf("Invalid company count after clear, got %v", count)
}
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
t.Errorf("Invalid manager count after clear, got %v", count)
}
} }

View File

@ -33,7 +33,7 @@ func TestCount(t *testing.T) {
var count3 int64 var count3 int64
if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
t.Errorf("No error should happen when count with group, but got %v", err) t.Errorf("Error happened when count with group, but got %v", err)
} }
if count3 != 2 { if count3 != 2 {