forked from mirror/gorm
Fix FirstOr(Init/Create) when assigning with association
This commit is contained in:
parent
2c4e857125
commit
2faff25dfb
|
@ -8,6 +8,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -132,8 +133,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||||
for _, expr := range exprs {
|
for _, value := range values {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []clause.Expression:
|
||||||
|
for _, expr := range v {
|
||||||
if eq, ok := expr.(clause.Eq); ok {
|
if eq, ok := expr.(clause.Eq); ok {
|
||||||
switch column := eq.Column.(type) {
|
switch column := eq.Column.(type) {
|
||||||
case string:
|
case string:
|
||||||
|
@ -148,28 +152,51 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||||
|
exprs := tx.Statement.BuildCondition(value)
|
||||||
|
tx.assignInterfacesToValue(exprs)
|
||||||
|
default:
|
||||||
|
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
for _, f := range s.Fields {
|
||||||
|
if f.Readable {
|
||||||
|
if v, isZero := f.ValueOf(reflectValue); !isZero {
|
||||||
|
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||||
|
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if len(values) > 0 {
|
||||||
|
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
|
||||||
|
tx.assignInterfacesToValue(exprs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
|
if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
|
||||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := c.Expression.(clause.Where); ok {
|
if where, ok := c.Expression.(clause.Where); ok {
|
||||||
tx.assignExprsToValue(where.Exprs)
|
tx.assignInterfacesToValue(where.Exprs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.attrs) > 0 {
|
if len(tx.Statement.attrs) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
tx.Error = nil
|
tx.Error = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.assigns) > 0 {
|
if len(tx.Statement.assigns) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
|
|
||||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := c.Expression.(clause.Where); ok {
|
if where, ok := c.Expression.(clause.Where); ok {
|
||||||
tx.assignExprsToValue(where.Exprs)
|
tx.assignInterfacesToValue(where.Exprs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.attrs) > 0 {
|
if len(tx.Statement.attrs) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.assigns) > 0 {
|
if len(tx.Statement.assigns) > 0 {
|
||||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||||
tx.assignExprsToValue(exprs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Create(dest)
|
return tx.Create(dest)
|
||||||
|
|
|
@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) {
|
||||||
t.Fatalf("errors happened when create user: %v", err)
|
t.Fatalf("errors happened when create user: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.CreatedAt = time.Time{}
|
||||||
|
user.UpdatedAt = time.Time{}
|
||||||
if err := DB.Where(&user).First(&User{}).Error; err != nil {
|
if err := DB.Where(&user).First(&User{}).Error; err != nil {
|
||||||
t.Errorf("search with struct with association should returns no error, but got %v", err)
|
t.Errorf("search with struct with association should returns no error, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue