mirror of https://github.com/go-gorm/gorm.git
Fix FirstOr(Init/Create) when assigning with association
This commit is contained in:
parent
2c4e857125
commit
2faff25dfb
|
@ -8,6 +8,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
|
@ -132,19 +133,47 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||
return
|
||||
}
|
||||
|
||||
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := tx.Statement.Schema.LookUpField(column); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
for _, value := range values {
|
||||
switch v := value.(type) {
|
||||
case []clause.Expression:
|
||||
for _, expr := range v {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := tx.Statement.Schema.LookUpField(column); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
case clause.Column:
|
||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
case clause.Column:
|
||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||
exprs := tx.Statement.BuildCondition(value)
|
||||
tx.assignInterfacesToValue(exprs)
|
||||
default:
|
||||
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, f := range s.Fields {
|
||||
if f.Readable {
|
||||
if v, isZero := f.ValueOf(reflectValue); !isZero {
|
||||
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
} else if len(values) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
|
||||
tx.assignInterfacesToValue(exprs)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,22 +183,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignExprsToValue(where.Exprs)
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
tx.Error = nil
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignExprsToValue(where.Exprs)
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
||||
tx.assignExprsToValue(exprs)
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
|
|
|
@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) {
|
|||
t.Fatalf("errors happened when create user: %v", err)
|
||||
}
|
||||
|
||||
user.CreatedAt = time.Time{}
|
||||
user.UpdatedAt = time.Time{}
|
||||
if err := DB.Where(&user).First(&User{}).Error; err != nil {
|
||||
t.Errorf("search with struct with association should returns no error, but got %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue