Fix FirstOr(Init/Create) when assigning with association

This commit is contained in:
Jinzhu 2020-08-13 18:38:39 +08:00
parent 2c4e857125
commit 2faff25dfb
2 changed files with 48 additions and 21 deletions

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
@ -132,8 +133,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
return return
} }
func (tx *DB) assignExprsToValue(exprs []clause.Expression) { func (tx *DB) assignInterfacesToValue(values ...interface{}) {
for _, expr := range exprs { for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok { if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) { switch column := eq.Column.(type) {
case string: case string:
@ -148,28 +152,51 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
} }
} }
} }
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
exprs := tx.Statement.BuildCondition(value)
tx.assignInterfacesToValue(exprs)
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
}
}
}
}
}
} else if len(values) > 0 {
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
tx.assignInterfacesToValue(exprs)
return
}
}
}
} }
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs) tx.assignInterfacesToValue(where.Exprs)
} }
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.attrs) > 0 { if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) tx.assignInterfacesToValue(tx.Statement.attrs...)
tx.assignExprsToValue(exprs)
} }
tx.Error = nil tx.Error = nil
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.assigns) > 0 { if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) tx.assignInterfacesToValue(tx.Statement.assigns...)
tx.assignExprsToValue(exprs)
} }
return return
} }
@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs) tx.assignInterfacesToValue(where.Exprs)
} }
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.attrs) > 0 { if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) tx.assignInterfacesToValue(tx.Statement.attrs...)
tx.assignExprsToValue(exprs)
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.assigns) > 0 { if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) tx.assignInterfacesToValue(tx.Statement.assigns...)
tx.assignExprsToValue(exprs)
} }
return tx.Create(dest) return tx.Create(dest)

View File

@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) {
t.Fatalf("errors happened when create user: %v", err) t.Fatalf("errors happened when create user: %v", err)
} }
user.CreatedAt = time.Time{}
user.UpdatedAt = time.Time{}
if err := DB.Where(&user).First(&User{}).Error; err != nil { if err := DB.Where(&user).First(&User{}).Error; err != nil {
t.Errorf("search with struct with association should returns no error, but got %v", err) t.Errorf("search with struct with association should returns no error, but got %v", err)
} }