forked from mirror/gorm
Add returning tests
This commit is contained in:
parent
835d7bde59
commit
e953880d19
|
@ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
|
dest := db.Statement.Dest
|
||||||
|
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
|
||||||
gorm.Scan(rows, db, mode)
|
gorm.Scan(rows, db, mode)
|
||||||
|
db.Statement.Dest = dest
|
||||||
rows.Close()
|
rows.Close()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -152,20 +155,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var primaryKeyExprs []clause.Expression
|
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
var primaryKeyExprs []clause.Expression
|
||||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||||
var notZero bool
|
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||||
for idx, field := range stmt.Schema.PrimaryFields {
|
var notZero bool
|
||||||
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
|
for idx, field := range stmt.Schema.PrimaryFields {
|
||||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
|
||||||
notZero = notZero || !isZero
|
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||||
}
|
notZero = notZero || !isZero
|
||||||
if notZero {
|
}
|
||||||
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
if notZero {
|
||||||
|
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
||||||
}
|
}
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||||
|
|
16
scan.go
16
scan.go
|
@ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
|
|
||||||
switch dest := db.Statement.Dest.(type) {
|
switch dest := db.Statement.Dest.(type) {
|
||||||
case map[string]interface{}, *map[string]interface{}:
|
case map[string]interface{}, *map[string]interface{}:
|
||||||
if update && db.Statement.Schema != nil {
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
|
||||||
case reflect.Struct:
|
|
||||||
fields := make([]*schema.Field, len(columns))
|
|
||||||
for idx, column := range columns {
|
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
|
||||||
fields[idx] = field
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
|
||||||
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
columnTypes, _ := rows.ColumnTypes()
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
prepareValues(values, db, columnTypes, columns)
|
prepareValues(values, db, columnTypes, columns)
|
||||||
|
|
|
@ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt.AddClauseIfNotExists(clause.Update{})
|
stmt.AddClauseIfNotExists(clause.Update{})
|
||||||
stmt.Build("UPDATE", "SET", "WHERE")
|
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// only sqlite, postgres support returning
|
||||||
|
func TestSoftDeleteReturning(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
users := []*User{
|
||||||
|
GetUser("delete-returning-1", Config{}),
|
||||||
|
GetUser("delete-returning-2", Config{}),
|
||||||
|
GetUser("delete-returning-3", Config{}),
|
||||||
|
}
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var results []User
|
||||||
|
DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results)
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("failed to return delete data, got %v", results)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count)
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("failed to delete data, current count %v", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteReturning(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
companies := []Company{
|
||||||
|
{Name: "delete-returning-1"},
|
||||||
|
{Name: "delete-returning-2"},
|
||||||
|
{Name: "delete-returning-3"},
|
||||||
|
}
|
||||||
|
DB.Create(&companies)
|
||||||
|
|
||||||
|
var results []Company
|
||||||
|
DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results)
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("failed to return delete data, got %v", results)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count)
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("failed to delete data, current count %v", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ require (
|
||||||
github.com/jinzhu/now v1.1.2
|
github.com/jinzhu/now v1.1.2
|
||||||
github.com/lib/pq v1.10.3
|
github.com/lib/pq v1.10.3
|
||||||
gorm.io/driver/mysql v1.1.2
|
gorm.io/driver/mysql v1.1.2
|
||||||
gorm.io/driver/postgres v1.2.0
|
gorm.io/driver/postgres v1.2.1
|
||||||
gorm.io/driver/sqlite v1.2.0
|
gorm.io/driver/sqlite v1.2.2
|
||||||
gorm.io/driver/sqlserver v1.1.2
|
gorm.io/driver/sqlserver v1.1.2
|
||||||
gorm.io/gorm v1.22.0
|
gorm.io/gorm v1.22.0
|
||||||
)
|
)
|
||||||
|
|
|
@ -167,16 +167,13 @@ func TestUpdates(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// update with gorm exprs
|
// update with gorm exprs
|
||||||
if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||||
}
|
}
|
||||||
var user4 User
|
var user4 User
|
||||||
DB.First(&user4, user3.ID)
|
DB.First(&user4, user3.ID)
|
||||||
|
|
||||||
// sqlite, postgres support returning
|
user3.Age += 100
|
||||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
|
||||||
user3.Age += 100
|
|
||||||
}
|
|
||||||
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) {
|
||||||
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
|
t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// only sqlite, postgres support returning
|
||||||
|
func TestUpdateReturning(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
users := []*User{
|
||||||
|
GetUser("update-returning-1", Config{}),
|
||||||
|
GetUser("update-returning-2", Config{}),
|
||||||
|
GetUser("update-returning-3", Config{}),
|
||||||
|
}
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var results []User
|
||||||
|
DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88)
|
||||||
|
if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 {
|
||||||
|
t.Errorf("failed to return updated data, got %v", results)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||||
|
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||||
|
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[1].Age-results[0].Age != 100 {
|
||||||
|
t.Errorf("failed to return updated age column")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue