Fix FirstOr(Init/Create) when assigning with association

This commit is contained in:
Jinzhu 2020-08-13 18:38:39 +08:00
parent 2c4e857125
commit 2faff25dfb
2 changed files with 48 additions and 21 deletions

View File

@ -8,6 +8,7 @@ import (
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
@ -132,19 +133,47 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
return
}
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
default:
}
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
exprs := tx.Statement.BuildCondition(value)
tx.assignInterfacesToValue(exprs)
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
}
}
}
}
}
default:
} else if len(values) > 0 {
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
tx.assignInterfacesToValue(exprs)
return
}
}
}
@ -154,22 +183,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs)
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
tx.Error = nil
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return
}
@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs)
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs)
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return tx.Create(dest)

View File

@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) {
t.Fatalf("errors happened when create user: %v", err)
}
user.CreatedAt = time.Time{}
user.UpdatedAt = time.Time{}
if err := DB.Where(&user).First(&User{}).Error; err != nil {
t.Errorf("search with struct with association should returns no error, but got %v", err)
}