Improve returning support

This commit is contained in:
Jinzhu 2021-10-26 22:36:37 +08:00
parent d3211908a0
commit af3fbdc2fc
10 changed files with 315 additions and 355 deletions

View File

@ -13,7 +13,6 @@ var (
type Config struct { type Config struct {
LastInsertIDReversed bool LastInsertIDReversed bool
WithReturning bool
CreateClauses []string CreateClauses []string
QueryClauses []string QueryClauses []string
UpdateClauses []string UpdateClauses []string
@ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
return !db.SkipDefaultTransaction return !db.SkipDefaultTransaction
} }
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create() createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:before_create", BeforeCreate)
@ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate) createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
createCallback.Clauses = config.CreateClauses createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query() queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Register("gorm:after_query", AfterQuery)
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
queryCallback.Clauses = config.QueryClauses queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete() deleteCallback := db.Callback().Delete()
@ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
deleteCallback.Clauses = config.DeleteClauses deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update() updateCallback := db.Callback().Update()
@ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update) updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
updateCallback.Clauses = config.UpdateClauses updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row() rowCallback := db.Callback().Row()

View File

@ -31,18 +31,35 @@ func BeforeCreate(db *gorm.DB) {
} }
func Create(config *Config) func(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning { withReturning := false
return CreateWithReturning for _, clause := range config.CreateClauses {
if clause == "RETURNING" {
withReturning = true
}
} }
return func(db *gorm.DB) { return func(db *gorm.DB) {
if db.Error != nil { if db.Error != nil {
return return
} }
onReturning := false
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.CreateClauses { if !db.Statement.Unscoped {
db.Statement.AddClause(c) for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
onReturning = true
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
} }
} }
@ -55,180 +72,70 @@ func Create(config *Config) func(db *gorm.DB) {
} }
if !db.DryRun && db.Error == nil { if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if onReturning {
doNothing := false
if err != nil { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
db.AddError(err) onConflict, _ := c.Expression.(clause.OnConflict)
return doNothing = onConflict.DoNothing
} }
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected() if doNothing {
gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing)
if db.RowsAffected != 0 && db.Statement.Schema != nil && } else {
db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { gorm.Scan(rows, db, gorm.ScanUpdate)
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
} }
} else { rows.Close()
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err) db.AddError(err)
} return
}
}
}
}
func CreateWithReturning(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ")
var (
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
)
for idx, field := range sch.FieldsWithDefaultDBValue {
if idx > 0 {
db.Statement.WriteByte(',')
} }
fields[idx] = field db.RowsAffected, _ = result.RowsAffected()
db.Statement.WriteQuoted(field.DBName) if db.RowsAffected != 0 && db.Statement.Schema != nil &&
} db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
if !db.DryRun && db.Error == nil { switch db.Statement.ReflectValue.Kind() {
db.RowsAffected = 0 case reflect.Slice, reflect.Array:
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
if err == nil { rv := db.Statement.ReflectValue.Index(i)
defer rows.Close() if reflect.Indirect(rv).Kind() != reflect.Struct {
break
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
c = db.Statement.Clauses["ON CONFLICT"]
onConflict, _ = c.Expression.(clause.OnConflict)
resetFieldValues = map[int]reflect.Value{}
)
for rows.Next() {
BEGIN:
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
break
}
for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue)
if onConflict.DoNothing && !fieldValue.IsZero() {
db.RowsAffected++
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
return
} }
goto BEGIN _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
} }
if field.FieldType.Kind() == reflect.Ptr {
values[idx] = fieldValue.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
reflectValue.Elem().Set(fieldValue.Addr())
values[idx] = reflectValue.Interface()
resetFieldValues[idx] = fieldValue
}
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
for idx, fv := range resetFieldValues {
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
fv.Set(v.Elem())
}
}
}
case reflect.Struct:
resetFieldValues := map[int]reflect.Value{}
for idx, field := range fields {
if field.FieldType.Kind() == reflect.Ptr {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else { } else {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
fieldValue := field.ReflectValueOf(db.Statement.ReflectValue) rv := db.Statement.ReflectValue.Index(i)
reflectValue.Elem().Set(fieldValue.Addr()) if reflect.Indirect(rv).Kind() != reflect.Struct {
values[idx] = reflectValue.Interface() break
resetFieldValues[idx] = fieldValue }
}
} if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
if rows.Next() { db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
db.RowsAffected++ insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
db.AddError(rows.Scan(values...)) }
for idx, fv := range resetFieldValues {
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
fv.Set(v.Elem())
} }
} }
case reflect.Struct:
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
} }
} else {
db.AddError(err)
} }
} else {
db.AddError(err)
} }
} }
} else if !db.DryRun && db.Error == nil {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
} }
} }
} }

