mirror of https://github.com/go-gorm/gorm.git
Refactor Scan
This commit is contained in:
parent
530b0a12b4
commit
43a72b369e
68
scan.go
68
scan.go
|
@ -50,58 +50,37 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||||
for idx, column := range columns {
|
for idx, field := range fields {
|
||||||
if sch == nil {
|
if field != nil {
|
||||||
values[idx] = reflectValue.Interface()
|
|
||||||
} else if field := sch.LookUpField(column); field != nil && field.Readable {
|
|
||||||
values[idx] = field.NewValuePool.Get()
|
values[idx] = field.NewValuePool.Get()
|
||||||
defer field.NewValuePool.Put(values[idx])
|
defer field.NewValuePool.Put(values[idx])
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
defer field.Set(db.Statement.Context, reflectValue, values[idx])
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
values[idx] = field.NewValuePool.Get()
|
|
||||||
defer field.NewValuePool.Put(values[idx])
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
} else if len(fields) == 1 {
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
} else if len(columns) == 1 {
|
|
||||||
sch = nil
|
|
||||||
if reflectValue.CanAddr() {
|
if reflectValue.CanAddr() {
|
||||||
values[idx] = reflectValue.Addr().Interface()
|
values[idx] = reflectValue.Addr().Interface()
|
||||||
} else {
|
} else {
|
||||||
values[idx] = reflectValue.Interface()
|
values[idx] = reflectValue.Interface()
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
if sch != nil {
|
for idx, joinField := range joinFields {
|
||||||
for idx, column := range columns {
|
if joinField[0] != nil {
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue)
|
||||||
field.Set(db.Statement.Context, reflectValue, values[idx])
|
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
|
|
||||||
|
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
}
|
}
|
||||||
|
joinField[1].Set(db.Statement.Context, relValue, values[idx])
|
||||||
field.Set(db.Statement.Context, relValue, values[idx])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -180,7 +159,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
reflectValue = db.Statement.ReflectValue
|
reflectValue = db.Statement.ReflectValue
|
||||||
)
|
)
|
||||||
|
|
||||||
for reflectValue.Kind() == reflect.Interface {
|
if reflectValue.Kind() == reflect.Interface {
|
||||||
reflectValue = reflectValue.Elem()
|
reflectValue = reflectValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,6 +178,17 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(columns) == 1 {
|
||||||
|
// Is Pluck
|
||||||
|
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||||
|
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||||
|
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||||
|
sch = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not Pluck
|
||||||
|
if sch != nil {
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
|
@ -219,14 +209,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
values[idx] = &sql.RawBytes{}
|
values[idx] = &sql.RawBytes{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(columns) == 1 {
|
|
||||||
// isPluck
|
|
||||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
|
||||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
|
||||||
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
|
||||||
sch = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,7 +242,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
elem = reflect.New(reflectValueType)
|
elem = reflect.New(reflectValueType)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
|
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||||
|
|
||||||
if !update {
|
if !update {
|
||||||
if isPtr {
|
if isPtr {
|
||||||
|
@ -276,7 +258,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
}
|
}
|
||||||
case reflect.Struct, reflect.Ptr:
|
case reflect.Struct, reflect.Ptr:
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
|
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
db.AddError(rows.Scan(dest))
|
db.AddError(rows.Scan(dest))
|
||||||
|
|
Loading…
Reference in New Issue