forked from mirror/gorm
Improve returning support
This commit is contained in:
parent
d3211908a0
commit
af3fbdc2fc
|
@ -13,7 +13,6 @@ var (
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LastInsertIDReversed bool
|
LastInsertIDReversed bool
|
||||||
WithReturning bool
|
|
||||||
CreateClauses []string
|
CreateClauses []string
|
||||||
QueryClauses []string
|
QueryClauses []string
|
||||||
UpdateClauses []string
|
UpdateClauses []string
|
||||||
|
@ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
return !db.SkipDefaultTransaction
|
return !db.SkipDefaultTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(config.CreateClauses) == 0 {
|
||||||
|
config.CreateClauses = createClauses
|
||||||
|
}
|
||||||
|
if len(config.QueryClauses) == 0 {
|
||||||
|
config.QueryClauses = queryClauses
|
||||||
|
}
|
||||||
|
if len(config.DeleteClauses) == 0 {
|
||||||
|
config.DeleteClauses = deleteClauses
|
||||||
|
}
|
||||||
|
if len(config.UpdateClauses) == 0 {
|
||||||
|
config.UpdateClauses = updateClauses
|
||||||
|
}
|
||||||
|
|
||||||
createCallback := db.Callback().Create()
|
createCallback := db.Callback().Create()
|
||||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||||
|
@ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
||||||
createCallback.Register("gorm:after_create", AfterCreate)
|
createCallback.Register("gorm:after_create", AfterCreate)
|
||||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
if len(config.CreateClauses) == 0 {
|
|
||||||
config.CreateClauses = createClauses
|
|
||||||
}
|
|
||||||
createCallback.Clauses = config.CreateClauses
|
createCallback.Clauses = config.CreateClauses
|
||||||
|
|
||||||
queryCallback := db.Callback().Query()
|
queryCallback := db.Callback().Query()
|
||||||
queryCallback.Register("gorm:query", Query)
|
queryCallback.Register("gorm:query", Query)
|
||||||
queryCallback.Register("gorm:preload", Preload)
|
queryCallback.Register("gorm:preload", Preload)
|
||||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||||
if len(config.QueryClauses) == 0 {
|
|
||||||
config.QueryClauses = queryClauses
|
|
||||||
}
|
|
||||||
queryCallback.Clauses = config.QueryClauses
|
queryCallback.Clauses = config.QueryClauses
|
||||||
|
|
||||||
deleteCallback := db.Callback().Delete()
|
deleteCallback := db.Callback().Delete()
|
||||||
|
@ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
deleteCallback.Register("gorm:delete", Delete)
|
deleteCallback.Register("gorm:delete", Delete)
|
||||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
if len(config.DeleteClauses) == 0 {
|
|
||||||
config.DeleteClauses = deleteClauses
|
|
||||||
}
|
|
||||||
deleteCallback.Clauses = config.DeleteClauses
|
deleteCallback.Clauses = config.DeleteClauses
|
||||||
|
|
||||||
updateCallback := db.Callback().Update()
|
updateCallback := db.Callback().Update()
|
||||||
|
@ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
||||||
updateCallback.Register("gorm:update", Update)
|
updateCallback.Register("gorm:update", Update(config))
|
||||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
||||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
if len(config.UpdateClauses) == 0 {
|
|
||||||
config.UpdateClauses = updateClauses
|
|
||||||
}
|
|
||||||
updateCallback.Clauses = config.UpdateClauses
|
updateCallback.Clauses = config.UpdateClauses
|
||||||
|
|
||||||
rowCallback := db.Callback().Row()
|
rowCallback := db.Callback().Row()
|
||||||
|
|
|
@ -31,18 +31,35 @@ func BeforeCreate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(config *Config) func(db *gorm.DB) {
|
func Create(config *Config) func(db *gorm.DB) {
|
||||||
if config.WithReturning {
|
withReturning := false
|
||||||
return CreateWithReturning
|
for _, clause := range config.CreateClauses {
|
||||||
|
if clause == "RETURNING" {
|
||||||
|
withReturning = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(db *gorm.DB) {
|
return func(db *gorm.DB) {
|
||||||
if db.Error != nil {
|
if db.Error != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
onReturning := false
|
||||||
|
|
||||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
if db.Statement.Schema != nil {
|
||||||
for _, c := range db.Statement.Schema.CreateClauses {
|
if !db.Statement.Unscoped {
|
||||||
db.Statement.AddClause(c)
|
for _, c := range db.Statement.Schema.CreateClauses {
|
||||||
|
db.Statement.AddClause(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||||
|
onReturning = true
|
||||||
|
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||||
|
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||||
|
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||||
|
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||||
|
}
|
||||||
|
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,180 +72,70 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if onReturning {
|
||||||
|
doNothing := false
|
||||||
if err != nil {
|
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
||||||
db.AddError(err)
|
onConflict, _ := c.Expression.(clause.OnConflict)
|
||||||
return
|
doNothing = onConflict.DoNothing
|
||||||
}
|
}
|
||||||
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
if doNothing {
|
||||||
|
gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing)
|
||||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
} else {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
gorm.Scan(rows, db, gorm.ScanUpdate)
|
||||||
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
|
||||||
case reflect.Slice, reflect.Array:
|
|
||||||
if config.LastInsertIDReversed {
|
|
||||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
|
||||||
rv := db.Statement.ReflectValue.Index(i)
|
|
||||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
|
||||||
if isZero {
|
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
|
||||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
|
||||||
rv := db.Statement.ReflectValue.Index(i)
|
|
||||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
|
||||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
rows.Close()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
return
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateWithReturning(db *gorm.DB) {
|
|
||||||
if db.Error == nil {
|
|
||||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
|
||||||
for _, c := range db.Statement.Schema.CreateClauses {
|
|
||||||
db.Statement.AddClause(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
|
||||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
|
||||||
|
|
||||||
db.Statement.Build(db.Statement.BuildClauses...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
|
||||||
db.Statement.WriteString(" RETURNING ")
|
|
||||||
|
|
||||||
var (
|
|
||||||
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
|
||||||
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, field := range sch.FieldsWithDefaultDBValue {
|
|
||||||
if idx > 0 {
|
|
||||||
db.Statement.WriteByte(',')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fields[idx] = field
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
db.Statement.WriteQuoted(field.DBName)
|
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
||||||
}
|
db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||||
|
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
||||||
if !db.DryRun && db.Error == nil {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
db.RowsAffected = 0
|
case reflect.Slice, reflect.Array:
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if config.LastInsertIDReversed {
|
||||||
|
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||||
if err == nil {
|
rv := db.Statement.ReflectValue.Index(i)
|
||||||
defer rows.Close()
|
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||||
|
break
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
|
||||||
case reflect.Slice, reflect.Array:
|
|
||||||
var (
|
|
||||||
c = db.Statement.Clauses["ON CONFLICT"]
|
|
||||||
onConflict, _ = c.Expression.(clause.OnConflict)
|
|
||||||
resetFieldValues = map[int]reflect.Value{}
|
|
||||||
)
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
BEGIN:
|
|
||||||
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
|
|
||||||
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx, field := range fields {
|
|
||||||
fieldValue := field.ReflectValueOf(reflectValue)
|
|
||||||
|
|
||||||
if onConflict.DoNothing && !fieldValue.IsZero() {
|
|
||||||
db.RowsAffected++
|
|
||||||
|
|
||||||
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
goto BEGIN
|
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
||||||
|
if isZero {
|
||||||
|
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||||
|
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.FieldType.Kind() == reflect.Ptr {
|
|
||||||
values[idx] = fieldValue.Addr().Interface()
|
|
||||||
} else {
|
|
||||||
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
|
|
||||||
reflectValue.Elem().Set(fieldValue.Addr())
|
|
||||||
values[idx] = reflectValue.Interface()
|
|
||||||
resetFieldValues[idx] = fieldValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
db.RowsAffected++
|
|
||||||
if err := rows.Scan(values...); err != nil {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx, fv := range resetFieldValues {
|
|
||||||
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
|
|
||||||
fv.Set(v.Elem())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.Struct:
|
|
||||||
resetFieldValues := map[int]reflect.Value{}
|
|
||||||
for idx, field := range fields {
|
|
||||||
if field.FieldType.Kind() == reflect.Ptr {
|
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
|
||||||
} else {
|
} else {
|
||||||
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
fieldValue := field.ReflectValueOf(db.Statement.ReflectValue)
|
rv := db.Statement.ReflectValue.Index(i)
|
||||||
reflectValue.Elem().Set(fieldValue.Addr())
|
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||||
values[idx] = reflectValue.Interface()
|
break
|
||||||
resetFieldValues[idx] = fieldValue
|
}
|
||||||
}
|
|
||||||
}
|
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||||
if rows.Next() {
|
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||||
db.RowsAffected++
|
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||||
db.AddError(rows.Scan(values...))
|
}
|
||||||
for idx, fv := range resetFieldValues {
|
|
||||||
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
|
|
||||||
fv.Set(v.Elem())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||||
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if !db.DryRun && db.Error == nil {
|
|
||||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
|
||||||
} else {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ func Query(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
gorm.Scan(rows, db, false)
|
gorm.Scan(rows, db, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,40 +50,56 @@ func BeforeUpdate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Update(db *gorm.DB) {
|
func Update(config *Config) func(db *gorm.DB) {
|
||||||
if db.Error != nil {
|
withReturning := false
|
||||||
return
|
for _, clause := range config.UpdateClauses {
|
||||||
}
|
if clause == "RETURNING" {
|
||||||
|
withReturning = true
|
||||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
|
||||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
|
||||||
db.Statement.AddClause(c)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
return func(db *gorm.DB) {
|
||||||
db.Statement.SQL.Grow(180)
|
if db.Error != nil {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
|
||||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
|
||||||
db.Statement.AddClause(set)
|
|
||||||
} else {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
db.Statement.Build(db.Statement.BuildClauses...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
db.AddError(gorm.ErrMissingWhereClause)
|
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||||
return
|
db.Statement.AddClause(c)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if db.Statement.SQL.String() == "" {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
db.Statement.SQL.Grow(180)
|
||||||
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||||
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||||
|
db.Statement.AddClause(set)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
db.Statement.Build(db.Statement.BuildClauses...)
|
||||||
|
}
|
||||||
|
|
||||||
if err == nil {
|
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.AddError(gorm.ErrMissingWhereClause)
|
||||||
} else {
|
return
|
||||||
db.AddError(err)
|
}
|
||||||
|
|
||||||
|
if !db.DryRun && db.Error == nil {
|
||||||
|
if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok {
|
||||||
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
|
gorm.Scan(rows, db, gorm.ScanUpdate)
|
||||||
|
rows.Close()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ func (onConflict OnConflict) Build(builder Builder) {
|
||||||
}
|
}
|
||||||
builder.WriteString(`) `)
|
builder.WriteString(`) `)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||||
builder.WriteString(" WHERE ")
|
builder.WriteString(" WHERE ")
|
||||||
onConflict.TargetWhere.Build(builder)
|
onConflict.TargetWhere.Build(builder)
|
||||||
|
|
|
@ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||||
}
|
}
|
||||||
tx.Statement.ReflectValue = elem
|
tx.Statement.ReflectValue = elem
|
||||||
}
|
}
|
||||||
Scan(rows, tx, true)
|
Scan(rows, tx, ScanInitialized)
|
||||||
return tx.Error
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
308
scan.go
308
scan.go
|
@ -49,13 +49,93 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||||
columns, _ := rows.Columns()
|
for idx, column := range columns {
|
||||||
values := make([]interface{}, len(columns))
|
if sch == nil {
|
||||||
|
values[idx] = reflectValue.Interface()
|
||||||
|
} else if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
|
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||||
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
|
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
|
} else if len(columns) == 1 {
|
||||||
|
sch = nil
|
||||||
|
values[idx] = reflectValue.Interface()
|
||||||
|
} else {
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
db.RowsAffected++
|
||||||
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
|
if sch != nil {
|
||||||
|
for idx, column := range columns {
|
||||||
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
|
field.Set(reflectValue, values[idx])
|
||||||
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
|
relValue := rel.Field.ReflectValueOf(reflectValue)
|
||||||
|
value := reflect.ValueOf(values[idx]).Elem()
|
||||||
|
|
||||||
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||||
|
if value.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
|
}
|
||||||
|
|
||||||
|
field.Set(relValue, values[idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ScanMode uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
ScanInitialized ScanMode = 1 << 0
|
||||||
|
ScanUpdate = 1 << 1
|
||||||
|
ScanOnConflictDoNothing = 1 << 2
|
||||||
|
)
|
||||||
|
|
||||||
|
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
|
var (
|
||||||
|
columns, _ = rows.Columns()
|
||||||
|
values = make([]interface{}, len(columns))
|
||||||
|
initialized = mode&ScanInitialized != 0
|
||||||
|
update = mode&ScanUpdate != 0
|
||||||
|
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
|
||||||
|
)
|
||||||
|
|
||||||
db.RowsAffected = 0
|
db.RowsAffected = 0
|
||||||
|
|
||||||
switch dest := db.Statement.Dest.(type) {
|
switch dest := db.Statement.Dest.(type) {
|
||||||
case map[string]interface{}, *map[string]interface{}:
|
case map[string]interface{}, *map[string]interface{}:
|
||||||
|
if update && db.Statement.Schema != nil {
|
||||||
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
fields := make([]*schema.Field, len(columns))
|
||||||
|
for idx, column := range columns {
|
||||||
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
|
fields[idx] = field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if initialized || rows.Next() {
|
||||||
|
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
columnTypes, _ := rows.ColumnTypes()
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
prepareValues(values, db, columnTypes, columns)
|
prepareValues(values, db, columnTypes, columns)
|
||||||
|
@ -71,7 +151,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
scanIntoMap(mapValue, values, columns)
|
scanIntoMap(mapValue, values, columns)
|
||||||
}
|
}
|
||||||
case *[]map[string]interface{}:
|
case *[]map[string]interface{}, []map[string]interface{}:
|
||||||
columnTypes, _ := rows.ColumnTypes()
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
prepareValues(values, db, columnTypes, columns)
|
prepareValues(values, db, columnTypes, columns)
|
||||||
|
@ -82,7 +162,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
|
|
||||||
mapValue := map[string]interface{}{}
|
mapValue := map[string]interface{}{}
|
||||||
scanIntoMap(mapValue, values, columns)
|
scanIntoMap(mapValue, values, columns)
|
||||||
*dest = append(*dest, mapValue)
|
if values, ok := dest.([]map[string]interface{}); ok {
|
||||||
|
values = append(values, mapValue)
|
||||||
|
} else if values, ok := dest.(*[]map[string]interface{}); ok {
|
||||||
|
*values = append(*values, mapValue)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case *int, *int8, *int16, *int32, *int64,
|
case *int, *int8, *int16, *int32, *int64,
|
||||||
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
||||||
|
@ -96,155 +180,109 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
db.AddError(rows.Scan(dest))
|
db.AddError(rows.Scan(dest))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
Schema := db.Statement.Schema
|
var (
|
||||||
reflectValue := db.Statement.ReflectValue
|
fields = make([]*schema.Field, len(columns))
|
||||||
|
joinFields [][2]*schema.Field
|
||||||
|
sch = db.Statement.Schema
|
||||||
|
reflectValue = db.Statement.ReflectValue
|
||||||
|
)
|
||||||
|
|
||||||
if reflectValue.Kind() == reflect.Interface {
|
if reflectValue.Kind() == reflect.Interface {
|
||||||
reflectValue = reflectValue.Elem()
|
reflectValue = reflectValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reflectValueType := reflectValue.Type()
|
||||||
|
switch reflectValueType.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
reflectValueType = reflectValueType.Elem()
|
||||||
|
}
|
||||||
|
isPtr := reflectValueType.Kind() == reflect.Ptr
|
||||||
|
if isPtr {
|
||||||
|
reflectValueType = reflectValueType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if sch != nil {
|
||||||
|
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||||
|
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, column := range columns {
|
||||||
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
|
fields[idx] = field
|
||||||
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
|
fields[idx] = field
|
||||||
|
|
||||||
|
if len(joinFields) == 0 {
|
||||||
|
joinFields = make([][2]*schema.Field, len(columns))
|
||||||
|
}
|
||||||
|
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
|
} else {
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) == 1 {
|
||||||
|
// isPluck
|
||||||
|
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||||
|
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||||
|
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||||
|
sch = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var (
|
var elem reflect.Value
|
||||||
reflectValueType = reflectValue.Type().Elem()
|
|
||||||
isPtr = reflectValueType.Kind() == reflect.Ptr
|
|
||||||
fields = make([]*schema.Field, len(columns))
|
|
||||||
joinFields [][2]*schema.Field
|
|
||||||
)
|
|
||||||
|
|
||||||
if isPtr {
|
if !update {
|
||||||
reflectValueType = reflectValueType.Elem()
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||||
}
|
|
||||||
|
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
|
||||||
|
|
||||||
if Schema != nil {
|
|
||||||
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
|
|
||||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx, column := range columns {
|
|
||||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
|
||||||
fields[idx] = field
|
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
|
||||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
fields[idx] = field
|
|
||||||
|
|
||||||
if len(joinFields) == 0 {
|
|
||||||
joinFields = make([][2]*schema.Field, len(columns))
|
|
||||||
}
|
|
||||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
} else {
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// pluck values into slice of data
|
|
||||||
isPluck := false
|
|
||||||
if len(fields) == 1 {
|
|
||||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
|
|
||||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
|
||||||
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
|
||||||
isPluck = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
|
BEGIN:
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
|
||||||
|
|
||||||
elem := reflect.New(reflectValueType)
|
if update {
|
||||||
if isPluck {
|
if int(db.RowsAffected) >= reflectValue.Len() {
|
||||||
db.AddError(rows.Scan(elem.Interface()))
|
return
|
||||||
} else {
|
|
||||||
for idx, field := range fields {
|
|
||||||
if field != nil {
|
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
elem = reflectValue.Index(int(db.RowsAffected))
|
||||||
db.AddError(rows.Scan(values...))
|
if onConflictDonothing {
|
||||||
|
for _, field := range fields {
|
||||||
for idx, field := range fields {
|
if _, ok := field.ValueOf(elem); !ok {
|
||||||
if len(joinFields) != 0 && joinFields[idx][0] != nil {
|
db.RowsAffected++
|
||||||
value := reflect.ValueOf(values[idx]).Elem()
|
goto BEGIN
|
||||||
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
|
||||||
|
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
|
||||||
if value.IsNil() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
|
||||||
}
|
|
||||||
|
|
||||||
field.Set(relValue, values[idx])
|
|
||||||
} else if field != nil {
|
|
||||||
field.Set(elem, values[idx])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isPtr {
|
|
||||||
reflectValue = reflect.Append(reflectValue, elem)
|
|
||||||
} else {
|
|
||||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Statement.ReflectValue.Set(reflectValue)
|
|
||||||
case reflect.Struct, reflect.Ptr:
|
|
||||||
if reflectValue.Type() != Schema.ModelType {
|
|
||||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
|
||||||
}
|
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
|
||||||
for idx, column := range columns {
|
|
||||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
|
||||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
values[idx] = &sql.RawBytes{}
|
}
|
||||||
} else if len(columns) == 1 {
|
} else {
|
||||||
values[idx] = dest
|
elem = reflect.New(reflectValueType)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
|
||||||
|
|
||||||
|
if !update {
|
||||||
|
if isPtr {
|
||||||
|
reflectValue = reflect.Append(reflectValue, elem)
|
||||||
} else {
|
} else {
|
||||||
values[idx] = &sql.RawBytes{}
|
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
if !update {
|
||||||
db.AddError(rows.Scan(values...))
|
db.Statement.ReflectValue.Set(reflectValue)
|
||||||
|
}
|
||||||
for idx, column := range columns {
|
case reflect.Struct, reflect.Ptr:
|
||||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
if initialized || rows.Next() {
|
||||||
field.Set(reflectValue, values[idx])
|
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
|
||||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
relValue := rel.Field.ReflectValueOf(reflectValue)
|
|
||||||
value := reflect.ValueOf(values[idx]).Elem()
|
|
||||||
|
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
|
||||||
if value.IsNil() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
|
||||||
}
|
|
||||||
|
|
||||||
field.Set(relValue, values[idx])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
db.AddError(rows.Scan(dest))
|
db.AddError(rows.Scan(dest))
|
||||||
|
|
|
@ -7,9 +7,9 @@ require (
|
||||||
github.com/jinzhu/now v1.1.2
|
github.com/jinzhu/now v1.1.2
|
||||||
github.com/lib/pq v1.10.3
|
github.com/lib/pq v1.10.3
|
||||||
gorm.io/driver/mysql v1.1.2
|
gorm.io/driver/mysql v1.1.2
|
||||||
gorm.io/driver/postgres v1.1.2
|
gorm.io/driver/postgres v1.2.0
|
||||||
gorm.io/driver/sqlite v1.1.6
|
gorm.io/driver/sqlite v1.2.0
|
||||||
gorm.io/driver/sqlserver v1.1.0
|
gorm.io/driver/sqlserver v1.1.1
|
||||||
gorm.io/gorm v1.21.16
|
gorm.io/gorm v1.21.16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/gorm/callbacks"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReturningWithNullToZeroValues(t *testing.T) {
|
func TestReturningWithNullToZeroValues(t *testing.T) {
|
||||||
|
@ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) {
|
||||||
Name string `gorm:"default:null"`
|
Name string `gorm:"default:null"`
|
||||||
}
|
}
|
||||||
u1 := user{}
|
u1 := user{}
|
||||||
c := DB.Callback().Create().Get("gorm:create")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
DB.Callback().Create().Replace("gorm:create", c)
|
|
||||||
})
|
|
||||||
DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true}))
|
|
||||||
|
|
||||||
if results := DB.Create(&u1); results.Error != nil {
|
if results := DB.Create(&u1); results.Error != nil {
|
||||||
t.Fatalf("errors happened on create: %v", results.Error)
|
t.Fatalf("errors happened on create: %v", results.Error)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
@ -166,13 +167,16 @@ func TestUpdates(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// update with gorm exprs
|
// update with gorm exprs
|
||||||
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||||
}
|
}
|
||||||
var user4 User
|
var user4 User
|
||||||
DB.First(&user4, user3.ID)
|
DB.First(&user4, user3.ID)
|
||||||
|
|
||||||
user3.Age += 100
|
// sqlite, postgres support returning
|
||||||
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||||
|
user3.Age += 100
|
||||||
|
}
|
||||||
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue