forked from mirror/gorm
Support scan into map, slice, struct
This commit is contained in:
parent
1403ee70c3
commit
b0e1bccf4a
|
@ -1,7 +1,6 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
@ -22,25 +21,7 @@ func Query(db *gorm.DB) {
|
|||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
values := make([]interface{}, len(columns))
|
||||
|
||||
for idx, column := range columns {
|
||||
if field, ok := db.Statement.Schema.FieldsByDBName[column]; ok {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
} else {
|
||||
values[idx] = sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
db.RowsAffected++
|
||||
rows.Scan(values...)
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
|
||||
db.AddError(gorm.ErrRecordNotFound)
|
||||
}
|
||||
Scan(rows, db)
|
||||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||
columns, _ := rows.Columns()
|
||||
values := make([]interface{}, len(columns))
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
for idx, _ := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
rows.Scan(values...)
|
||||
}
|
||||
|
||||
mapValue, ok := dest.(map[string]interface{})
|
||||
if ok {
|
||||
if v, ok := dest.(*map[string]interface{}); ok {
|
||||
mapValue = *v
|
||||
}
|
||||
}
|
||||
|
||||
for idx, column := range columns {
|
||||
mapValue[column] = *(values[idx].(*interface{}))
|
||||
}
|
||||
case *[]map[string]interface{}:
|
||||
for idx, _ := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
db.RowsAffected++
|
||||
rows.Scan(values...)
|
||||
|
||||
v := map[string]interface{}{}
|
||||
for idx, column := range columns {
|
||||
v[column] = *(values[idx].(*interface{}))
|
||||
}
|
||||
*dest = append(*dest, v)
|
||||
}
|
||||
default:
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
||||
|
||||
for rows.Next() {
|
||||
elem := reflect.New(db.Statement.Schema.ModelType).Elem()
|
||||
for idx, column := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil {
|
||||
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
|
||||
} else if db.RowsAffected == 0 {
|
||||
values[idx] = sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
||||
} else {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
for idx, column := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
} else {
|
||||
values[idx] = sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
|
||||
db.AddError(gorm.ErrRecordNotFound)
|
||||
}
|
||||
}
|
|
@ -26,7 +26,6 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
|
|||
// TODO handle where
|
||||
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
Desc: true,
|
||||
})
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = out
|
||||
|
@ -47,6 +46,7 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) {
|
|||
func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
Desc: true,
|
||||
})
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = out
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package schema_test
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -13,7 +12,7 @@ import (
|
|||
|
||||
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
|
||||
t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
|
||||
tests.AssertEqual(t, s, v, "Name", "Table")
|
||||
tests.AssertObjEqual(t, s, v, "Name", "Table")
|
||||
|
||||
for idx, field := range primaryFields {
|
||||
var found bool
|
||||
|
@ -53,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
|
|||
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
|
||||
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
|
||||
} else {
|
||||
tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
|
||||
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
|
||||
|
||||
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
|
||||
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
|
||||
|
@ -195,39 +194,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
|||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||
for k, v := range values {
|
||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||
var (
|
||||
checker func(fv interface{}, v interface{})
|
||||
field = s.FieldsByDBName[k]
|
||||
fv, _ = field.ValueOf(value)
|
||||
)
|
||||
|
||||
checker = func(fv interface{}, v interface{}) {
|
||||
if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v {
|
||||
t.Errorf("expects: %p, but got %p", v, fv)
|
||||
} else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) {
|
||||
if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv {
|
||||
t.Errorf("expects: %p, but got %p", v, fv)
|
||||
}
|
||||
} else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) {
|
||||
if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v {
|
||||
t.Errorf("expects: %p, but got %p", v, fv)
|
||||
}
|
||||
} else if valuer, isValuer := fv.(driver.Valuer); isValuer {
|
||||
valuerv, _ := valuer.Value()
|
||||
checker(valuerv, v)
|
||||
} else if valuer, isValuer := v.(driver.Valuer); isValuer {
|
||||
valuerv, _ := valuer.Value()
|
||||
checker(fv, valuerv)
|
||||
} else if reflect.ValueOf(fv).Kind() == reflect.Ptr {
|
||||
checker(reflect.ValueOf(fv).Elem().Interface(), v)
|
||||
} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
|
||||
checker(fv, reflect.ValueOf(v).Elem().Interface())
|
||||
} else {
|
||||
t.Errorf("expects: %+v, but got %+v", v, fv)
|
||||
}
|
||||
}
|
||||
|
||||
checker(fv, v)
|
||||
fv, _ := s.FieldsByDBName[k].ValueOf(value)
|
||||
tests.AssertEqual(t, v, fv)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"log"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -14,6 +17,7 @@ func Now() *time.Time {
|
|||
|
||||
func RunTestsSuit(t *testing.T, db *gorm.DB) {
|
||||
TestCreate(t, db)
|
||||
TestFind(t, db)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T, db *gorm.DB) {
|
||||
|
@ -38,7 +42,94 @@ func TestCreate(t *testing.T, db *gorm.DB) {
|
|||
if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
t.Errorf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertEqual(t, newUser, user, "Name", "Age", "Birthday")
|
||||
AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFind(t *testing.T, db *gorm.DB) {
|
||||
db.Migrator().DropTable(&User{})
|
||||
db.AutoMigrate(&User{})
|
||||
|
||||
t.Run("Find", func(t *testing.T) {
|
||||
var users = []User{{
|
||||
Name: "find",
|
||||
Age: 1,
|
||||
Birthday: Now(),
|
||||
}, {
|
||||
Name: "find",
|
||||
Age: 2,
|
||||
Birthday: Now(),
|
||||
}, {
|
||||
Name: "find",
|
||||
Age: 3,
|
||||
Birthday: Now(),
|
||||
}}
|
||||
|
||||
if err := db.Create(&users).Error; err != nil {
|
||||
t.Errorf("errors happened when create users: %v", err)
|
||||
}
|
||||
|
||||
t.Run("First", func(t *testing.T) {
|
||||
var first User
|
||||
if err := db.Where("name = ?", "find").First(&first).Error; err != nil {
|
||||
t.Errorf("errors happened when query first: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Last", func(t *testing.T) {
|
||||
var last User
|
||||
if err := db.Where("name = ?", "find").Last(&last).Error; err != nil {
|
||||
t.Errorf("errors happened when query last: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday")
|
||||
}
|
||||
})
|
||||
|
||||
var all []User
|
||||
if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 {
|
||||
t.Errorf("errors happened when query find: %v, length: %v", err, len(all))
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("FirstMap", func(t *testing.T) {
|
||||
var first = map[string]interface{}{}
|
||||
if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil {
|
||||
t.Errorf("errors happened when query first: %v", err)
|
||||
} else {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := db.NamingStrategy.ColumnName("", name)
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
var allMap = []map[string]interface{}{}
|
||||
if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||
t.Errorf("errors happened when query first: %v", err)
|
||||
} else {
|
||||
log.Printf("all map %+v %+v", len(allMap), allMap)
|
||||
for idx, user := range users {
|
||||
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := db.NamingStrategy.ColumnName("", name)
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,24 +6,43 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func AssertEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||
for _, name := range names {
|
||||
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
|
||||
expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
||||
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
||||
t.Run(name, func(t *testing.T) {
|
||||
AssertEqual(t, got, expect)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, expects) {
|
||||
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
|
||||
expects = reflect.Indirect(reflect.ValueOf(got)).Interface()
|
||||
func AssertEqual(t *testing.T, got, expect interface{}) {
|
||||
if !reflect.DeepEqual(got, expect) {
|
||||
isEqual := func() {
|
||||
if curTime, ok := got.(time.Time); ok {
|
||||
format := "2006-01-02T15:04:05Z07:00"
|
||||
if curTime.Format(format) != expects.(time.Time).Format(format) {
|
||||
t.Errorf("expects: %v, got %v", expects.(time.Time).Format(format), curTime.Format(format))
|
||||
if curTime.Format(format) != expect.(time.Time).Format(format) {
|
||||
t.Errorf("expect: %v, got %v", expect.(time.Time).Format(format), curTime.Format(format))
|
||||
}
|
||||
} else {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Errorf("expects: %v, got %v", expects, got)
|
||||
})
|
||||
} else if got != expect {
|
||||
t.Errorf("expect: %#v, got %#v", expect, got)
|
||||
}
|
||||
}
|
||||
|
||||
if got != nil {
|
||||
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
|
||||
}
|
||||
|
||||
if expect != nil {
|
||||
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
||||
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
||||
isEqual()
|
||||
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
|
||||
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
|
||||
isEqual()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"runtime"
|
||||
)
|
||||
|
||||
var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`)
|
||||
var goTestRegexp = regexp.MustCompile(`/gorm/.*test.go`)
|
||||
var goSrcRegexp = regexp.MustCompile(`/gorm/.*\.go`)
|
||||
var goTestRegexp = regexp.MustCompile(`/gorm/.*test\.go`)
|
||||
|
||||
func FileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
|
|
Loading…
Reference in New Issue