mirror of https://github.com/go-gorm/gorm.git
Add returning tests
This commit is contained in:
parent
835d7bde59
commit
e953880d19
|
@ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||
if !db.DryRun && db.Error == nil {
|
||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
dest := db.Statement.Dest
|
||||
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
|
||||
gorm.Scan(rows, db, mode)
|
||||
db.Statement.Dest = dest
|
||||
rows.Close()
|
||||
}
|
||||
} else {
|
||||
|
@ -152,20 +155,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var primaryKeyExprs []clause.Expression
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||
var notZero bool
|
||||
for idx, field := range stmt.Schema.PrimaryFields {
|
||||
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
|
||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||
notZero = notZero || !isZero
|
||||
}
|
||||
if notZero {
|
||||
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
||||
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||
var primaryKeyExprs []clause.Expression
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||
var notZero bool
|
||||
for idx, field := range stmt.Schema.PrimaryFields {
|
||||
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
|
||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||
notZero = notZero || !isZero
|
||||
}
|
||||
if notZero {
|
||||
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
||||
}
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
||||
}
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||
|
|
16
scan.go
16
scan.go
|
@ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
if update && db.Statement.Schema != nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
fields := make([]*schema.Field, len(columns))
|
||||
for idx, column := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
}
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
|
|
@ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
|||
}
|
||||
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
stmt.Build("UPDATE", "SET", "WHERE")
|
||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// only sqlite, postgres support returning
|
||||
func TestSoftDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
users := []*User{
|
||||
GetUser("delete-returning-1", Config{}),
|
||||
GetUser("delete-returning-2", Config{}),
|
||||
GetUser("delete-returning-3", Config{}),
|
||||
}
|
||||
DB.Create(&users)
|
||||
|
||||
var results []User
|
||||
DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("failed to return delete data, got %v", results)
|
||||
}
|
||||
|
||||
var count int64
|
||||
DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("failed to delete data, current count %v", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
companies := []Company{
|
||||
{Name: "delete-returning-1"},
|
||||
{Name: "delete-returning-2"},
|
||||
{Name: "delete-returning-3"},
|
||||
}
|
||||
DB.Create(&companies)
|
||||
|
||||
var results []Company
|
||||
DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("failed to return delete data, got %v", results)
|
||||
}
|
||||
|
||||
var count int64
|
||||
DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("failed to delete data, current count %v", count)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ require (
|
|||
github.com/jinzhu/now v1.1.2
|
||||
github.com/lib/pq v1.10.3
|
||||
gorm.io/driver/mysql v1.1.2
|
||||
gorm.io/driver/postgres v1.2.0
|
||||
gorm.io/driver/sqlite v1.2.0
|
||||
gorm.io/driver/postgres v1.2.1
|
||||
gorm.io/driver/sqlite v1.2.2
|
||||
gorm.io/driver/sqlserver v1.1.2
|
||||
gorm.io/gorm v1.22.0
|
||||
)
|
||||
|
|
|
@ -167,16 +167,13 @@ func TestUpdates(t *testing.T) {
|
|||
}
|
||||
|
||||
// update with gorm exprs
|
||||
if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||
}
|
||||
var user4 User
|
||||
DB.First(&user4, user3.ID)
|
||||
|
||||
// sqlite, postgres support returning
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
user3.Age += 100
|
||||
}
|
||||
user3.Age += 100
|
||||
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
||||
}
|
||||
|
||||
|
@ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) {
|
|||
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
|
||||
}
|
||||
}
|
||||
|
||||
// only sqlite, postgres support returning
|
||||
func TestUpdateReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
users := []*User{
|
||||
GetUser("update-returning-1", Config{}),
|
||||
GetUser("update-returning-2", Config{}),
|
||||
GetUser("update-returning-3", Config{}),
|
||||
}
|
||||
DB.Create(&users)
|
||||
|
||||
var results []User
|
||||
DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88)
|
||||
if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 {
|
||||
t.Errorf("failed to return updated data, got %v", results)
|
||||
}
|
||||
|
||||
if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||
}
|
||||
|
||||
if results[1].Age-results[0].Age != 100 {
|
||||
t.Errorf("failed to return updated age column")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue