mirror of https://github.com/go-gorm/gorm.git
Test Association For BelongsTo
This commit is contained in:
parent
cbc4a81140
commit
91a695893c
|
@ -19,8 +19,10 @@ type Association struct {
|
||||||
|
|
||||||
func (db *DB) Association(column string) *Association {
|
func (db *DB) Association(column string) *Association {
|
||||||
association := &Association{DB: db}
|
association := &Association{DB: db}
|
||||||
|
table := db.Statement.Table
|
||||||
|
|
||||||
if err := db.Statement.Parse(db.Statement.Model); err == nil {
|
if err := db.Statement.Parse(db.Statement.Model); err == nil {
|
||||||
|
db.Statement.Table = table
|
||||||
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
|
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
|
||||||
|
|
||||||
if association.Relationship == nil {
|
if association.Relationship == nil {
|
||||||
|
@ -83,6 +85,16 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||||
rel := association.Relationship
|
rel := association.Relationship
|
||||||
|
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
|
case schema.BelongsTo:
|
||||||
|
if len(values) == 0 {
|
||||||
|
updateMap := map[string]interface{}{}
|
||||||
|
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
updateMap[ref.ForeignKey.DBName] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
association.DB.UpdateColumns(updateMap)
|
||||||
|
}
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
var (
|
var (
|
||||||
primaryFields []*schema.Field
|
primaryFields []*schema.Field
|
||||||
|
@ -90,6 +102,9 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||||
updateMap = map[string]interface{}{}
|
updateMap = map[string]interface{}{}
|
||||||
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
)
|
)
|
||||||
|
if rel.Type == schema.BelongsTo {
|
||||||
|
modelValue = reflect.New(rel.Schema.ModelType).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
|
@ -101,7 +116,7 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||||
if len(values) > 0 {
|
if len(values) == 0 {
|
||||||
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
||||||
association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap)
|
association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap)
|
||||||
}
|
}
|
||||||
|
@ -161,8 +176,8 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
tx = association.DB
|
tx = association.DB
|
||||||
rel = association.Relationship
|
rel = association.Relationship
|
||||||
reflectValue = tx.Statement.ReflectValue
|
reflectValue = tx.Statement.ReflectValue
|
||||||
conds = rel.ToQueryConditions(reflectValue)
|
|
||||||
relFields []*schema.Field
|
relFields []*schema.Field
|
||||||
|
foreignKeyFields []*schema.Field
|
||||||
foreignKeys []string
|
foreignKeys []string
|
||||||
updateAttrs = map[string]interface{}{}
|
updateAttrs = map[string]interface{}{}
|
||||||
)
|
)
|
||||||
|
@ -174,6 +189,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
relFields = append(relFields, ref.ForeignKey)
|
relFields = append(relFields, ref.ForeignKey)
|
||||||
} else {
|
} else {
|
||||||
relFields = append(relFields, ref.PrimaryKey)
|
relFields = append(relFields, ref.PrimaryKey)
|
||||||
|
foreignKeyFields = append(foreignKeyFields, ref.ForeignKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||||
|
@ -189,11 +205,14 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
conds := rel.ToQueryConditions(reflectValue)
|
||||||
tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
|
tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
|
||||||
case schema.BelongsTo:
|
case schema.BelongsTo:
|
||||||
tx.Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
|
modelValue := reflect.New(rel.Schema.ModelType).Interface()
|
||||||
|
tx.Model(modelValue).UpdateColumns(updateAttrs)
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
|
conds := rel.ToQueryConditions(reflectValue)
|
||||||
tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue)
|
tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,13 +235,16 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rel.Field.Set(data, validFieldValues)
|
rel.Field.Set(data, validFieldValues.Interface())
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for idx, field := range relFields {
|
for idx, field := range relFields {
|
||||||
fieldValues[idx], _ = field.ValueOf(data)
|
fieldValues[idx], _ = field.ValueOf(fieldValue)
|
||||||
}
|
}
|
||||||
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok {
|
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||||
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType))
|
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface())
|
||||||
|
for _, field := range foreignKeyFields {
|
||||||
|
field.Set(data, reflect.Zero(field.FieldType).Interface())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -275,7 +297,11 @@ func (association *Association) Count() (count int64) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
|
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
|
||||||
reflectValue := association.DB.Statement.ReflectValue
|
var (
|
||||||
|
reflectValue = association.DB.Statement.ReflectValue
|
||||||
|
assignBacks = [][2]reflect.Value{}
|
||||||
|
assignBack = association.Relationship.Field.FieldType.Kind() == reflect.Struct
|
||||||
|
)
|
||||||
|
|
||||||
appendToRelations := func(source, rv reflect.Value, clear bool) {
|
appendToRelations := func(source, rv reflect.Value, clear bool) {
|
||||||
switch association.Relationship.Type {
|
switch association.Relationship.Type {
|
||||||
|
@ -283,10 +309,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
switch rv.Kind() {
|
switch rv.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if rv.Len() > 0 {
|
if rv.Len() > 0 {
|
||||||
association.Error = association.Relationship.Field.Set(source, rv.Index(0))
|
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
|
||||||
|
if assignBack {
|
||||||
|
assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
association.Error = association.Relationship.Field.Set(source, rv)
|
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
|
||||||
|
if assignBack {
|
||||||
|
assignBacks = append(assignBacks, [2]reflect.Value{source, rv})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case schema.HasMany, schema.Many2Many:
|
case schema.HasMany, schema.Many2Many:
|
||||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||||
|
@ -315,7 +347,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
}
|
}
|
||||||
|
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
association.Error = association.Relationship.Field.Set(source, fieldValue)
|
association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -333,7 +365,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
if len(values) != reflectValue.Len() {
|
if len(values) != reflectValue.Len() {
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType))
|
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -349,19 +381,24 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType))
|
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, value := range values {
|
for idx, value := range values {
|
||||||
appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0)
|
rv := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
appendToRelations(reflectValue, rv, clear && idx == 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
|
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasZero {
|
if hasZero {
|
||||||
association.DB.Save(reflectValue.Interface())
|
association.DB.Save(reflectValue.Addr().Interface())
|
||||||
} else {
|
} else {
|
||||||
association.DB.Select(selectedColumns).Save(reflectValue.Interface())
|
association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, assignBack := range assignBacks {
|
||||||
|
reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,8 +73,8 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
|
|
||||||
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||||
db.Session(&gorm.Session{}).Create(rv.Interface())
|
db.Session(&gorm.Session{}).Create(rv.Interface())
|
||||||
setupReferences(db.Statement.ReflectValue, rv)
|
|
||||||
}
|
}
|
||||||
|
setupReferences(db.Statement.ReflectValue, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if field := stmt.Schema.LookUpField(column); field != nil {
|
if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||||
results[field.DBName] = true
|
results[field.DBName] = true
|
||||||
} else {
|
} else {
|
||||||
results[column] = true
|
results[column] = true
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BeforeUpdate(db *gorm.DB) {
|
func BeforeUpdate(db *gorm.DB) {
|
||||||
|
@ -91,8 +92,27 @@ func AfterUpdate(db *gorm.DB) {
|
||||||
|
|
||||||
// ConvertToAssignments convert to update assignments
|
// ConvertToAssignments convert to update assignments
|
||||||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
selectColumns, restricted := SelectAndOmitColumns(stmt, false, true)
|
var (
|
||||||
reflectModelValue := reflect.ValueOf(stmt.Model)
|
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
|
||||||
|
reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model))
|
||||||
|
assignValue func(field *schema.Field, value interface{})
|
||||||
|
)
|
||||||
|
|
||||||
|
switch reflectModelValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
assignValue = func(field *schema.Field, value interface{}) {
|
||||||
|
for i := 0; i < reflectModelValue.Len(); i++ {
|
||||||
|
field.Set(reflectModelValue.Index(i), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
assignValue = func(field *schema.Field, value interface{}) {
|
||||||
|
field.Set(reflectModelValue, value)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
assignValue = func(field *schema.Field, value interface{}) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch value := stmt.Dest.(type) {
|
switch value := stmt.Dest.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
|
@ -111,7 +131,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
value[k] = time.Now()
|
value[k] = time.Now()
|
||||||
}
|
}
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
||||||
field.Set(reflectModelValue, value[k])
|
assignValue(field, value[k])
|
||||||
}
|
}
|
||||||
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
|
||||||
|
@ -122,7 +142,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||||
field.Set(reflectModelValue, now)
|
assignValue(field, now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -140,7 +160,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
|
|
||||||
if ok || !isZero {
|
if ok || !isZero {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||||
field.Set(reflectModelValue, value)
|
assignValue(field, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
11
gorm.go
11
gorm.go
|
@ -105,11 +105,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
||||||
func (db *DB) Session(config *Session) *DB {
|
func (db *DB) Session(config *Session) *DB {
|
||||||
var (
|
var (
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
stmt = tx.Statement.clone()
|
||||||
txConfig = *tx.Config
|
txConfig = *tx.Config
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.Context != nil {
|
if config.Context != nil {
|
||||||
tx.Statement.Context = config.Context
|
stmt.Context = config.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Logger != nil {
|
if config.Logger != nil {
|
||||||
|
@ -120,9 +121,11 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
txConfig.NowFunc = config.NowFunc
|
txConfig.NowFunc = config.NowFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Config = &txConfig
|
return &DB{
|
||||||
tx.clone = true
|
Config: &txConfig,
|
||||||
return tx
|
Statement: stmt,
|
||||||
|
clone: true,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithContext change current instance db's context to ctx
|
// WithContext change current instance db's context to ctx
|
||||||
|
|
|
@ -372,7 +372,11 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
}
|
}
|
||||||
|
|
||||||
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
||||||
|
if v == nil {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
|
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
@ -386,6 +390,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -387,6 +387,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
|
||||||
|
|
||||||
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
||||||
column, values := ToQueryValues(relForeignKeys, foreignValues)
|
column, values := ToQueryValues(relForeignKeys, foreignValues)
|
||||||
|
|
||||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
33
statement.go
33
statement.go
|
@ -278,6 +278,39 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stmt *Statement) clone() *Statement {
|
||||||
|
newStmt := &Statement{
|
||||||
|
DB: stmt.DB,
|
||||||
|
Table: stmt.Table,
|
||||||
|
Model: stmt.Model,
|
||||||
|
Dest: stmt.Dest,
|
||||||
|
ReflectValue: stmt.ReflectValue,
|
||||||
|
Clauses: map[string]clause.Clause{},
|
||||||
|
Selects: stmt.Selects,
|
||||||
|
Omits: stmt.Omits,
|
||||||
|
Joins: map[string][]interface{}{},
|
||||||
|
Preloads: map[string][]interface{}{},
|
||||||
|
ConnPool: stmt.ConnPool,
|
||||||
|
Schema: stmt.Schema,
|
||||||
|
Context: stmt.Context,
|
||||||
|
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, c := range stmt.Clauses {
|
||||||
|
newStmt.Clauses[k] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, p := range stmt.Preloads {
|
||||||
|
newStmt.Preloads[k] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, j := range stmt.Joins {
|
||||||
|
newStmt.Joins[k] = j
|
||||||
|
}
|
||||||
|
|
||||||
|
return newStmt
|
||||||
|
}
|
||||||
|
|
||||||
func (stmt *Statement) reinit() {
|
func (stmt *Statement) reinit() {
|
||||||
// stmt.Table = ""
|
// stmt.Table = ""
|
||||||
// stmt.Model = nil
|
// stmt.Model = nil
|
||||||
|
|
|
@ -15,6 +15,7 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
|
|
||||||
CheckUser(t, user, user)
|
CheckUser(t, user, user)
|
||||||
|
|
||||||
|
// Find
|
||||||
var user2 User
|
var user2 User
|
||||||
DB.Find(&user2, "id = ?", user.ID)
|
DB.Find(&user2, "id = ?", user.ID)
|
||||||
DB.Model(&user2).Association("Company").Find(&user2.Company)
|
DB.Model(&user2).Association("Company").Find(&user2.Company)
|
||||||
|
@ -22,6 +23,7 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
DB.Model(&user2).Association("Manager").Find(user2.Manager)
|
DB.Model(&user2).Association("Manager").Find(user2.Manager)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
// Count
|
||||||
if count := DB.Model(&user).Association("Company").Count(); count != 1 {
|
if count := DB.Model(&user).Association("Company").Count(); count != 1 {
|
||||||
t.Errorf("invalid company count, got %v", count)
|
t.Errorf("invalid company count, got %v", count)
|
||||||
}
|
}
|
||||||
|
@ -29,4 +31,123 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
|
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
|
||||||
t.Errorf("invalid manager count, got %v", count)
|
t.Errorf("invalid manager count, got %v", count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
var company = Company{Name: "company-belongs-to-append"}
|
||||||
|
var manager = GetUser("manager-belongs-to-append", Config{})
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
|
||||||
|
t.Fatalf("Error happened when append Company, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if company.ID == 0 {
|
||||||
|
t.Fatalf("Company's ID should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil {
|
||||||
|
t.Fatalf("Error happened when append Manager, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.ID == 0 {
|
||||||
|
t.Fatalf("Manager's ID should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Company = company
|
||||||
|
user.Manager = manager
|
||||||
|
user.CompanyID = &company.ID
|
||||||
|
user.ManagerID = &manager.ID
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var company2 = Company{Name: "company-belongs-to-replace"}
|
||||||
|
var manager2 = GetUser("manager-belongs-to-replace", Config{})
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil {
|
||||||
|
t.Fatalf("Error happened when replace Company, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if company2.ID == 0 {
|
||||||
|
t.Fatalf("Company's ID should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil {
|
||||||
|
t.Fatalf("Error happened when replace Manager, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager2.ID == 0 {
|
||||||
|
t.Fatalf("Manager's ID should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Company = company2
|
||||||
|
user.Manager = manager2
|
||||||
|
user.CompanyID = &company2.ID
|
||||||
|
user.ManagerID = &manager2.ID
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil {
|
||||||
|
t.Fatalf("Error happened when delete Company, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
|
||||||
|
t.Errorf("Invalid company count after delete non-existing association, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil {
|
||||||
|
t.Fatalf("Error happened when delete Company, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
|
||||||
|
t.Errorf("Invalid company count after delete, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil {
|
||||||
|
t.Fatalf("Error happened when delete Manager, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
|
||||||
|
t.Errorf("Invalid manager count after delete non-existing association, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil {
|
||||||
|
t.Fatalf("Error happened when delete Manager, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
|
||||||
|
t.Errorf("Invalid manager count after delete, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare Data
|
||||||
|
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
|
||||||
|
t.Fatalf("Error happened when append Company, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil {
|
||||||
|
t.Fatalf("Error happened when append Manager, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
|
||||||
|
t.Errorf("Invalid company count after append, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
|
||||||
|
t.Errorf("Invalid manager count after append, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
if err := DB.Model(&user2).Association("Company").Clear(); err != nil {
|
||||||
|
t.Errorf("Error happened when clear Company, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Manager").Clear(); err != nil {
|
||||||
|
t.Errorf("Error happened when clear Manager, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
|
||||||
|
t.Errorf("Invalid company count after clear, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
|
||||||
|
t.Errorf("Invalid manager count after clear, got %v", count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ func TestCount(t *testing.T) {
|
||||||
|
|
||||||
var count3 int64
|
var count3 int64
|
||||||
if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
|
if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
|
||||||
t.Errorf("No error should happen when count with group, but got %v", err)
|
t.Errorf("Error happened when count with group, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if count3 != 2 {
|
if count3 != 2 {
|
||||||
|
|
Loading…
Reference in New Issue