From 12bbde89e683d85181b0344ff71f44d3148bf9cd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Sep 2021 14:04:19 +0800 Subject: [PATCH] Fix Scan with interface --- finisher_api.go | 7 ++++++- scan.go | 20 ++++++++++++-------- schema/schema.go | 6 +++++- tests/scan_test.go | 37 +++++++++++++++++++++++++++++++++++-- 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 741a9456..d273093f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -506,7 +506,12 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx.Statement.Dest = dest tx.Statement.ReflectValue = reflect.ValueOf(dest) for tx.Statement.ReflectValue.Kind() == reflect.Ptr { - tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + elem := tx.Statement.ReflectValue.Elem() + if !elem.IsValid() { + elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) + tx.Statement.ReflectValue.Set(elem) + } + tx.Statement.ReflectValue = elem } Scan(rows, tx, true) return tx.Error diff --git a/scan.go b/scan.go index 2beecd45..20bdde9e 100644 --- a/scan.go +++ b/scan.go @@ -97,11 +97,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } default: Schema := db.Statement.Schema + reflectValue := db.Statement.ReflectValue + if reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } - switch db.Statement.ReflectValue.Kind() { + switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var ( - reflectValueType = db.Statement.ReflectValue.Type().Elem() + reflectValueType = reflectValue.Type().Elem() isPtr = reflectValueType.Kind() == reflect.Ptr fields = make([]*schema.Field, len(columns)) joinFields [][2]*schema.Field @@ -111,7 +115,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { reflectValueType = reflectValueType.Elem() } - db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) if Schema != nil { if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { @@ -186,13 +190,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } if isPtr { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem)) } else { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) + db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem.Elem())) } } case reflect.Struct, reflect.Ptr: - if db.Statement.ReflectValue.Type() != Schema.ModelType { + if reflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } @@ -220,11 +224,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.ReflectValue, values[idx]) + field.Set(reflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + relValue := rel.Field.ReflectValueOf(reflectValue) value := reflect.ValueOf(values[idx]).Elem() if relValue.Kind() == reflect.Ptr && relValue.IsNil() { diff --git a/schema/schema.go b/schema/schema.go index faba2e21..c425070b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -77,7 +77,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - modelType := reflect.ValueOf(dest).Type() + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } diff --git a/tests/scan_test.go b/tests/scan_test.go index 67d5f385..aacad827 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -29,8 +29,9 @@ func TestScan(t *testing.T) { } var resPointer *result - DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer) - if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } @@ -70,6 +71,38 @@ func TestScan(t *testing.T) { if uint(id) != user2.ID { t.Errorf("Failed to scan to customized data type") } + + var resInt interface{} + resInt = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3) + } + + var resInt2 interface{} + resInt2 = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3) + } + + var resInt3 interface{} + resInt3 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3) + } + + var resInt4 interface{} + resInt4 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3) + } } func TestScanRows(t *testing.T) {