Refactor Scan

This commit is contained in:
Jinzhu 2022-02-27 22:54:43 +08:00
parent 530b0a12b4
commit 43a72b369e
1 changed files with 43 additions and 61 deletions

104
scan.go
View File

@ -50,58 +50,37 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
} }
} }
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, column := range columns { for idx, field := range fields {
if sch == nil { if field != nil {
values[idx] = reflectValue.Interface()
} else if field := sch.LookUpField(column); field != nil && field.Readable {
values[idx] = field.NewValuePool.Get() values[idx] = field.NewValuePool.Get()
defer field.NewValuePool.Put(values[idx]) defer field.NewValuePool.Put(values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 { if len(joinFields) == 0 || joinFields[idx][0] == nil {
if rel, ok := sch.Relationships.Relations[names[0]]; ok { defer field.Set(db.Statement.Context, reflectValue, values[idx])
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = field.NewValuePool.Get()
defer field.NewValuePool.Put(values[idx])
continue
}
} }
values[idx] = &sql.RawBytes{} } else if len(fields) == 1 {
} else if len(columns) == 1 {
sch = nil
if reflectValue.CanAddr() { if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface() values[idx] = reflectValue.Addr().Interface()
} else { } else {
values[idx] = reflectValue.Interface() values[idx] = reflectValue.Interface()
} }
} else {
values[idx] = &sql.RawBytes{}
} }
} }
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
if sch != nil { for idx, joinField := range joinFields {
for idx, column := range columns { if joinField[0] != nil {
if field := sch.LookUpField(column); field != nil && field.Readable { relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue)
field.Set(db.Statement.Context, reflectValue, values[idx]) if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
} else if names := strings.Split(column, "__"); len(names) > 1 { if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
if rel, ok := sch.Relationships.Relations[names[0]]; ok { return
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(db.Statement.Context, relValue, values[idx])
}
} }
relValue.Set(reflect.New(relValue.Type().Elem()))
} }
joinField[1].Set(db.Statement.Context, relValue, values[idx])
} }
} }
} }
@ -180,7 +159,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
reflectValue = db.Statement.ReflectValue reflectValue = db.Statement.ReflectValue
) )
for reflectValue.Kind() == reflect.Interface { if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem() reflectValue = reflectValue.Elem()
} }
@ -199,35 +178,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
} }
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
if len(columns) == 1 { if len(columns) == 1 {
// isPluck // Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil sch = nil
} }
} }
// Not Pluck
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
} }
switch reflectValue.Kind() { switch reflectValue.Kind() {
@ -260,7 +242,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflect.New(reflectValueType) elem = reflect.New(reflectValueType)
} }
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update { if !update {
if isPtr { if isPtr {
@ -276,7 +258,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
} }
case reflect.Struct, reflect.Ptr: case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() { if initialized || rows.Next() {
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
} }
default: default:
db.AddError(rows.Scan(dest)) db.AddError(rows.Scan(dest))