mirror of https://github.com/go-gorm/gorm.git
Fix Scan with interface
This commit is contained in:
parent
61b018cb94
commit
12bbde89e6
|
@ -506,7 +506,12 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||||
tx.Statement.Dest = dest
|
tx.Statement.Dest = dest
|
||||||
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
||||||
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||||
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
|
elem := tx.Statement.ReflectValue.Elem()
|
||||||
|
if !elem.IsValid() {
|
||||||
|
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
|
||||||
|
tx.Statement.ReflectValue.Set(elem)
|
||||||
|
}
|
||||||
|
tx.Statement.ReflectValue = elem
|
||||||
}
|
}
|
||||||
Scan(rows, tx, true)
|
Scan(rows, tx, true)
|
||||||
return tx.Error
|
return tx.Error
|
||||||
|
|
20
scan.go
20
scan.go
|
@ -97,11 +97,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
Schema := db.Statement.Schema
|
Schema := db.Statement.Schema
|
||||||
|
reflectValue := db.Statement.ReflectValue
|
||||||
|
if reflectValue.Kind() == reflect.Interface {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var (
|
var (
|
||||||
reflectValueType = db.Statement.ReflectValue.Type().Elem()
|
reflectValueType = reflectValue.Type().Elem()
|
||||||
isPtr = reflectValueType.Kind() == reflect.Ptr
|
isPtr = reflectValueType.Kind() == reflect.Ptr
|
||||||
fields = make([]*schema.Field, len(columns))
|
fields = make([]*schema.Field, len(columns))
|
||||||
joinFields [][2]*schema.Field
|
joinFields [][2]*schema.Field
|
||||||
|
@ -111,7 +115,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
reflectValueType = reflectValueType.Elem()
|
reflectValueType = reflectValueType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||||
|
|
||||||
if Schema != nil {
|
if Schema != nil {
|
||||||
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
|
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||||
|
@ -186,13 +190,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if isPtr {
|
if isPtr {
|
||||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
|
db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem))
|
||||||
} else {
|
} else {
|
||||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
|
db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem.Elem()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct, reflect.Ptr:
|
case reflect.Struct, reflect.Ptr:
|
||||||
if db.Statement.ReflectValue.Type() != Schema.ModelType {
|
if reflectValue.Type() != Schema.ModelType {
|
||||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,11 +224,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
|
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
field.Set(db.Statement.ReflectValue, values[idx])
|
field.Set(reflectValue, values[idx])
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
relValue := rel.Field.ReflectValueOf(reflectValue)
|
||||||
value := reflect.ValueOf(values[idx]).Elem()
|
value := reflect.ValueOf(values[idx]).Elem()
|
||||||
|
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||||
|
|
|
@ -77,7 +77,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.ValueOf(dest).Type()
|
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
|
||||||
|
if modelType.Kind() == reflect.Interface {
|
||||||
|
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
|
||||||
|
}
|
||||||
|
|
||||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,8 +29,9 @@ func TestScan(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var resPointer *result
|
var resPointer *result
|
||||||
DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer)
|
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil {
|
||||||
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
|
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||||
|
} else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) {
|
||||||
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,6 +71,38 @@ func TestScan(t *testing.T) {
|
||||||
if uint(id) != user2.ID {
|
if uint(id) != user2.ID {
|
||||||
t.Errorf("Failed to scan to customized data type")
|
t.Errorf("Failed to scan to customized data type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var resInt interface{}
|
||||||
|
resInt = &User{}
|
||||||
|
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||||
|
} else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age {
|
||||||
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resInt2 interface{}
|
||||||
|
resInt2 = &User{}
|
||||||
|
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||||
|
} else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age {
|
||||||
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resInt3 interface{}
|
||||||
|
resInt3 = []User{}
|
||||||
|
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||||
|
} else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age {
|
||||||
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resInt4 interface{}
|
||||||
|
resInt4 = []User{}
|
||||||
|
if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to query with pointer of value, got error %v", err)
|
||||||
|
} else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age {
|
||||||
|
t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanRows(t *testing.T) {
|
func TestScanRows(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue