Improve returning support

This commit is contained in:
Jinzhu 2021-10-26 22:36:37 +08:00
parent d3211908a0
commit af3fbdc2fc
10 changed files with 315 additions and 355 deletions

View File

@ -13,7 +13,6 @@ var (
type Config struct {
LastInsertIDReversed bool
WithReturning bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
@ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
return !db.SkipDefaultTransaction
}
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
@ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete()
@ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update()
@ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update)
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row()

View File

@ -31,21 +31,38 @@ func BeforeCreate(db *gorm.DB) {
}
func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning {
return CreateWithReturning
withReturning := false
for _, clause := range config.CreateClauses {
if clause == "RETURNING" {
withReturning = true
}
}
return func(db *gorm.DB) {
if db.Error != nil {
return
}
onReturning := false
if db.Statement.Schema != nil && !db.Statement.Unscoped {
if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
onReturning = true
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
@ -55,6 +72,21 @@ func Create(config *Config) func(db *gorm.DB) {
}
if !db.DryRun && db.Error == nil {
if onReturning {
doNothing := false
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
onConflict, _ := c.Expression.(clause.OnConflict)
doNothing = onConflict.DoNothing
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
if doNothing {
gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing)
} else {
gorm.Scan(rows, db, gorm.ScanUpdate)
}
rows.Close()
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
@ -63,7 +95,6 @@ func Create(config *Config) func(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
@ -107,130 +138,6 @@ func Create(config *Config) func(db *gorm.DB) {
}
}
}
func CreateWithReturning(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ")
var (
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
)
for idx, field := range sch.FieldsWithDefaultDBValue {
if idx > 0 {
db.Statement.WriteByte(',')
}
fields[idx] = field
db.Statement.WriteQuoted(field.DBName)
}
if !db.DryRun && db.Error == nil {
db.RowsAffected = 0
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
defer rows.Close()
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
c = db.Statement.Clauses["ON CONFLICT"]
onConflict, _ = c.Expression.(clause.OnConflict)
resetFieldValues = map[int]reflect.Value{}
)
for rows.Next() {
BEGIN:
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
break
}
for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue)
if onConflict.DoNothing && !fieldValue.IsZero() {
db.RowsAffected++
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
return
}
goto BEGIN
}
if field.FieldType.Kind() == reflect.Ptr {
values[idx] = fieldValue.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
reflectValue.Elem().Set(fieldValue.Addr())
values[idx] = reflectValue.Interface()
resetFieldValues[idx] = fieldValue
}
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
for idx, fv := range resetFieldValues {
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
fv.Set(v.Elem())
}
}
}
case reflect.Struct:
resetFieldValues := map[int]reflect.Value{}
for idx, field := range fields {
if field.FieldType.Kind() == reflect.Ptr {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
fieldValue := field.ReflectValueOf(db.Statement.ReflectValue)
reflectValue.Elem().Set(fieldValue.Addr())
values[idx] = reflectValue.Interface()
resetFieldValues[idx] = fieldValue
}
}
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx, fv := range resetFieldValues {
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
fv.Set(v.Elem())
}
}
}
}
} else {
db.AddError(err)
}
}
} else if !db.DryRun && db.Error == nil {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterCreate(db *gorm.DB) {

View File

@ -22,7 +22,7 @@ func Query(db *gorm.DB) {
}
defer rows.Close()
gorm.Scan(rows, db, false)
gorm.Scan(rows, db, 0)
}
}
}

View File

@ -50,7 +50,15 @@ func BeforeUpdate(db *gorm.DB) {
}
}
func Update(db *gorm.DB) {
func Update(config *Config) func(db *gorm.DB) {
withReturning := false
for _, clause := range config.UpdateClauses {
if clause == "RETURNING" {
withReturning = true
}
}
return func(db *gorm.DB) {
if db.Error != nil {
return
}
@ -78,6 +86,12 @@ func Update(db *gorm.DB) {
}
if !db.DryRun && db.Error == nil {
if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, gorm.ScanUpdate)
rows.Close()
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
@ -87,6 +101,8 @@ func Update(db *gorm.DB) {
}
}
}
}
}
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {

View File

@ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
}
tx.Statement.ReflectValue = elem
}
Scan(rows, tx, true)
Scan(rows, tx, ScanInitialized)
return tx.Error
}

