Support scan into map, slice, struct

This commit is contained in:
Jinzhu 2020-03-04 11:32:36 +08:00
parent 1403ee70c3
commit b0e1bccf4a
7 changed files with 228 additions and 71 deletions

View File

@ -1,7 +1,6 @@
package callbacks
import (
"database/sql"
"reflect"
"github.com/jinzhu/gorm"
@ -22,25 +21,7 @@ func Query(db *gorm.DB) {
}
defer rows.Close()
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
for idx, column := range columns {
if field, ok := db.Statement.Schema.FieldsByDBName[column]; ok {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else {
values[idx] = sql.RawBytes{}
}
}
for rows.Next() {
db.RowsAffected++
rows.Scan(values...)
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
db.AddError(gorm.ErrRecordNotFound)
}
Scan(rows, db)
}
func Preload(db *gorm.DB) {

98
callbacks/scan.go Normal file
View File

@ -0,0 +1,98 @@
package callbacks
import (
"database/sql"
"reflect"
"github.com/jinzhu/gorm"
)
func Scan(rows *sql.Rows, db *gorm.DB) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
for idx, _ := range columns {
values[idx] = new(interface{})
}
if rows.Next() {
db.RowsAffected++
rows.Scan(values...)
}
mapValue, ok := dest.(map[string]interface{})
if ok {
if v, ok := dest.(*map[string]interface{}); ok {
mapValue = *v
}
}
for idx, column := range columns {
mapValue[column] = *(values[idx].(*interface{}))
}
case *[]map[string]interface{}:
for idx, _ := range columns {
values[idx] = new(interface{})
}
for rows.Next() {
db.RowsAffected++
rows.Scan(values...)
v := map[string]interface{}{}
for idx, column := range columns {
v[column] = *(values[idx].(*interface{}))
}
*dest = append(*dest, v)
}
default:
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
for rows.Next() {
elem := reflect.New(db.Statement.Schema.ModelType).Elem()
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if db.RowsAffected == 0 {
values[idx] = sql.RawBytes{}
}
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
} else {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
}
}
case reflect.Struct:
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else {
values[idx] = sql.RawBytes{}
}
}
if rows.Next() {
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
}
}
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
db.AddError(gorm.ErrRecordNotFound)
}
}

View File

@ -26,7 +26,6 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
// TODO handle where
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true,
})
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = out
@ -47,6 +46,7 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) {
func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true,
})
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = out

View File

@ -1,7 +1,6 @@
package schema_test
import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
@ -13,7 +12,7 @@ import (
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
tests.AssertEqual(t, s, v, "Name", "Table")
tests.AssertObjEqual(t, s, v, "Name", "Table")
for idx, field := range primaryFields {
var found bool
@ -53,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
} else {
tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
@ -195,39 +194,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) {
var (
checker func(fv interface{}, v interface{})
field = s.FieldsByDBName[k]
fv, _ = field.ValueOf(value)
)
checker = func(fv interface{}, v interface{}) {
if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v {
t.Errorf("expects: %p, but got %p", v, fv)
} else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) {
if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv {
t.Errorf("expects: %p, but got %p", v, fv)
}
} else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) {
if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v {
t.Errorf("expects: %p, but got %p", v, fv)
}
} else if valuer, isValuer := fv.(driver.Valuer); isValuer {
valuerv, _ := valuer.Value()
checker(valuerv, v)
} else if valuer, isValuer := v.(driver.Valuer); isValuer {
valuerv, _ := valuer.Value()
checker(fv, valuerv)
} else if reflect.ValueOf(fv).Kind() == reflect.Ptr {
checker(reflect.ValueOf(fv).Elem().Interface(), v)
} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
checker(fv, reflect.ValueOf(v).Elem().Interface())
} else {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
}
checker(fv, v)
fv, _ := s.FieldsByDBName[k].ValueOf(value)
tests.AssertEqual(t, v, fv)
})
}
}

View File

@ -1,6 +1,9 @@
package tests
import (
"log"
"reflect"
"strconv"
"testing"
"time"
@ -14,6 +17,7 @@ func Now() *time.Time {
func RunTestsSuit(t *testing.T, db *gorm.DB) {
TestCreate(t, db)
TestFind(t, db)
}
func TestCreate(t *testing.T, db *gorm.DB) {
@ -38,7 +42,94 @@ func TestCreate(t *testing.T, db *gorm.DB) {
if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Errorf("errors happened when query: %v", err)
} else {
AssertEqual(t, newUser, user, "Name", "Age", "Birthday")
AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday")
}
})
}
func TestFind(t *testing.T, db *gorm.DB) {
db.Migrator().DropTable(&User{})
db.AutoMigrate(&User{})
t.Run("Find", func(t *testing.T) {
var users = []User{{
Name: "find",
Age: 1,
Birthday: Now(),
}, {
Name: "find",
Age: 2,
Birthday: Now(),
}, {
Name: "find",
Age: 3,
Birthday: Now(),
}}
if err := db.Create(&users).Error; err != nil {
t.Errorf("errors happened when create users: %v", err)
}
t.Run("First", func(t *testing.T) {
var first User
if err := db.Where("name = ?", "find").First(&first).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday")
}
})
t.Run("Last", func(t *testing.T) {
var last User
if err := db.Where("name = ?", "find").Last(&last).Error; err != nil {
t.Errorf("errors happened when query last: %v", err)
} else {
AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday")
}
})
var all []User
if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 {
t.Errorf("errors happened when query find: %v, length: %v", err, len(all))
} else {
for idx, user := range users {
t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) {
AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday")
})
}
}
t.Run("FirstMap", func(t *testing.T) {
var first = map[string]interface{}{}
if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := db.NamingStrategy.ColumnName("", name)
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
})
}
}
})
var allMap = []map[string]interface{}{}
if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
log.Printf("all map %+v %+v", len(allMap), allMap)
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := db.NamingStrategy.ColumnName("", name)
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
}
})
}
}
})
}

View File

@ -6,24 +6,43 @@ import (
"time"
)
func AssertEqual(t *testing.T, r, e interface{}, names ...string) {
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
for _, name := range names {
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
t.Run(name, func(t *testing.T) {
AssertEqual(t, got, expect)
})
}
}
if !reflect.DeepEqual(got, expects) {
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
expects = reflect.Indirect(reflect.ValueOf(got)).Interface()
func AssertEqual(t *testing.T, got, expect interface{}) {
if !reflect.DeepEqual(got, expect) {
isEqual := func() {
if curTime, ok := got.(time.Time); ok {
format := "2006-01-02T15:04:05Z07:00"
if curTime.Format(format) != expects.(time.Time).Format(format) {
t.Errorf("expects: %v, got %v", expects.(time.Time).Format(format), curTime.Format(format))
if curTime.Format(format) != expect.(time.Time).Format(format) {
t.Errorf("expect: %v, got %v", expect.(time.Time).Format(format), curTime.Format(format))
}
} else {
t.Run(name, func(t *testing.T) {
t.Errorf("expects: %v, got %v", expects, got)
})
} else if got != expect {
t.Errorf("expect: %#v, got %#v", expect, got)
}
}
if got != nil {
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
}
if expect != nil {
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
}
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
isEqual()
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
isEqual()
}
}
}

View File

@ -6,8 +6,8 @@ import (
"runtime"
)
var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`)
var goTestRegexp = regexp.MustCompile(`/gorm/.*test.go`)
var goSrcRegexp = regexp.MustCompile(`/gorm/.*\.go`)
var goTestRegexp = regexp.MustCompile(`/gorm/.*test\.go`)
func FileWithLineNum() string {
for i := 2; i < 15; i++ {