diff --git a/callbacks/helper.go b/callbacks/helper.go index 56c0767d..baad2302 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -44,13 +44,14 @@ func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) ( sort.Strings(keys) for _, k := range keys { + value := mapValue[k] if field := stmt.Schema.LookUpField(k); field != nil { k = field.DBName } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { columns = append(columns, k) - values.Values[0] = append(values.Values[0], mapValue[k]) + values.Values[0] = append(values.Values[0], value) } } return diff --git a/callbacks/update.go b/callbacks/update.go index 82df3e81..9e1e9b78 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -2,8 +2,10 @@ package callbacks import ( "reflect" + "sort" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) func BeforeUpdate(db *gorm.DB) { @@ -40,6 +42,17 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Update{}) + db.Statement.AddClause(ConvertToAssignments(db.Statement)) + db.Statement.Build("UPDATE", "SET", "WHERE") + + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } func AfterUpdate(db *gorm.DB) { @@ -74,3 +87,48 @@ func AfterUpdate(db *gorm.DB) { } } } + +// ConvertToAssignments convert to update assignments +func ConvertToAssignments(stmt *gorm.Statement) clause.Set { + selectColumns, restricted := SelectAndOmitColumns(stmt) + reflectModelValue := reflect.ValueOf(stmt.Model) + + switch value := stmt.Dest.(type) { + case map[string]interface{}: + var set clause.Set = make([]clause.Assignment, 0, len(value)) + + var keys []string + for k, _ := range value { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + if field := stmt.Schema.LookUpField(k); field != nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + field.Set(reflectModelValue, value[k]) + } + } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) + } + } + + return set + default: + switch stmt.ReflectValue.Kind() { + case reflect.Struct: + var set clause.Set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + for _, field := range stmt.Schema.FieldsByDBName { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + value, _ := field.ValueOf(stmt.ReflectValue) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + field.Set(reflectModelValue, value) + } + } + return set + } + } + + return clause.Set{} +} diff --git a/clause/limit.go b/clause/limit.go index 7775e6bf..e30666af 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -18,11 +18,11 @@ func (limit Limit) Build(builder Builder) { if limit.Limit > 0 { builder.Write("LIMIT ") builder.Write(strconv.Itoa(limit.Limit)) - } - if limit.Offset > 0 { - builder.Write(" OFFSET ") - builder.Write(strconv.Itoa(limit.Offset)) + if limit.Offset > 0 { + builder.Write(" OFFSET ") + builder.Write(strconv.Itoa(limit.Offset)) + } } } diff --git a/finisher_api.go b/finisher_api.go index c918c08a..e2f89cf0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -22,11 +22,13 @@ func (db *DB) Save(value interface{}) (tx *DB) { } // First find first record that match given conditions, order by primary key -func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { - // TODO handle where +func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) @@ -34,8 +36,11 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { } // Take return a record that match given conditions, the order will depend on the database implementation -func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) @@ -43,11 +48,14 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { } // Last find last record that match given conditions, order by primary key -func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) @@ -55,8 +63,11 @@ func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { } // Find find records that match given conditions -func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) return @@ -75,22 +86,30 @@ func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.callbacks.Update().Execute(tx) return } // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = values + tx.callbacks.Update().Execute(tx) return } func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.callbacks.Update().Execute(tx) return } func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = values + tx.callbacks.Update().Execute(tx) return } diff --git a/tests/tests.go b/tests/tests.go index 2f0dfd34..18207268 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -18,6 +18,7 @@ func Now() *time.Time { func RunTestsSuit(t *testing.T, db *gorm.DB) { TestCreate(t, db) TestFind(t, db) + TestUpdate(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { @@ -133,3 +134,62 @@ func TestFind(t *testing.T, db *gorm.DB) { } }) } + +func TestUpdate(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Update", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + if err := db.Model(&user).Update("Age", 10).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 10 { + t.Errorf("Age should equals to 10, but got %v", user.Age) + } + + var result User + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result, user, "Name", "Age", "Birthday") + } + + values := map[string]interface{}{"Active": true, "age": 5} + if err := db.Model(&user).Updates(values).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 5 { + t.Errorf("Age should equals to 5, but got %v", user.Age) + } else if user.Active != true { + t.Errorf("Active should be true, but got %v", user.Active) + } + + var result2 User + if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") + } + + if err := db.Model(&user).Updates(User{Age: 2}).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 2 { + t.Errorf("Age should equals to 2, but got %v", user.Age) + } + + var result3 User + if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") + } + }) +}