mirror of https://github.com/go-gorm/gorm.git
Add Update test
This commit is contained in:
parent
0c34123796
commit
cbd55dbcd5
|
@ -44,13 +44,14 @@ func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) (
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
|
value := mapValue[k]
|
||||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||||
k = field.DBName
|
k = field.DBName
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
columns = append(columns, k)
|
columns = append(columns, k)
|
||||||
values.Values[0] = append(values.Values[0], mapValue[k])
|
values.Values[0] = append(values.Values[0], value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -2,8 +2,10 @@ package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
"github.com/jinzhu/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BeforeUpdate(db *gorm.DB) {
|
func BeforeUpdate(db *gorm.DB) {
|
||||||
|
@ -40,6 +42,17 @@ func BeforeUpdate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Update(db *gorm.DB) {
|
func Update(db *gorm.DB) {
|
||||||
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||||
|
db.Statement.AddClause(ConvertToAssignments(db.Statement))
|
||||||
|
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||||
|
|
||||||
|
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterUpdate(db *gorm.DB) {
|
func AfterUpdate(db *gorm.DB) {
|
||||||
|
@ -74,3 +87,48 @@ func AfterUpdate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConvertToAssignments convert to update assignments
|
||||||
|
func ConvertToAssignments(stmt *gorm.Statement) clause.Set {
|
||||||
|
selectColumns, restricted := SelectAndOmitColumns(stmt)
|
||||||
|
reflectModelValue := reflect.ValueOf(stmt.Model)
|
||||||
|
|
||||||
|
switch value := stmt.Dest.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
var set clause.Set = make([]clause.Assignment, 0, len(value))
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
for k, _ := range value {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, k := range keys {
|
||||||
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||||
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
||||||
|
field.Set(reflectModelValue, value[k])
|
||||||
|
}
|
||||||
|
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
|
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return set
|
||||||
|
default:
|
||||||
|
switch stmt.ReflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
var set clause.Set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||||
|
for _, field := range stmt.Schema.FieldsByDBName {
|
||||||
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
|
value, _ := field.ValueOf(stmt.ReflectValue)
|
||||||
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||||
|
field.Set(reflectModelValue, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clause.Set{}
|
||||||
|
}
|
||||||
|
|
|
@ -18,11 +18,11 @@ func (limit Limit) Build(builder Builder) {
|
||||||
if limit.Limit > 0 {
|
if limit.Limit > 0 {
|
||||||
builder.Write("LIMIT ")
|
builder.Write("LIMIT ")
|
||||||
builder.Write(strconv.Itoa(limit.Limit))
|
builder.Write(strconv.Itoa(limit.Limit))
|
||||||
}
|
|
||||||
|
|
||||||
if limit.Offset > 0 {
|
if limit.Offset > 0 {
|
||||||
builder.Write(" OFFSET ")
|
builder.Write(" OFFSET ")
|
||||||
builder.Write(strconv.Itoa(limit.Offset))
|
builder.Write(strconv.Itoa(limit.Offset))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,11 +22,13 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// First find first record that match given conditions, order by primary key
|
// First find first record that match given conditions, order by primary key
|
||||||
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
|
func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) {
|
||||||
// TODO handle where
|
|
||||||
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
|
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
})
|
})
|
||||||
|
if len(conds) > 0 {
|
||||||
|
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
|
||||||
|
}
|
||||||
tx.Statement.RaiseErrorOnNotFound = true
|
tx.Statement.RaiseErrorOnNotFound = true
|
||||||
tx.Statement.Dest = out
|
tx.Statement.Dest = out
|
||||||
tx.callbacks.Query().Execute(tx)
|
tx.callbacks.Query().Execute(tx)
|
||||||
|
@ -34,8 +36,11 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||||
func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) {
|
func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance().Limit(1)
|
tx = db.getInstance().Limit(1)
|
||||||
|
if len(conds) > 0 {
|
||||||
|
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
|
||||||
|
}
|
||||||
tx.Statement.RaiseErrorOnNotFound = true
|
tx.Statement.RaiseErrorOnNotFound = true
|
||||||
tx.Statement.Dest = out
|
tx.Statement.Dest = out
|
||||||
tx.callbacks.Query().Execute(tx)
|
tx.callbacks.Query().Execute(tx)
|
||||||
|
@ -43,11 +48,14 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Last find last record that match given conditions, order by primary key
|
// Last find last record that match given conditions, order by primary key
|
||||||
func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) {
|
func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
|
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
Desc: true,
|
Desc: true,
|
||||||
})
|
})
|
||||||
|
if len(conds) > 0 {
|
||||||
|
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
|
||||||
|
}
|
||||||
tx.Statement.RaiseErrorOnNotFound = true
|
tx.Statement.RaiseErrorOnNotFound = true
|
||||||
tx.Statement.Dest = out
|
tx.Statement.Dest = out
|
||||||
tx.callbacks.Query().Execute(tx)
|
tx.callbacks.Query().Execute(tx)
|
||||||
|
@ -55,8 +63,11 @@ func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find find records that match given conditions
|
// Find find records that match given conditions
|
||||||
func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
|
func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
if len(conds) > 0 {
|
||||||
|
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
|
||||||
|
}
|
||||||
tx.Statement.Dest = out
|
tx.Statement.Dest = out
|
||||||
tx.callbacks.Query().Execute(tx)
|
tx.callbacks.Query().Execute(tx)
|
||||||
return
|
return
|
||||||
|
@ -75,22 +86,30 @@ func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
|
||||||
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||||
|
tx.callbacks.Update().Execute(tx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.Statement.Dest = values
|
||||||
|
tx.callbacks.Update().Execute(tx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||||
|
tx.callbacks.Update().Execute(tx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.Statement.Dest = values
|
||||||
|
tx.callbacks.Update().Execute(tx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ func Now() *time.Time {
|
||||||
func RunTestsSuit(t *testing.T, db *gorm.DB) {
|
func RunTestsSuit(t *testing.T, db *gorm.DB) {
|
||||||
TestCreate(t, db)
|
TestCreate(t, db)
|
||||||
TestFind(t, db)
|
TestFind(t, db)
|
||||||
|
TestUpdate(t, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreate(t *testing.T, db *gorm.DB) {
|
func TestCreate(t *testing.T, db *gorm.DB) {
|
||||||
|
@ -133,3 +134,62 @@ func TestFind(t *testing.T, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdate(t *testing.T, db *gorm.DB) {
|
||||||
|
db.Migrator().DropTable(&User{})
|
||||||
|
db.AutoMigrate(&User{})
|
||||||
|
|
||||||
|
t.Run("Update", func(t *testing.T) {
|
||||||
|
var user = User{
|
||||||
|
Name: "create",
|
||||||
|
Age: 18,
|
||||||
|
Birthday: Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(&user).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Model(&user).Update("Age", 10).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when update: %v", err)
|
||||||
|
} else if user.Age != 10 {
|
||||||
|
t.Errorf("Age should equals to 10, but got %v", user.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query: %v", err)
|
||||||
|
} else {
|
||||||
|
AssertObjEqual(t, result, user, "Name", "Age", "Birthday")
|
||||||
|
}
|
||||||
|
|
||||||
|
values := map[string]interface{}{"Active": true, "age": 5}
|
||||||
|
if err := db.Model(&user).Updates(values).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when update: %v", err)
|
||||||
|
} else if user.Age != 5 {
|
||||||
|
t.Errorf("Age should equals to 5, but got %v", user.Age)
|
||||||
|
} else if user.Active != true {
|
||||||
|
t.Errorf("Active should be true, but got %v", user.Active)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result2 User
|
||||||
|
if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query: %v", err)
|
||||||
|
} else {
|
||||||
|
AssertObjEqual(t, result2, user, "Name", "Age", "Birthday")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Model(&user).Updates(User{Age: 2}).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when update: %v", err)
|
||||||
|
} else if user.Age != 2 {
|
||||||
|
t.Errorf("Age should equals to 2, but got %v", user.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result3 User
|
||||||
|
if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query: %v", err)
|
||||||
|
} else {
|
||||||
|
AssertObjEqual(t, result3, user, "Name", "Age", "Birthday")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue