mirror of https://github.com/go-gorm/gorm.git
202 lines
5.2 KiB
Go
202 lines
5.2 KiB
Go
package gorm
|
|
|
|
import (
|
|
"database/sql"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"gorm.io/gorm/schema"
|
|
)
|
|
|
|
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|
columns, _ := rows.Columns()
|
|
values := make([]interface{}, len(columns))
|
|
|
|
switch dest := db.Statement.Dest.(type) {
|
|
case map[string]interface{}, *map[string]interface{}:
|
|
if initialized || rows.Next() {
|
|
for idx := range columns {
|
|
values[idx] = new(interface{})
|
|
}
|
|
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(values...))
|
|
|
|
mapValue, ok := dest.(map[string]interface{})
|
|
if !ok {
|
|
if v, ok := dest.(*map[string]interface{}); ok {
|
|
mapValue = *v
|
|
}
|
|
}
|
|
|
|
for idx, column := range columns {
|
|
if v, ok := values[idx].(*interface{}); ok {
|
|
if v == nil {
|
|
mapValue[column] = nil
|
|
} else {
|
|
mapValue[column] = *v
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case *[]map[string]interface{}:
|
|
for initialized || rows.Next() {
|
|
for idx := range columns {
|
|
values[idx] = new(interface{})
|
|
}
|
|
|
|
initialized = false
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(values...))
|
|
|
|
mapValue := map[string]interface{}{}
|
|
for idx, column := range columns {
|
|
if v, ok := values[idx].(*interface{}); ok {
|
|
if v == nil {
|
|
mapValue[column] = nil
|
|
} else {
|
|
mapValue[column] = *v
|
|
}
|
|
}
|
|
}
|
|
|
|
*dest = append(*dest, mapValue)
|
|
}
|
|
case *int, *int64, *uint, *uint64:
|
|
for initialized || rows.Next() {
|
|
initialized = false
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(dest))
|
|
}
|
|
default:
|
|
switch db.Statement.ReflectValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
reflectValueType := db.Statement.ReflectValue.Type().Elem()
|
|
isPtr := reflectValueType.Kind() == reflect.Ptr
|
|
if isPtr {
|
|
reflectValueType = reflectValueType.Elem()
|
|
}
|
|
|
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
|
fields := make([]*schema.Field, len(columns))
|
|
joinFields := make([][2]*schema.Field, len(columns))
|
|
|
|
for idx, column := range columns {
|
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
|
fields[idx] = field
|
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
|
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
|
continue
|
|
}
|
|
}
|
|
values[idx] = &sql.RawBytes{}
|
|
} else {
|
|
values[idx] = &sql.RawBytes{}
|
|
}
|
|
}
|
|
|
|
for initialized || rows.Next() {
|
|
for idx := range columns {
|
|
values[idx] = new(interface{})
|
|
}
|
|
|
|
initialized = false
|
|
db.RowsAffected++
|
|
|
|
elem := reflect.New(reflectValueType).Elem()
|
|
|
|
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
|
|
// pluck
|
|
values[0] = elem.Addr().Interface()
|
|
db.AddError(rows.Scan(values...))
|
|
} else {
|
|
db.AddError(rows.Scan(values...))
|
|
|
|
for idx, field := range fields {
|
|
if v, ok := values[idx].(*interface{}); ok {
|
|
if field != nil {
|
|
if v == nil {
|
|
field.Set(elem, v)
|
|
} else {
|
|
field.Set(elem, *v)
|
|
}
|
|
} else if joinFields[idx][0] != nil {
|
|
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
|
if v == nil {
|
|
continue
|
|
}
|
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
|
}
|
|
|
|
if v == nil {
|
|
joinFields[idx][1].Set(relValue, nil)
|
|
} else {
|
|
joinFields[idx][1].Set(relValue, *v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for idx := range columns {
|
|
values[idx] = new(interface{})
|
|
}
|
|
}
|
|
|
|
if isPtr {
|
|
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
|
} else {
|
|
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
|
|
}
|
|
}
|
|
case reflect.Struct:
|
|
if initialized || rows.Next() {
|
|
for idx := range columns {
|
|
values[idx] = new(interface{})
|
|
}
|
|
|
|
db.RowsAffected++
|
|
db.AddError(rows.Scan(values...))
|
|
|
|
for idx, column := range columns {
|
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
|
if v, ok := values[idx].(*interface{}); ok {
|
|
if v == nil {
|
|
field.Set(db.Statement.ReflectValue, v)
|
|
} else {
|
|
field.Set(db.Statement.ReflectValue, *v)
|
|
}
|
|
}
|
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
|
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
|
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
if v, ok := values[idx].(*interface{}); ok {
|
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
|
if v == nil {
|
|
continue
|
|
}
|
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
|
}
|
|
|
|
if v == nil {
|
|
field.Set(relValue, nil)
|
|
} else {
|
|
field.Set(relValue, *v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
|
|
db.AddError(ErrRecordNotFound)
|
|
}
|
|
}
|