From af3fbdc2fcfface01ce2a0795ee0fac3997ddc8e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Oct 2021 22:36:37 +0800 Subject: [PATCH] Improve returning support --- callbacks/callbacks.go | 28 ++-- callbacks/create.go | 237 ++++++++++--------------------- callbacks/query.go | 2 +- callbacks/update.go | 68 +++++---- clause/on_conflict.go | 2 +- finisher_api.go | 2 +- scan.go | 308 +++++++++++++++++++++++------------------ tests/go.mod | 6 +- tests/gorm_test.go | 9 +- tests/update_test.go | 8 +- 10 files changed, 315 insertions(+), 355 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d85c1928..bc18d854 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -13,7 +13,6 @@ var ( type Config struct { LastInsertIDReversed bool - WithReturning bool CreateClauses []string QueryClauses []string UpdateClauses []string @@ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { return !db.SkipDefaultTransaction } + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + createCallback := db.Callback().Create() createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) @@ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - if len(config.CreateClauses) == 0 { - config.CreateClauses = createClauses - } createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) - if len(config.QueryClauses) == 0 { - config.QueryClauses = queryClauses - } queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() @@ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - if len(config.DeleteClauses) == 0 { - config.DeleteClauses = deleteClauses - } deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() @@ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) - updateCallback.Register("gorm:update", Update) + updateCallback.Register("gorm:update", Update(config)) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - if len(config.UpdateClauses) == 0 { - config.UpdateClauses = updateClauses - } updateCallback.Clauses = config.UpdateClauses rowCallback := db.Callback().Row() diff --git a/callbacks/create.go b/callbacks/create.go index c889caf6..fe4cd797 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -31,18 +31,35 @@ func BeforeCreate(db *gorm.DB) { } func Create(config *Config) func(db *gorm.DB) { - if config.WithReturning { - return CreateWithReturning + withReturning := false + for _, clause := range config.CreateClauses { + if clause == "RETURNING" { + withReturning = true + } } return func(db *gorm.DB) { if db.Error != nil { return } + onReturning := false - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) + if db.Statement.Schema != nil { + if !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + onReturning = true + if _, ok := db.Statement.Clauses["RETURNING"]; !ok { + fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) + } } } @@ -55,180 +72,70 @@ func Create(config *Config) func(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err != nil { - db.AddError(err) - return - } - - db.RowsAffected, _ = result.RowsAffected() - - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } + if onReturning { + doNothing := false + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + onConflict, _ := c.Expression.(clause.OnConflict) + doNothing = onConflict.DoNothing + } + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + if doNothing { + gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing) + } else { + gorm.Scan(rows, db, gorm.ScanUpdate) } - } else { + rows.Close() + } + } else { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err != nil { db.AddError(err) - } - } - } - } -} - -func CreateWithReturning(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{}) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - - db.Statement.Build(db.Statement.BuildClauses...) - } - - if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { - db.Statement.WriteString(" RETURNING ") - - var ( - fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) - values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) - ) - - for idx, field := range sch.FieldsWithDefaultDBValue { - if idx > 0 { - db.Statement.WriteByte(',') + return } - fields[idx] = field - db.Statement.WriteQuoted(field.DBName) - } - - if !db.DryRun && db.Error == nil { - db.RowsAffected = 0 - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - defer rows.Close() - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - c = db.Statement.Clauses["ON CONFLICT"] - onConflict, _ = c.Expression.(clause.OnConflict) - resetFieldValues = map[int]reflect.Value{} - ) - - for rows.Next() { - BEGIN: - reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) - if reflect.Indirect(reflectValue).Kind() != reflect.Struct { - break - } - - for idx, field := range fields { - fieldValue := field.ReflectValueOf(reflectValue) - - if onConflict.DoNothing && !fieldValue.IsZero() { - db.RowsAffected++ - - if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { - return + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break } - goto BEGIN + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } } - - if field.FieldType.Kind() == reflect.Ptr { - values[idx] = fieldValue.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) - reflectValue.Elem().Set(fieldValue.Addr()) - values[idx] = reflectValue.Interface() - resetFieldValues[idx] = fieldValue - } - } - - db.RowsAffected++ - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } - - for idx, fv := range resetFieldValues { - if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { - fv.Set(v.Elem()) - } - } - } - case reflect.Struct: - resetFieldValues := map[int]reflect.Value{} - for idx, field := range fields { - if field.FieldType.Kind() == reflect.Ptr { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } else { - reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) - fieldValue := field.ReflectValueOf(db.Statement.ReflectValue) - reflectValue.Elem().Set(fieldValue.Addr()) - values[idx] = reflectValue.Interface() - resetFieldValues[idx] = fieldValue - } - } - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - for idx, fv := range resetFieldValues { - if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { - fv.Set(v.Elem()) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } } } + case reflect.Struct: + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } + } else { + db.AddError(err) } - } else { - db.AddError(err) } } - } else if !db.DryRun && db.Error == nil { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) - } } } } diff --git a/callbacks/query.go b/callbacks/query.go index 0eee2a43..0cfb0b3f 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -22,7 +22,7 @@ func Query(db *gorm.DB) { } defer rows.Close() - gorm.Scan(rows, db, false) + gorm.Scan(rows, db, 0) } } } diff --git a/callbacks/update.go b/callbacks/update.go index a0a2c579..90dc6a89 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -50,40 +50,56 @@ func BeforeUpdate(db *gorm.DB) { } } -func Update(db *gorm.DB) { - if db.Error != nil { - return - } - - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) +func Update(config *Config) func(db *gorm.DB) { + withReturning := false + for _, clause := range config.UpdateClauses { + if clause == "RETURNING" { + withReturning = true } } - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { + return func(db *gorm.DB) { + if db.Error != nil { return } - db.Statement.Build(db.Statement.BuildClauses...) - } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build(db.Statement.BuildClauses...) + } - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + if !db.DryRun && db.Error == nil { + if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok { + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, gorm.ScanUpdate) + rows.Close() + } + } else { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } } } } diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 64ee7f53..309c5fcd 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -26,7 +26,7 @@ func (onConflict OnConflict) Build(builder Builder) { } builder.WriteString(`) `) } - + if len(onConflict.TargetWhere.Exprs) > 0 { builder.WriteString(" WHERE ") onConflict.TargetWhere.Build(builder) diff --git a/finisher_api.go b/finisher_api.go index e98efc92..48eb94c5 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { } tx.Statement.ReflectValue = elem } - Scan(rows, tx, true) + Scan(rows, tx, ScanInitialized) return tx.Error } diff --git a/scan.go b/scan.go index 4570380d..37f5112d 100644 --- a/scan.go +++ b/scan.go @@ -49,13 +49,93 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func Scan(rows *sql.Rows, db *DB, initialized bool) { - columns, _ := rows.Columns() - values := make([]interface{}, len(columns)) +func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, column := range columns { + if sch == nil { + values[idx] = reflectValue.Interface() + } else if field := sch.LookUpField(column); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + continue + } + } + values[idx] = &sql.RawBytes{} + } else if len(columns) == 1 { + sch = nil + values[idx] = reflectValue.Interface() + } else { + values[idx] = &sql.RawBytes{} + } + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + field.Set(reflectValue, values[idx]) + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + relValue := rel.Field.ReflectValueOf(reflectValue) + value := reflect.ValueOf(values[idx]).Elem() + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } + } + } + } + } +} + +type ScanMode uint8 + +const ( + ScanInitialized ScanMode = 1 << 0 + ScanUpdate = 1 << 1 + ScanOnConflictDoNothing = 1 << 2 +) + +func Scan(rows *sql.Rows, db *DB, mode ScanMode) { + var ( + columns, _ = rows.Columns() + values = make([]interface{}, len(columns)) + initialized = mode&ScanInitialized != 0 + update = mode&ScanUpdate != 0 + onConflictDonothing = mode&ScanOnConflictDoNothing != 0 + ) + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: + if update && db.Statement.Schema != nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + fields := make([]*schema.Field, len(columns)) + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } + } + + if initialized || rows.Next() { + db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) + } + } + } + if initialized || rows.Next() { columnTypes, _ := rows.ColumnTypes() prepareValues(values, db, columnTypes, columns) @@ -71,7 +151,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } scanIntoMap(mapValue, values, columns) } - case *[]map[string]interface{}: + case *[]map[string]interface{}, []map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) @@ -82,7 +162,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) - *dest = append(*dest, mapValue) + if values, ok := dest.([]map[string]interface{}); ok { + values = append(values, mapValue) + } else if values, ok := dest.(*[]map[string]interface{}); ok { + *values = append(*values, mapValue) + } } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, @@ -96,155 +180,109 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(dest)) } default: - Schema := db.Statement.Schema - reflectValue := db.Statement.ReflectValue + var ( + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue + ) + if reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } + reflectValueType := reflectValue.Type() + switch reflectValueType.Kind() { + case reflect.Array, reflect.Slice: + reflectValueType = reflectValueType.Elem() + } + isPtr := reflectValueType.Kind() == reflect.Ptr + if isPtr { + reflectValueType = reflectValueType.Elem() + } + + if sch != nil { + if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { + sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + + if len(columns) == 1 { + // isPluck + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner + reflectValueType.Kind() != reflect.Struct || // is not struct + sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + sch = nil + } + } + } + switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - var ( - reflectValueType = reflectValue.Type().Elem() - isPtr = reflectValueType.Kind() == reflect.Ptr - fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field - ) + var elem reflect.Value - if isPtr { - reflectValueType = reflectValueType.Elem() - } - - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - - if Schema != nil { - if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) - } - - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - } - - // pluck values into slice of data - isPluck := false - if len(fields) == 1 { - if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner - reflectValueType.Kind() != reflect.Struct || // is not struct - Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time - isPluck = true - } + if !update { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) } for initialized || rows.Next() { + BEGIN: initialized = false - db.RowsAffected++ - elem := reflect.New(reflectValueType) - if isPluck { - db.AddError(rows.Scan(elem.Interface())) - } else { - for idx, field := range fields { - if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } + if update { + if int(db.RowsAffected) >= reflectValue.Len() { + return } - - db.AddError(rows.Scan(values...)) - - for idx, field := range fields { - if len(joinFields) != 0 && joinFields[idx][0] != nil { - value := reflect.ValueOf(values[idx]).Elem() - relValue := joinFields[idx][0].ReflectValueOf(elem) - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(relValue, values[idx]) - } else if field != nil { - field.Set(elem, values[idx]) - } - } - } - - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) - } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) - } - } - - db.Statement.ReflectValue.Set(reflectValue) - case reflect.Struct, reflect.Ptr: - if reflectValue.Type() != Schema.ModelType { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) - } - - if initialized || rows.Next() { - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - continue + elem = reflectValue.Index(int(db.RowsAffected)) + if onConflictDonothing { + for _, field := range fields { + if _, ok := field.ValueOf(elem); !ok { + db.RowsAffected++ + goto BEGIN } } - values[idx] = &sql.RawBytes{} - } else if len(columns) == 1 { - values[idx] = dest + } + } else { + elem = reflect.New(reflectValueType) + } + + db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + + if !update { + if isPtr { + reflectValue = reflect.Append(reflectValue, elem) } else { - values[idx] = &sql.RawBytes{} + reflectValue = reflect.Append(reflectValue, elem.Elem()) } } + } - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - field.Set(reflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(reflectValue) - value := reflect.ValueOf(values[idx]).Elem() - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(relValue, values[idx]) - } - } - } - } + if !update { + db.Statement.ReflectValue.Set(reflectValue) + } + case reflect.Struct, reflect.Ptr: + if initialized || rows.Next() { + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) } default: db.AddError(rows.Scan(dest)) diff --git a/tests/go.mod b/tests/go.mod index e18dc1dc..96db0559 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,9 +7,9 @@ require ( github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 - gorm.io/driver/postgres v1.1.2 - gorm.io/driver/sqlite v1.1.6 - gorm.io/driver/sqlserver v1.1.0 + gorm.io/driver/postgres v1.2.0 + gorm.io/driver/sqlite v1.2.0 + gorm.io/driver/sqlserver v1.1.1 gorm.io/gorm v1.21.16 ) diff --git a/tests/gorm_test.go b/tests/gorm_test.go index 39741439..9827465c 100644 --- a/tests/gorm_test.go +++ b/tests/gorm_test.go @@ -1,9 +1,9 @@ package tests_test import ( - "gorm.io/gorm" - "gorm.io/gorm/callbacks" "testing" + + "gorm.io/gorm" ) func TestReturningWithNullToZeroValues(t *testing.T) { @@ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) { Name string `gorm:"default:null"` } u1 := user{} - c := DB.Callback().Create().Get("gorm:create") - t.Cleanup(func() { - DB.Callback().Create().Replace("gorm:create", c) - }) - DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true})) if results := DB.Create(&u1); results.Error != nil { t.Fatalf("errors happened on create: %v", results.Error) diff --git a/tests/update_test.go b/tests/update_test.go index 631d0d6d..0dd9465a 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -9,6 +9,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) @@ -166,13 +167,16 @@ func TestUpdates(t *testing.T) { } // update with gorm exprs - if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } var user4 User DB.First(&user4, user3.ID) - user3.Age += 100 + // sqlite, postgres support returning + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + user3.Age += 100 + } AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") }