356
scan.go
View File

@ -49,165 +49,14 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
}
}
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue, ok := dest.(map[string]interface{})
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
mapValue = *v
}
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(dest))
}
default:
Schema := db.Statement.Schema
reflectValue := db.Statement.ReflectValue
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
reflectValueType = reflectValue.Type().Elem()
isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
)
if isPtr {
reflectValueType = reflectValueType.Elem()
}
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
if Schema != nil {
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
// pluck values into slice of data
isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
isPluck = true
}
}
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType)
if isPluck {
db.AddError(rows.Scan(elem.Interface()))
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
}
db.AddError(rows.Scan(values...))
for idx, field := range fields {
if len(joinFields) != 0 && joinFields[idx][0] != nil {
value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
} else if field != nil {
field.Set(elem, values[idx])
}
}
}
if isPtr {
reflectValue = reflect.Append(reflectValue, elem)
} else {
reflectValue = reflect.Append(reflectValue, elem.Elem())
}
}
db.Statement.ReflectValue.Set(reflectValue)
case reflect.Struct, reflect.Ptr:
if reflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
if sch == nil {
values[idx] = reflectValue.Interface()
} else if field := sch.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
@ -215,7 +64,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
values[idx] = dest
sch = nil
values[idx] = reflectValue.Interface()
} else {
values[idx] = &sql.RawBytes{}
}
@ -224,11 +74,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
db.RowsAffected++
db.AddError(rows.Scan(values...))
if sch != nil {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
if field := sch.LookUpField(column); field != nil && field.Readable {
field.Set(reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(reflectValue)
value := reflect.ValueOf(values[idx]).Elem()
@ -246,6 +97,193 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
}
}
}
type ScanMode uint8
const (
ScanInitialized ScanMode = 1 << 0
ScanUpdate = 1 << 1
ScanOnConflictDoNothing = 1 << 2
)
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if update && db.Statement.Schema != nil {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
fields := make([]*schema.Field, len(columns))
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
}
}
if initialized || rows.Next() {
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
}
}
}
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue, ok := dest.(map[string]interface{})
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
mapValue = *v
}
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}, []map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
if values, ok := dest.([]map[string]interface{}); ok {
values = append(values, mapValue)
} else if values, ok := dest.(*[]map[string]interface{}); ok {
*values = append(*values, mapValue)
}
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(dest))
}
default:
var (
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
if len(columns) == 1 {
// isPluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var elem reflect.Value
if !update {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
}
for initialized || rows.Next() {
BEGIN:
initialized = false
if update {
if int(db.RowsAffected) >= reflectValue.Len() {
return
}
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing {
for _, field := range fields {
if _, ok := field.ValueOf(elem); !ok {
db.RowsAffected++
goto BEGIN
}
}
}
} else {
elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
if !update {
if isPtr {
reflectValue = reflect.Append(reflectValue, elem)
} else {
reflectValue = reflect.Append(reflectValue, elem.Elem())
}
}
}
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))
}

View File

@ -7,9 +7,9 @@ require (
github.com/jinzhu/now v1.1.2
github.com/lib/pq v1.10.3
gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.1.2
gorm.io/driver/sqlite v1.1.6
gorm.io/driver/sqlserver v1.1.0
gorm.io/driver/postgres v1.2.0
gorm.io/driver/sqlite v1.2.0
gorm.io/driver/sqlserver v1.1.1
gorm.io/gorm v1.21.16
)

View File

@ -1,9 +1,9 @@
package tests_test
import (
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"testing"
"gorm.io/gorm"
)
func TestReturningWithNullToZeroValues(t *testing.T) {
@ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) {
Name string `gorm:"default:null"`
}
u1 := user{}
c := DB.Callback().Create().Get("gorm:create")
t.Cleanup(func() {
DB.Callback().Create().Replace("gorm:create", c)
})
DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true}))
if results := DB.Create(&u1); results.Error != nil {
t.Fatalf("errors happened on create: %v", results.Error)

View File

@ -9,6 +9,7 @@ import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
. "gorm.io/gorm/utils/tests"
)
@ -166,13 +167,16 @@ func TestUpdates(t *testing.T) {
}
// update with gorm exprs
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
var user4 User
DB.First(&user4, user3.ID)
// sqlite, postgres support returning
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
user3.Age += 100
}
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
}