diff --git a/chainable_api.go b/chainable_api.go index 6b91c9ad..8336b787 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -224,13 +224,15 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { return } -func (db *DB) Assign(attrs ...interface{}) (tx *DB) { +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.attrs = attrs return } -func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.assigns = attrs return } diff --git a/clause/expression.go b/clause/expression.go index 872736ce..067774d4 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -171,14 +171,3 @@ func (like Like) NegationBuild(builder Builder) { builder.WriteString(" NOT LIKE ") builder.AddVar(builder, like.Value) } - -// Map -type Map map[interface{}]interface{} - -func (m Map) Build(builder Builder) { - // TODO -} - -func (m Map) NegationBuild(builder Builder) { - // TODO -} diff --git a/finisher_api.go b/finisher_api.go index 04b25ed2..2590e422 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "errors" "reflect" "strings" @@ -99,13 +100,56 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return } -func (db *DB) FirstOrInit(dest interface{}, where ...interface{}) (tx *DB) { +func (tx *DB) assignExprsToValue(exprs []clause.Expression) { + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + field.Set(tx.Statement.ReflectValue, eq.Value) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + field.Set(tx.Statement.ReflectValue, eq.Value) + } + default: + } + } + } +} + +func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() + if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignExprsToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]) + tx.assignExprsToValue(exprs) + } + tx.Error = nil + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + tx.assignExprsToValue(exprs) + } return } func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() + // if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { + // // initialize with attrs, conds + // } + + // assign dest return } diff --git a/statement.go b/statement.go index d37622dd..51dea6fc 100644 --- a/statement.go +++ b/statement.go @@ -34,6 +34,8 @@ type Statement struct { SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg + attrs []interface{} + assigns []interface{} } // StatementModifier statement modifier interface @@ -195,7 +197,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } // BuildCondtion build condition -func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { +func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) { if sql, ok := query.(string); ok { if i, err := strconv.Atoi(sql); err == nil { query = i @@ -212,42 +214,53 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con switch v := arg.(type) { case clause.Expression: - conditions = append(conditions, v) + conds = append(conds, v) case *DB: if v.Statement == nil { if cs, ok := v.Statement.Clauses["WHERE"]; ok { - conditions = append(conditions, cs.Expression) + conds = append(conds, cs.Expression) } } case map[interface{}]interface{}: - var clauseMap = clause.Map{} for i, j := range v { - clauseMap[i] = j + conds = append(conds, clause.Eq{Column: i, Value: j}) } - conditions = append(conditions, clauseMap) case map[string]string: - var clauseMap = clause.Map{} for i, j := range v { - clauseMap[i] = j + conds = append(conds, clause.Eq{Column: i, Value: j}) } - conditions = append(conditions, clauseMap) case map[string]interface{}: - var clauseMap = clause.Map{} for i, j := range v { - clauseMap[i] = j + conds = append(conds, clause.Eq{Column: i, Value: j}) } - conditions = append(conditions, clauseMap) default: - // TODO check is struct - // struct, slice -> ids + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.FieldsByDBName { + if v, isZero := field.ValueOf(reflectValue); !isZero { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for _, field := range s.FieldsByDBName { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } + } + } + } + } } } - if len(conditions) == 0 { - conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args}) + if len(conds) == 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } - return conditions + return } // Build build sql with clauses names @@ -337,7 +350,7 @@ func (stmt *Statement) reinit() { // return true // }) - stmt.Schema = nil + // stmt.Schema = nil stmt.SQL.Reset() stmt.Vars = nil stmt.NamedVars = nil diff --git a/tests/upsert_test.go b/tests/upsert_test.go new file mode 100644 index 00000000..728550d5 --- /dev/null +++ b/tests/upsert_test.go @@ -0,0 +1,56 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestFindOrInitialize(t *testing.T) { + var user1, user2, user3, user4, user5, user6 User + if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) + if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) + if user3.Name != "find or init 2" || user3.ID != 0 { + t.Errorf("user should be initialized with inline search value") + } + + DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and attrs") + } + + DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and assign attrs") + } + + DB.Save(&User{Name: "find or init", Age: 33}) + DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 { + t.Errorf("user should be found with FirstOrInit") + } + + DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } +} + +func TestFindOrCreate(t *testing.T) { +}