diff --git a/finisher_api.go b/finisher_api.go index 33a4f121..8a3d4199 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -132,19 +133,47 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return } -func (tx *DB) assignExprsToValue(exprs []clause.Expression) { - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) +func (tx *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + default: + } } - case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + exprs := tx.Statement.BuildCondition(value) + tx.assignInterfacesToValue(exprs) + default: + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(reflectValue); !isZero { + if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + } + } + } + } } - default: + } else if len(values) > 0 { + exprs := tx.Statement.BuildCondition(values[0], values[1:]...) + tx.assignInterfacesToValue(exprs) + return } } } @@ -154,22 +183,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } tx.Error = nil } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return } @@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return tx.Create(dest) diff --git a/tests/query_test.go b/tests/query_test.go index 4c2a2abd..72dd89b9 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) { t.Fatalf("errors happened when create user: %v", err) } + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} if err := DB.Where(&user).First(&User{}).Error; err != nil { t.Errorf("search with struct with association should returns no error, but got %v", err) }