Refactor Scan

This commit is contained in:
Jinzhu 2022-02-27 22:54:43 +08:00
parent 530b0a12b4
commit 43a72b369e
1 changed files with 43 additions and 61 deletions

68
scan.go
View File

@ -50,58 +50,37 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
}
}
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, column := range columns {
if sch == nil {
values[idx] = reflectValue.Interface()
} else if field := sch.LookUpField(column); field != nil && field.Readable {
func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
defer field.NewValuePool.Put(values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = field.NewValuePool.Get()
defer field.NewValuePool.Put(values[idx])
continue
if len(joinFields) == 0 || joinFields[idx][0] == nil {
defer field.Set(db.Statement.Context, reflectValue, values[idx])
}
}
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
sch = nil
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
field.Set(db.Statement.Context, reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
for idx, joinField := range joinFields {
if joinField[0] != nil {
relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
continue
return
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(db.Statement.Context, relValue, values[idx])
}
}
}
joinField[1].Set(db.Statement.Context, relValue, values[idx])
}
}
}
@ -180,7 +159,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
reflectValue = db.Statement.ReflectValue
)
for reflectValue.Kind() == reflect.Interface {
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
@ -199,6 +178,17 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if len(columns) == 1 {
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
@ -219,14 +209,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
values[idx] = &sql.RawBytes{}
}
}
if len(columns) == 1 {
// isPluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
}
@ -260,7 +242,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if isPtr {
@ -276,7 +258,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))