diff --git a/scan.go b/scan.go index e83390ca..d7b58e03 100644 --- a/scan.go +++ b/scan.go @@ -50,58 +50,37 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { - for idx, column := range columns { - if sch == nil { - values[idx] = reflectValue.Interface() - } else if field := sch.LookUpField(column); field != nil && field.Readable { +func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, field := range fields { + if field != nil { values[idx] = field.NewValuePool.Get() defer field.NewValuePool.Put(values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = field.NewValuePool.Get() - defer field.NewValuePool.Put(values[idx]) - continue - } + if len(joinFields) == 0 || joinFields[idx][0] == nil { + defer field.Set(db.Statement.Context, reflectValue, values[idx]) } - values[idx] = &sql.RawBytes{} - } else if len(columns) == 1 { - sch = nil + } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() } else { values[idx] = reflectValue.Interface() } - } else { - values[idx] = &sql.RawBytes{} } } db.RowsAffected++ db.AddError(rows.Scan(values...)) - if sch != nil { - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.Context, reflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(db.Statement.Context, relValue, values[idx]) - } + for idx, joinField := range joinFields { + if joinField[0] != nil { + relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + return } + + relValue.Set(reflect.New(relValue.Type().Elem())) } + joinField[1].Set(db.Statement.Context, relValue, values[idx]) } } } @@ -180,7 +159,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { reflectValue = db.Statement.ReflectValue ) - for reflectValue.Kind() == reflect.Interface { + if reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } @@ -199,35 +178,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - if len(columns) == 1 { - // isPluck + // Is Pluck if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch = nil } } + + // Not Pluck + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } } switch reflectValue.Kind() { @@ -260,7 +242,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflect.New(reflectValueType) } - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { if isPtr { @@ -276,7 +258,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: db.AddError(rows.Scan(dest))