Fix Scan with interface

This commit is contained in:
Jinzhu 2021-09-17 14:04:19 +08:00
parent 61b018cb94
commit 12bbde89e6
4 changed files with 58 additions and 12 deletions

View File

@ -506,7 +506,12 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.ValueOf(dest) tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr { for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() elem := tx.Statement.ReflectValue.Elem()
if !elem.IsValid() {
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
tx.Statement.ReflectValue.Set(elem)
}
tx.Statement.ReflectValue = elem
} }
Scan(rows, tx, true) Scan(rows, tx, true)
return tx.Error return tx.Error

20
scan.go
View File

@ -97,11 +97,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
default: default:
Schema := db.Statement.Schema Schema := db.Statement.Schema
reflectValue := db.Statement.ReflectValue
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch db.Statement.ReflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var ( var (
reflectValueType = db.Statement.ReflectValue.Type().Elem() reflectValueType = reflectValue.Type().Elem()
isPtr = reflectValueType.Kind() == reflect.Ptr isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns)) fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field joinFields [][2]*schema.Field
@ -111,7 +115,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
reflectValueType = reflectValueType.Elem() reflectValueType = reflectValueType.Elem()
} }
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
if Schema != nil { if Schema != nil {
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
@ -186,13 +190,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
if isPtr { if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem))
} else { } else {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem.Elem()))
} }
} }
case reflect.Struct, reflect.Ptr: case reflect.Struct, reflect.Ptr:
if db.Statement.ReflectValue.Type() != Schema.ModelType { if reflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
} }
@ -220,11 +224,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for idx, column := range columns { for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable { if field := Schema.LookUpField(column); field != nil && field.Readable {
field.Set(db.Statement.ReflectValue, values[idx]) field.Set(reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) relValue := rel.Field.ReflectValueOf(reflectValue)
value := reflect.ValueOf(values[idx]).Elem() value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {

View File

@ -77,7 +77,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
modelType := reflect.ValueOf(dest).Type() modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem() modelType = modelType.Elem()
} }

View File

@ -29,8 +29,9 @@ func TestScan(t *testing.T) {
} }
var resPointer *result var resPointer *result
DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer) if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil {
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { t.Fatalf("Failed to query with pointer of value, got error %v", err)
} else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
} }
@ -70,6 +71,38 @@ func TestScan(t *testing.T) {
if uint(id) != user2.ID { if uint(id) != user2.ID {
t.Errorf("Failed to scan to customized data type") t.Errorf("Failed to scan to customized data type")
} }
var resInt interface{}
resInt = &User{}
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil {
t.Fatalf("Failed to query with pointer of value, got error %v", err)
} else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age {
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3)
}
var resInt2 interface{}
resInt2 = &User{}
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil {
t.Fatalf("Failed to query with pointer of value, got error %v", err)
} else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age {
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3)
}
var resInt3 interface{}
resInt3 = []User{}
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil {
t.Fatalf("Failed to query with pointer of value, got error %v", err)
} else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age {
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3)
}
var resInt4 interface{}
resInt4 = []User{}
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil {
t.Fatalf("Failed to query with pointer of value, got error %v", err)
} else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age {
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3)
}
} }
func TestScanRows(t *testing.T) { func TestScanRows(t *testing.T) {