gorm/scan.go

308 lines
8.5 KiB
Go
Raw Normal View History

2020-05-26 18:13:05 +03:00
package gorm
2020-03-04 06:32:36 +03:00
import (
"database/sql"
"database/sql/driver"
2020-03-04 06:32:36 +03:00
"reflect"
2020-04-29 18:47:18 +03:00
"strings"
2020-09-24 15:08:24 +03:00
"time"
2020-03-04 06:32:36 +03:00
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/schema"
2020-03-04 06:32:36 +03:00
)
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
2022-03-29 13:14:29 +03:00
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
2022-02-27 17:54:43 +03:00
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
2022-02-27 17:54:43 +03:00
} else if len(fields) == 1 {
2022-02-27 04:09:29 +03:00
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
2021-10-26 17:36:37 +03:00
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedSchemaMap := make(map[*schema.Field]interface{}, 0)
2022-02-28 04:28:19 +03:00
for idx, field := range fields {
if field != nil {
if len(joinFields) == 0 || joinFields[idx][0] == nil {
2022-03-23 12:24:25 +03:00
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
2022-02-28 04:28:19 +03:00
} else {
joinSchema := joinFields[idx][0]
relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr {
if _, ok := joinedSchemaMap[joinSchema]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
continue
}
2022-02-27 17:54:43 +03:00
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedSchemaMap[joinSchema] = nil
}
2022-02-28 04:28:19 +03:00
}
2022-03-23 12:24:25 +03:00
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
2021-10-26 17:36:37 +03:00
}
2022-02-28 04:28:19 +03:00
// release data to pool
field.NewValuePool.Put(values[idx])
2021-10-26 17:36:37 +03:00
}
}
}
// ScanMode scan data mode
2021-10-26 17:36:37 +03:00
type ScanMode uint8
// scan modes
2021-10-26 17:36:37 +03:00
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
2021-10-26 17:36:37 +03:00
)
// Scan scan rows into db statement
2022-03-29 13:14:29 +03:00
func Scan(rows Rows, db *DB, mode ScanMode) {
2021-10-26 17:36:37 +03:00
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
db.RowsAffected = 0
2020-03-04 06:32:36 +03:00
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
2020-05-26 18:13:05 +03:00
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
2020-03-04 06:32:36 +03:00
db.RowsAffected++
2020-05-30 20:21:16 +03:00
db.AddError(rows.Scan(values...))
2020-03-04 06:32:36 +03:00
mapValue, ok := dest.(map[string]interface{})
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
2021-11-01 07:00:36 +03:00
if *v == nil {
*v = map[string]interface{}{}
}
mapValue = *v
}
2020-03-04 06:32:36 +03:00
}
scanIntoMap(mapValue, values, columns)
2020-03-04 06:32:36 +03:00
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
2020-05-26 18:13:05 +03:00
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
2020-05-26 18:13:05 +03:00
initialized = false
2020-03-04 06:32:36 +03:00
db.RowsAffected++
2020-05-30 20:21:16 +03:00
db.AddError(rows.Scan(values...))
2020-03-04 06:32:36 +03:00
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
2020-03-04 06:32:36 +03:00
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
2020-05-26 18:13:05 +03:00
for initialized || rows.Next() {
initialized = false
2020-05-24 06:32:59 +03:00
db.RowsAffected++
2020-05-30 20:21:16 +03:00
db.AddError(rows.Scan(dest))
2020-05-24 06:32:59 +03:00
}
2020-03-04 06:32:36 +03:00
default:
2021-10-26 17:36:37 +03:00
var (
fields = make([]*schema.Field, len(columns))
selectedColumnsMap = make(map[string]int, len(columns))
joinFields [][2]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
2021-10-26 17:36:37 +03:00
)
2022-02-27 17:54:43 +03:00
if reflectValue.Kind() == reflect.Interface {
2021-09-17 09:04:19 +03:00
reflectValue = reflectValue.Elem()
}
2020-07-01 03:56:21 +03:00
2021-10-26 17:36:37 +03:00
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
2020-05-31 16:11:20 +03:00
}
2021-10-26 17:36:37 +03:00
if len(columns) == 1 {
2022-02-27 17:54:43 +03:00
// Is Pluck
2021-10-26 17:36:37 +03:00
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
2020-11-10 13:38:24 +03:00
reflectValueType.Kind() != reflect.Struct || // is not struct
2021-10-26 17:36:37 +03:00
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
2020-09-24 15:08:24 +03:00
}
}
2022-02-27 17:54:43 +03:00
// Not Pluck
if sch != nil {
schFieldsCount := len(sch.Fields)
2022-02-27 17:54:43 +03:00
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
if curIndex, ok := selectedColumnsMap[column]; ok {
fields[idx] = field // handle duplicate fields
offset := curIndex + 1
// handle sch inconsistent with database
// like Raw(`...`).Scan
if schFieldsCount > offset {
for fieldIndex, selectField := range sch.Fields[offset:] {
if selectField.DBName == column && selectField.Readable {
selectedColumnsMap[column] = curIndex + fieldIndex + 1
fields[idx] = selectField
break
}
}
}
} else {
fields[idx] = field
selectedColumnsMap[column] = idx
}
2022-02-27 17:54:43 +03:00
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
2021-10-26 17:36:37 +03:00
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var elem reflect.Value
recyclableStruct := reflect.New(reflectValueType)
2021-10-26 17:36:37 +03:00
if !update || reflectValue.Len() == 0 {
update = false
2021-10-26 17:36:37 +03:00
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
}
2020-09-24 15:08:24 +03:00
2020-05-26 18:13:05 +03:00
for initialized || rows.Next() {
2021-10-26 17:36:37 +03:00
BEGIN:
2020-05-26 18:13:05 +03:00
initialized = false
2021-10-26 17:36:37 +03:00
if update {
if int(db.RowsAffected) >= reflectValue.Len() {
return
}
2021-10-26 17:36:37 +03:00
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing {
for _, field := range fields {
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
2021-10-26 17:36:37 +03:00
db.RowsAffected++
goto BEGIN
2020-05-31 16:11:20 +03:00
}
}
}
2020-11-10 13:38:24 +03:00
} else {
if isPtr && db.RowsAffected > 0 {
elem = reflect.New(reflectValueType)
} else {
elem = recyclableStruct
}
2020-03-04 06:32:36 +03:00
}
2021-09-17 13:35:14 +03:00
2022-02-27 17:54:43 +03:00
db.scanIntoStruct(rows, elem, values, fields, joinFields)
2020-07-01 03:56:21 +03:00
2021-10-26 17:36:37 +03:00
if !update {
if isPtr {
reflectValue = reflect.Append(reflectValue, elem)
} else {
2021-10-26 17:36:37 +03:00
reflectValue = reflect.Append(reflectValue, elem.Elem())
}
2020-03-04 06:32:36 +03:00
}
2021-10-26 17:36:37 +03:00
}
2020-03-04 06:32:36 +03:00
2021-10-26 17:36:37 +03:00
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
2022-02-27 17:54:43 +03:00
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
2020-03-04 06:32:36 +03:00
}
default:
db.AddError(rows.Scan(dest))
2020-03-04 06:32:36 +03:00
}
}
if err := rows.Err(); err != nil && err != db.Error {
db.AddError(err)
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
2020-05-26 18:13:05 +03:00
db.AddError(ErrRecordNotFound)
2020-03-04 06:32:36 +03:00
}
}