View File

@ -22,7 +22,7 @@ func Query(db *gorm.DB) {
} }
defer rows.Close() defer rows.Close()
gorm.Scan(rows, db, false) gorm.Scan(rows, db, 0)
} }
} }
} }

View File

@ -50,40 +50,56 @@ func BeforeUpdate(db *gorm.DB) {
} }
} }
func Update(db *gorm.DB) { func Update(config *Config) func(db *gorm.DB) {
if db.Error != nil { withReturning := false
return for _, clause := range config.UpdateClauses {
} if clause == "RETURNING" {
withReturning = true
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
} }
} }
if db.Statement.SQL.String() == "" { return func(db *gorm.DB) {
db.Statement.SQL.Grow(180) if db.Error != nil {
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return return
} }
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { if db.Statement.Schema != nil && !db.Statement.Unscoped {
db.AddError(gorm.ErrMissingWhereClause) for _, c := range db.Statement.Schema.UpdateClauses {
return db.Statement.AddClause(c)
} }
}
if !db.DryRun && db.Error == nil { if db.Statement.SQL.String() == "" {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return
}
db.Statement.Build(db.Statement.BuildClauses...)
}
if err == nil { if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.RowsAffected, _ = result.RowsAffected() db.AddError(gorm.ErrMissingWhereClause)
} else { return
db.AddError(err) }
if !db.DryRun && db.Error == nil {
if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, gorm.ScanUpdate)
rows.Close()
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
} }
} }
} }

View File

@ -26,7 +26,7 @@ func (onConflict OnConflict) Build(builder Builder) {
} }
builder.WriteString(`) `) builder.WriteString(`) `)
} }
if len(onConflict.TargetWhere.Exprs) > 0 { if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ") builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder) onConflict.TargetWhere.Build(builder)

View File

@ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
} }
tx.Statement.ReflectValue = elem tx.Statement.ReflectValue = elem
} }
Scan(rows, tx, true) Scan(rows, tx, ScanInitialized)
return tx.Error return tx.Error
} }

308
scan.go
View File

@ -49,13 +49,93 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
} }
} }
func Scan(rows *sql.Rows, db *DB, initialized bool) { func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
columns, _ := rows.Columns() for idx, column := range columns {
values := make([]interface{}, len(columns)) if sch == nil {
values[idx] = reflectValue.Interface()
} else if field := sch.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
sch = nil
values[idx] = reflectValue.Interface()
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
field.Set(reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(reflectValue)
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
}
}
}
}
}
}
type ScanMode uint8
const (
ScanInitialized ScanMode = 1 << 0
ScanUpdate = 1 << 1
ScanOnConflictDoNothing = 1 << 2
)
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
db.RowsAffected = 0 db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}: case map[string]interface{}, *map[string]interface{}:
if update && db.Statement.Schema != nil {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
fields := make([]*schema.Field, len(columns))
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
}
}
if initialized || rows.Next() {
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
}
}
}
if initialized || rows.Next() { if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes() columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns) prepareValues(values, db, columnTypes, columns)
@ -71,7 +151,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
scanIntoMap(mapValue, values, columns) scanIntoMap(mapValue, values, columns)
} }
case *[]map[string]interface{}: case *[]map[string]interface{}, []map[string]interface{}:
columnTypes, _ := rows.ColumnTypes() columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() { for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns) prepareValues(values, db, columnTypes, columns)
@ -82,7 +162,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
mapValue := map[string]interface{}{} mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns) scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue) if values, ok := dest.([]map[string]interface{}); ok {
values = append(values, mapValue)
} else if values, ok := dest.(*[]map[string]interface{}); ok {
*values = append(*values, mapValue)
}
} }
case *int, *int8, *int16, *int32, *int64, case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
@ -96,155 +180,109 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
db.AddError(rows.Scan(dest)) db.AddError(rows.Scan(dest))
} }
default: default:
Schema := db.Statement.Schema var (
reflectValue := db.Statement.ReflectValue fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface { if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem() reflectValue = reflectValue.Elem()
} }
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
if len(columns) == 1 {
// isPluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var ( var elem reflect.Value
reflectValueType = reflectValue.Type().Elem()
isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
)
if isPtr { if !update {
reflectValueType = reflectValueType.Elem() db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
}
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
if Schema != nil {
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
// pluck values into slice of data
isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
isPluck = true
}
} }
for initialized || rows.Next() { for initialized || rows.Next() {
BEGIN:
initialized = false initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType) if update {
if isPluck { if int(db.RowsAffected) >= reflectValue.Len() {
db.AddError(rows.Scan(elem.Interface())) return
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
} }
elem = reflectValue.Index(int(db.RowsAffected))
db.AddError(rows.Scan(values...)) if onConflictDonothing {
for _, field := range fields {
for idx, field := range fields { if _, ok := field.ValueOf(elem); !ok {
if len(joinFields) != 0 && joinFields[idx][0] != nil { db.RowsAffected++
value := reflect.ValueOf(values[idx]).Elem() goto BEGIN
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
} else if field != nil {
field.Set(elem, values[idx])
}
}
}
if isPtr {
reflectValue = reflect.Append(reflectValue, elem)
} else {
reflectValue = reflect.Append(reflectValue, elem.Elem())
}
}
db.Statement.ReflectValue.Set(reflectValue)
case reflect.Struct, reflect.Ptr:
if reflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
} }
} }
values[idx] = &sql.RawBytes{} }
} else if len(columns) == 1 { } else {
values[idx] = dest elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
if !update {
if isPtr {
reflectValue = reflect.Append(reflectValue, elem)
} else { } else {
values[idx] = &sql.RawBytes{} reflectValue = reflect.Append(reflectValue, elem.Elem())
} }
} }
}
db.RowsAffected++ if !update {
db.AddError(rows.Scan(values...)) db.Statement.ReflectValue.Set(reflectValue)
}
for idx, column := range columns { case reflect.Struct, reflect.Ptr:
if field := Schema.LookUpField(column); field != nil && field.Readable { if initialized || rows.Next() {
field.Set(reflectValue, values[idx]) db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(reflectValue)
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
}
}
}
}
} }
default: default:
db.AddError(rows.Scan(dest)) db.AddError(rows.Scan(dest))

View File

@ -7,9 +7,9 @@ require (
github.com/jinzhu/now v1.1.2 github.com/jinzhu/now v1.1.2
github.com/lib/pq v1.10.3 github.com/lib/pq v1.10.3
gorm.io/driver/mysql v1.1.2 gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.1.2 gorm.io/driver/postgres v1.2.0
gorm.io/driver/sqlite v1.1.6 gorm.io/driver/sqlite v1.2.0
gorm.io/driver/sqlserver v1.1.0 gorm.io/driver/sqlserver v1.1.1
gorm.io/gorm v1.21.16 gorm.io/gorm v1.21.16
) )

View File

@ -1,9 +1,9 @@
package tests_test package tests_test
import ( import (
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"testing" "testing"
"gorm.io/gorm"
) )
func TestReturningWithNullToZeroValues(t *testing.T) { func TestReturningWithNullToZeroValues(t *testing.T) {
@ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) {
Name string `gorm:"default:null"` Name string `gorm:"default:null"`
} }
u1 := user{} u1 := user{}
c := DB.Callback().Create().Get("gorm:create")
t.Cleanup(func() {
DB.Callback().Create().Replace("gorm:create", c)
})
DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true}))
if results := DB.Create(&u1); results.Error != nil { if results := DB.Create(&u1); results.Error != nil {
t.Fatalf("errors happened on create: %v", results.Error) t.Fatalf("errors happened on create: %v", results.Error)

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -166,13 +167,16 @@ func TestUpdates(t *testing.T) {
} }
// update with gorm exprs // update with gorm exprs
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
} }
var user4 User var user4 User
DB.First(&user4, user3.ID) DB.First(&user4, user3.ID)
user3.Age += 100 // sqlite, postgres support returning
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
user3.Age += 100
}
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
} }