mirror of https://github.com/go-gorm/gorm.git
361 lines
10 KiB
Go
361 lines
10 KiB
Go
package gorm
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"reflect"
|
|
"time"
|
|
|
|
"gorm.io/gorm/schema"
|
|
"gorm.io/gorm/utils"
|
|
)
|
|
|
|
// prepareValues prepare values slice
|
|
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
|
if db.Statement.Schema != nil {
|
|
for idx, name := range columns {
|
|
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
|
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
|
continue
|
|
}
|
|
values[idx] = new(interface{})
|
|
}
|
|
} else if len(columnTypes) > 0 {
|
|
for idx, columnType := range columnTypes {
|
|
if columnType.ScanType() != nil {
|
|
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
|
} else {
|
|
values[idx] = new(interface{})
|
|
}
|
|
}
|
|
} else {
|
|
for idx := range columns {
|
|
values[idx] = new(interface{})
|
|
}
|
|
}
|
|
}
|
|
|
|
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
|
for idx, column := range columns {
|
|
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
|
mapValue[column] = reflectValue.Interface()
|
|
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
|
mapValue[column], _ = valuer.Value()
|
|
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
|
|
mapValue[column] = string(b)
|
|
}
|
|
} else {
|
|
mapValue[column] = nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
|
for idx, field := range fields {
|
|
if field != nil {
|
|
values[idx] = field.NewValuePool.Get()
|
|
} else if len(fields) == 1 {
|
|
if reflectValue.CanAddr() {
|
|
values[idx] = reflectValue.Addr().Interface()
|
|
} else {
|
|
values[idx] = reflectValue.Interface()
|
|
}
|
|
}
|
|
}
|
|
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(values...))
|
|
joinedNestedSchemaMap := make(map[string]interface{})
|
|
for idx, field := range fields {
|
|
if field == nil {
|
|
continue
|
|
}
|
|
|
|
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
|
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
|
} else { // joinFields count is larger than 2 when using join
|
|
var isNilPtrValue bool
|
|
var relValue reflect.Value
|
|
// does not contain raw dbname
|
|
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
|
// current reflect value
|
|
currentReflectValue := reflectValue
|
|
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
|
for _, joinSchema := range nestedJoinSchemas {
|
|
fullRels = append(fullRels, joinSchema.Name)
|
|
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
|
if relValue.Kind() == reflect.Ptr {
|
|
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
|
// same nested structure
|
|
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
|
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
|
isNilPtrValue = true
|
|
break
|
|
}
|
|
|
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
|
joinedNestedSchemaMap[fullRelsName] = nil
|
|
}
|
|
}
|
|
currentReflectValue = relValue
|
|
}
|
|
|
|
if !isNilPtrValue { // ignore if value is nil
|
|
f := joinFields[idx][len(joinFields[idx])-1]
|
|
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
|
}
|
|
}
|
|
|
|
// release data to pool
|
|
field.NewValuePool.Put(values[idx])
|
|
}
|
|
}
|
|
|
|
// ScanMode scan data mode
|
|
type ScanMode uint8
|
|
|
|
// scan modes
|
|
const (
|
|
ScanInitialized ScanMode = 1 << 0 // 1
|
|
ScanUpdate ScanMode = 1 << 1 // 2
|
|
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
|
|
)
|
|
|
|
// Scan scan rows into db statement
|
|
func Scan(rows Rows, db *DB, mode ScanMode) {
|
|
var (
|
|
columns, _ = rows.Columns()
|
|
values = make([]interface{}, len(columns))
|
|
initialized = mode&ScanInitialized != 0
|
|
update = mode&ScanUpdate != 0
|
|
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
|
|
)
|
|
|
|
if len(db.Statement.ColumnMapping) > 0 {
|
|
for i, column := range columns {
|
|
v, ok := db.Statement.ColumnMapping[column]
|
|
if ok {
|
|
columns[i] = v
|
|
}
|
|
}
|
|
}
|
|
|
|
db.RowsAffected = 0
|
|
|
|
switch dest := db.Statement.Dest.(type) {
|
|
case map[string]interface{}, *map[string]interface{}:
|
|
if initialized || rows.Next() {
|
|
columnTypes, _ := rows.ColumnTypes()
|
|
prepareValues(values, db, columnTypes, columns)
|
|
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(values...))
|
|
|
|
mapValue, ok := dest.(map[string]interface{})
|
|
if !ok {
|
|
if v, ok := dest.(*map[string]interface{}); ok {
|
|
if *v == nil {
|
|
*v = map[string]interface{}{}
|
|
}
|
|
mapValue = *v
|
|
}
|
|
}
|
|
scanIntoMap(mapValue, values, columns)
|
|
}
|
|
case *[]map[string]interface{}:
|
|
columnTypes, _ := rows.ColumnTypes()
|
|
for initialized || rows.Next() {
|
|
prepareValues(values, db, columnTypes, columns)
|
|
|
|
initialized = false
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(values...))
|
|
|
|
mapValue := map[string]interface{}{}
|
|
scanIntoMap(mapValue, values, columns)
|
|
*dest = append(*dest, mapValue)
|
|
}
|
|
case *int, *int8, *int16, *int32, *int64,
|
|
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
|
*float32, *float64,
|
|
*bool, *string, *time.Time,
|
|
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
|
|
*sql.NullBool, *sql.NullString, *sql.NullTime:
|
|
for initialized || rows.Next() {
|
|
initialized = false
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(dest))
|
|
}
|
|
default:
|
|
var (
|
|
fields = make([]*schema.Field, len(columns))
|
|
joinFields [][]*schema.Field
|
|
sch = db.Statement.Schema
|
|
reflectValue = db.Statement.ReflectValue
|
|
)
|
|
|
|
if reflectValue.Kind() == reflect.Interface {
|
|
reflectValue = reflectValue.Elem()
|
|
}
|
|
|
|
reflectValueType := reflectValue.Type()
|
|
switch reflectValueType.Kind() {
|
|
case reflect.Array, reflect.Slice:
|
|
reflectValueType = reflectValueType.Elem()
|
|
}
|
|
isPtr := reflectValueType.Kind() == reflect.Ptr
|
|
if isPtr {
|
|
reflectValueType = reflectValueType.Elem()
|
|
}
|
|
|
|
if sch != nil {
|
|
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
|
|
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
|
}
|
|
|
|
if len(columns) == 1 {
|
|
// Is Pluck
|
|
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
|
reflectValueType.Kind() != reflect.Struct || // is not struct
|
|
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
|
sch = nil
|
|
}
|
|
}
|
|
|
|
// Not Pluck
|
|
if sch != nil {
|
|
matchedFieldCount := make(map[string]int, len(columns))
|
|
for idx, column := range columns {
|
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
|
fields[idx] = field
|
|
if count, ok := matchedFieldCount[column]; ok {
|
|
// handle duplicate fields
|
|
for _, selectField := range sch.Fields {
|
|
if selectField.DBName == column && selectField.Readable {
|
|
if count == 0 {
|
|
matchedFieldCount[column]++
|
|
fields[idx] = selectField
|
|
break
|
|
}
|
|
count--
|
|
}
|
|
}
|
|
} else {
|
|
matchedFieldCount[column] = 1
|
|
}
|
|
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
|
subNameCount := len(names)
|
|
// nested relation fields
|
|
relFields := make([]*schema.Field, 0, subNameCount-1)
|
|
relFields = append(relFields, rel.Field)
|
|
for _, name := range names[1 : subNameCount-1] {
|
|
rel = rel.FieldSchema.Relationships.Relations[name]
|
|
relFields = append(relFields, rel.Field)
|
|
}
|
|
// latest name is raw dbname
|
|
dbName := names[subNameCount-1]
|
|
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
|
fields[idx] = field
|
|
|
|
if len(joinFields) == 0 {
|
|
joinFields = make([][]*schema.Field, len(columns))
|
|
}
|
|
relFields = append(relFields, field)
|
|
joinFields[idx] = relFields
|
|
continue
|
|
}
|
|
}
|
|
var val interface{}
|
|
values[idx] = &val
|
|
} else {
|
|
var val interface{}
|
|
values[idx] = &val
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
switch reflectValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
var (
|
|
elem reflect.Value
|
|
isArrayKind = reflectValue.Kind() == reflect.Array
|
|
)
|
|
|
|
if !update || reflectValue.Len() == 0 {
|
|
update = false
|
|
if isArrayKind {
|
|
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
|
} else {
|
|
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
|
if reflectValue.Cap() == 0 {
|
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
|
} else {
|
|
reflectValue.SetLen(0)
|
|
db.Statement.ReflectValue.Set(reflectValue)
|
|
}
|
|
}
|
|
}
|
|
|
|
for initialized || rows.Next() {
|
|
BEGIN:
|
|
initialized = false
|
|
|
|
if update {
|
|
if int(db.RowsAffected) >= reflectValue.Len() {
|
|
return
|
|
}
|
|
elem = reflectValue.Index(int(db.RowsAffected))
|
|
if onConflictDonothing {
|
|
for _, field := range fields {
|
|
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
|
|
db.RowsAffected++
|
|
goto BEGIN
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
elem = reflect.New(reflectValueType)
|
|
}
|
|
|
|
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
|
|
|
if !update {
|
|
if !isPtr {
|
|
elem = elem.Elem()
|
|
}
|
|
if isArrayKind {
|
|
if reflectValue.Len() >= int(db.RowsAffected) {
|
|
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
|
}
|
|
} else {
|
|
reflectValue = reflect.Append(reflectValue, elem)
|
|
}
|
|
}
|
|
}
|
|
|
|
if !update {
|
|
db.Statement.ReflectValue.Set(reflectValue)
|
|
}
|
|
case reflect.Struct, reflect.Ptr:
|
|
if initialized || rows.Next() {
|
|
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
|
|
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
|
}
|
|
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
|
}
|
|
default:
|
|
db.AddError(rows.Scan(dest))
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil && err != db.Error {
|
|
db.AddError(err)
|
|
}
|
|
|
|
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
|
|
db.AddError(ErrRecordNotFound)
|
|
}
|
|
}
|