forked from mirror/gorm
Fix Scan with interface
This commit is contained in:
parent
61b018cb94
commit
12bbde89e6
|
@ -506,7 +506,12 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
|||
tx.Statement.Dest = dest
|
||||
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
||||
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
|
||||
elem := tx.Statement.ReflectValue.Elem()
|
||||
if !elem.IsValid() {
|
||||
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
|
||||
tx.Statement.ReflectValue.Set(elem)
|
||||
}
|
||||
tx.Statement.ReflectValue = elem
|
||||
}
|
||||
Scan(rows, tx, true)
|
||||
return tx.Error
|
||||
|
|
20
scan.go
20
scan.go
|
@ -97,11 +97,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
default:
|
||||
Schema := db.Statement.Schema
|
||||
reflectValue := db.Statement.ReflectValue
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
reflectValueType = db.Statement.ReflectValue.Type().Elem()
|
||||
reflectValueType = reflectValue.Type().Elem()
|
||||
isPtr = reflectValueType.Kind() == reflect.Ptr
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
|
@ -111,7 +115,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
|
||||
if Schema != nil {
|
||||
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||
|
@ -186,13 +190,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
|
||||
if isPtr {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
|
||||
db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem))
|
||||
} else {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
|
||||
db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem.Elem()))
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if db.Statement.ReflectValue.Type() != Schema.ModelType {
|
||||
if reflectValue.Type() != Schema.ModelType {
|
||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
|
@ -220,11 +224,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
field.Set(db.Statement.ReflectValue, values[idx])
|
||||
field.Set(reflectValue, values[idx])
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
relValue := rel.Field.ReflectValueOf(reflectValue)
|
||||
value := reflect.ValueOf(values[idx]).Elem()
|
||||
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
|
|
|
@ -77,7 +77,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
|
||||
if modelType.Kind() == reflect.Interface {
|
||||
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
|
||||
}
|
||||
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
|
|
@ -29,8 +29,9 @@ func TestScan(t *testing.T) {
|
|||
}
|
||||
|
||||
var resPointer *result
|
||||
DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer)
|
||||
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
|
||||
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil {
|
||||
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||
} else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
|
||||
}
|
||||
|
||||
|
@ -70,6 +71,38 @@ func TestScan(t *testing.T) {
|
|||
if uint(id) != user2.ID {
|
||||
t.Errorf("Failed to scan to customized data type")
|
||||
}
|
||||
|
||||
var resInt interface{}
|
||||
resInt = &User{}
|
||||
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil {
|
||||
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||
} else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3)
|
||||
}
|
||||
|
||||
var resInt2 interface{}
|
||||
resInt2 = &User{}
|
||||
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil {
|
||||
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||
} else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3)
|
||||
}
|
||||
|
||||
var resInt3 interface{}
|
||||
resInt3 = []User{}
|
||||
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil {
|
||||
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||
} else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3)
|
||||
}
|
||||
|
||||
var resInt4 interface{}
|
||||
resInt4 = []User{}
|
||||
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil {
|
||||
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||
} else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age {
|
||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRows(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue