forked from mirror/gorm
Improve returning support
This commit is contained in:
parent
d3211908a0
commit
af3fbdc2fc
|
@ -13,7 +13,6 @@ var (
|
|||
|
||||
type Config struct {
|
||||
LastInsertIDReversed bool
|
||||
WithReturning bool
|
||||
CreateClauses []string
|
||||
QueryClauses []string
|
||||
UpdateClauses []string
|
||||
|
@ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
|||
return !db.SkipDefaultTransaction
|
||||
}
|
||||
|
||||
if len(config.CreateClauses) == 0 {
|
||||
config.CreateClauses = createClauses
|
||||
}
|
||||
if len(config.QueryClauses) == 0 {
|
||||
config.QueryClauses = queryClauses
|
||||
}
|
||||
if len(config.DeleteClauses) == 0 {
|
||||
config.DeleteClauses = deleteClauses
|
||||
}
|
||||
if len(config.UpdateClauses) == 0 {
|
||||
config.UpdateClauses = updateClauses
|
||||
}
|
||||
|
||||
createCallback := db.Callback().Create()
|
||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||
|
@ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
|||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
||||
createCallback.Register("gorm:after_create", AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
if len(config.CreateClauses) == 0 {
|
||||
config.CreateClauses = createClauses
|
||||
}
|
||||
createCallback.Clauses = config.CreateClauses
|
||||
|
||||
queryCallback := db.Callback().Query()
|
||||
queryCallback.Register("gorm:query", Query)
|
||||
queryCallback.Register("gorm:preload", Preload)
|
||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||
if len(config.QueryClauses) == 0 {
|
||||
config.QueryClauses = queryClauses
|
||||
}
|
||||
queryCallback.Clauses = config.QueryClauses
|
||||
|
||||
deleteCallback := db.Callback().Delete()
|
||||
|
@ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
|||
deleteCallback.Register("gorm:delete", Delete)
|
||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
if len(config.DeleteClauses) == 0 {
|
||||
config.DeleteClauses = deleteClauses
|
||||
}
|
||||
deleteCallback.Clauses = config.DeleteClauses
|
||||
|
||||
updateCallback := db.Callback().Update()
|
||||
|
@ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
|||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
||||
updateCallback.Register("gorm:update", Update)
|
||||
updateCallback.Register("gorm:update", Update(config))
|
||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
if len(config.UpdateClauses) == 0 {
|
||||
config.UpdateClauses = updateClauses
|
||||
}
|
||||
updateCallback.Clauses = config.UpdateClauses
|
||||
|
||||
rowCallback := db.Callback().Row()
|
||||
|
|
|
@ -31,21 +31,38 @@ func BeforeCreate(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
if config.WithReturning {
|
||||
return CreateWithReturning
|
||||
withReturning := false
|
||||
for _, clause := range config.CreateClauses {
|
||||
if clause == "RETURNING" {
|
||||
withReturning = true
|
||||
}
|
||||
}
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
onReturning := false
|
||||
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
if db.Statement.Schema != nil {
|
||||
if !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||
onReturning = true
|
||||
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
|
@ -55,6 +72,21 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if onReturning {
|
||||
doNothing := false
|
||||
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
||||
onConflict, _ := c.Expression.(clause.OnConflict)
|
||||
doNothing = onConflict.DoNothing
|
||||
}
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
if doNothing {
|
||||
gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing)
|
||||
} else {
|
||||
gorm.Scan(rows, db, gorm.ScanUpdate)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
} else {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err != nil {
|
||||
|
@ -63,7 +95,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
||||
|
@ -107,130 +138,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CreateWithReturning(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
||||
db.Statement.WriteString(" RETURNING ")
|
||||
|
||||
var (
|
||||
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
||||
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
||||
)
|
||||
|
||||
for idx, field := range sch.FieldsWithDefaultDBValue {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
fields[idx] = field
|
||||
db.Statement.WriteQuoted(field.DBName)
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
db.RowsAffected = 0
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
c = db.Statement.Clauses["ON CONFLICT"]
|
||||
onConflict, _ = c.Expression.(clause.OnConflict)
|
||||
resetFieldValues = map[int]reflect.Value{}
|
||||
)
|
||||
|
||||
for rows.Next() {
|
||||
BEGIN:
|
||||
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
|
||||
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
for idx, field := range fields {
|
||||
fieldValue := field.ReflectValueOf(reflectValue)
|
||||
|
||||
if onConflict.DoNothing && !fieldValue.IsZero() {
|
||||
db.RowsAffected++
|
||||
|
||||
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
|
||||
return
|
||||
}
|
||||
|
||||
goto BEGIN
|
||||
}
|
||||
|
||||
if field.FieldType.Kind() == reflect.Ptr {
|
||||
values[idx] = fieldValue.Addr().Interface()
|
||||
} else {
|
||||
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
|
||||
reflectValue.Elem().Set(fieldValue.Addr())
|
||||
values[idx] = reflectValue.Interface()
|
||||
resetFieldValues[idx] = fieldValue
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
|
||||
for idx, fv := range resetFieldValues {
|
||||
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
|
||||
fv.Set(v.Elem())
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
resetFieldValues := map[int]reflect.Value{}
|
||||
for idx, field := range fields {
|
||||
if field.FieldType.Kind() == reflect.Ptr {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
} else {
|
||||
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
|
||||
fieldValue := field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
reflectValue.Elem().Set(fieldValue.Addr())
|
||||
values[idx] = reflectValue.Interface()
|
||||
resetFieldValues[idx] = fieldValue
|
||||
}
|
||||
}
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
for idx, fv := range resetFieldValues {
|
||||
if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
|
||||
fv.Set(v.Elem())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
} else if !db.DryRun && db.Error == nil {
|
||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
|
|
|
@ -22,7 +22,7 @@ func Query(db *gorm.DB) {
|
|||
}
|
||||
defer rows.Close()
|
||||
|
||||
gorm.Scan(rows, db, false)
|
||||
gorm.Scan(rows, db, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,7 +50,15 @@ func BeforeUpdate(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
func Update(db *gorm.DB) {
|
||||
func Update(config *Config) func(db *gorm.DB) {
|
||||
withReturning := false
|
||||
for _, clause := range config.UpdateClauses {
|
||||
if clause == "RETURNING" {
|
||||
withReturning = true
|
||||
}
|
||||
}
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
@ -78,6 +86,12 @@ func Update(db *gorm.DB) {
|
|||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok {
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
gorm.Scan(rows, db, gorm.ScanUpdate)
|
||||
rows.Close()
|
||||
}
|
||||
} else {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
|
@ -87,6 +101,8 @@ func Update(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
|
|
|
@ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
|||
}
|
||||
tx.Statement.ReflectValue = elem
|
||||
}
|
||||
Scan(rows, tx, true)
|
||||
Scan(rows, tx, ScanInitialized)
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
|
|
356
scan.go
356
scan.go
|
@ -49,165 +49,14 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
|||
}
|
||||
}
|
||||
|
||||
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
columns, _ := rows.Columns()
|
||||
values := make([]interface{}, len(columns))
|
||||
db.RowsAffected = 0
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
if initialized || rows.Next() {
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue, ok := dest.(map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := dest.(*map[string]interface{}); ok {
|
||||
mapValue = *v
|
||||
}
|
||||
}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
}
|
||||
case *[]map[string]interface{}:
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
for initialized || rows.Next() {
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue := map[string]interface{}{}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
*dest = append(*dest, mapValue)
|
||||
}
|
||||
case *int, *int8, *int16, *int32, *int64,
|
||||
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
||||
*float32, *float64,
|
||||
*bool, *string, *time.Time,
|
||||
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
|
||||
*sql.NullBool, *sql.NullString, *sql.NullTime:
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
default:
|
||||
Schema := db.Statement.Schema
|
||||
reflectValue := db.Statement.ReflectValue
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
reflectValueType = reflectValue.Type().Elem()
|
||||
isPtr = reflectValueType.Kind() == reflect.Ptr
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
)
|
||||
|
||||
if isPtr {
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
|
||||
if Schema != nil {
|
||||
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
}
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
continue
|
||||
}
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pluck values into slice of data
|
||||
isPluck := false
|
||||
if len(fields) == 1 {
|
||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
|
||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||
isPluck = true
|
||||
}
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
|
||||
elem := reflect.New(reflectValueType)
|
||||
if isPluck {
|
||||
db.AddError(rows.Scan(elem.Interface()))
|
||||
} else {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
for idx, field := range fields {
|
||||
if len(joinFields) != 0 && joinFields[idx][0] != nil {
|
||||
value := reflect.ValueOf(values[idx]).Elem()
|
||||
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
||||
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
if value.IsNil() {
|
||||
continue
|
||||
}
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
}
|
||||
|
||||
field.Set(relValue, values[idx])
|
||||
} else if field != nil {
|
||||
field.Set(elem, values[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
} else {
|
||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if reflectValue.Type() != Schema.ModelType {
|
||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
if sch == nil {
|
||||
values[idx] = reflectValue.Interface()
|
||||
} else if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
continue
|
||||
|
@ -215,7 +64,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else if len(columns) == 1 {
|
||||
values[idx] = dest
|
||||
sch = nil
|
||||
values[idx] = reflectValue.Interface()
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
|
@ -224,11 +74,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
if sch != nil {
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
field.Set(reflectValue, values[idx])
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
relValue := rel.Field.ReflectValueOf(reflectValue)
|
||||
value := reflect.ValueOf(values[idx]).Elem()
|
||||
|
@ -246,6 +97,193 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ScanMode uint8
|
||||
|
||||
const (
|
||||
ScanInitialized ScanMode = 1 << 0
|
||||
ScanUpdate = 1 << 1
|
||||
ScanOnConflictDoNothing = 1 << 2
|
||||
)
|
||||
|
||||
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||
var (
|
||||
columns, _ = rows.Columns()
|
||||
values = make([]interface{}, len(columns))
|
||||
initialized = mode&ScanInitialized != 0
|
||||
update = mode&ScanUpdate != 0
|
||||
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
|
||||
)
|
||||
|
||||
db.RowsAffected = 0
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
if update && db.Statement.Schema != nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
fields := make([]*schema.Field, len(columns))
|
||||
for idx, column := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
}
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue, ok := dest.(map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := dest.(*map[string]interface{}); ok {
|
||||
mapValue = *v
|
||||
}
|
||||
}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
}
|
||||
case *[]map[string]interface{}, []map[string]interface{}:
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
for initialized || rows.Next() {
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue := map[string]interface{}{}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
if values, ok := dest.([]map[string]interface{}); ok {
|
||||
values = append(values, mapValue)
|
||||
} else if values, ok := dest.(*[]map[string]interface{}); ok {
|
||||
*values = append(*values, mapValue)
|
||||
}
|
||||
}
|
||||
case *int, *int8, *int16, *int32, *int64,
|
||||
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
||||
*float32, *float64,
|
||||
*bool, *string, *time.Time,
|
||||
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
|
||||
*sql.NullBool, *sql.NullString, *sql.NullTime:
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
default:
|
||||
var (
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
reflectValueType := reflectValue.Type()
|
||||
switch reflectValueType.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
isPtr := reflectValueType.Kind() == reflect.Ptr
|
||||
if isPtr {
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
|
||||
if sch != nil {
|
||||
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
}
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
continue
|
||||
}
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(columns) == 1 {
|
||||
// isPluck
|
||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||
sch = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var elem reflect.Value
|
||||
|
||||
if !update {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
BEGIN:
|
||||
initialized = false
|
||||
|
||||
if update {
|
||||
if int(db.RowsAffected) >= reflectValue.Len() {
|
||||
return
|
||||
}
|
||||
elem = reflectValue.Index(int(db.RowsAffected))
|
||||
if onConflictDonothing {
|
||||
for _, field := range fields {
|
||||
if _, ok := field.ValueOf(elem); !ok {
|
||||
db.RowsAffected++
|
||||
goto BEGIN
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
elem = reflect.New(reflectValueType)
|
||||
}
|
||||
|
||||
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
|
||||
|
||||
if !update {
|
||||
if isPtr {
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
} else {
|
||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !update {
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if initialized || rows.Next() {
|
||||
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
|
||||
}
|
||||
default:
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
|
|
|
@ -7,9 +7,9 @@ require (
|
|||
github.com/jinzhu/now v1.1.2
|
||||
github.com/lib/pq v1.10.3
|
||||
gorm.io/driver/mysql v1.1.2
|
||||
gorm.io/driver/postgres v1.1.2
|
||||
gorm.io/driver/sqlite v1.1.6
|
||||
gorm.io/driver/sqlserver v1.1.0
|
||||
gorm.io/driver/postgres v1.2.0
|
||||
gorm.io/driver/sqlite v1.2.0
|
||||
gorm.io/driver/sqlserver v1.1.1
|
||||
gorm.io/gorm v1.21.16
|
||||
)
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestReturningWithNullToZeroValues(t *testing.T) {
|
||||
|
@ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) {
|
|||
Name string `gorm:"default:null"`
|
||||
}
|
||||
u1 := user{}
|
||||
c := DB.Callback().Create().Get("gorm:create")
|
||||
t.Cleanup(func() {
|
||||
DB.Callback().Create().Replace("gorm:create", c)
|
||||
})
|
||||
DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true}))
|
||||
|
||||
if results := DB.Create(&u1); results.Error != nil {
|
||||
t.Fatalf("errors happened on create: %v", results.Error)
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/utils"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
@ -166,13 +167,16 @@ func TestUpdates(t *testing.T) {
|
|||
}
|
||||
|
||||
// update with gorm exprs
|
||||
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||
if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||
}
|
||||
var user4 User
|
||||
DB.First(&user4, user3.ID)
|
||||
|
||||
// sqlite, postgres support returning
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
user3.Age += 100
|
||||
}
|
||||
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue