mirror of https://github.com/go-gorm/gorm.git
Query with smaller struct
This commit is contained in:
parent
ee1f46e3a1
commit
9075b33620
|
@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
db.Statement.SQL.Grow(100)
|
db.Statement.SQL.Grow(100)
|
||||||
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
||||||
|
|
||||||
if db.Statement.ReflectValue.Kind() == reflect.Struct {
|
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
||||||
var conds []clause.Expression
|
var conds []clause.Expression
|
||||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||||
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
|
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
|
||||||
|
@ -64,6 +64,16 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() && db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType {
|
||||||
|
stmt := gorm.Statement{DB: db}
|
||||||
|
// smaller struct
|
||||||
|
if err := stmt.Parse(db.Statement.Dest); err == nil {
|
||||||
|
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
|
||||||
|
|
||||||
|
for idx, dbName := range stmt.Schema.DBNames {
|
||||||
|
clauseSelect.Columns[idx] = clause.Column{Name: dbName}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// inline joins
|
// inline joins
|
||||||
|
|
24
scan.go
24
scan.go
|
@ -69,6 +69,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
db.AddError(rows.Scan(dest))
|
db.AddError(rows.Scan(dest))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
Schema := db.Statement.Schema
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var (
|
var (
|
||||||
|
@ -84,16 +86,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
|
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
||||||
|
|
||||||
if db.Statement.Schema != nil {
|
if Schema != nil {
|
||||||
|
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||||
|
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
|
}
|
||||||
|
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if len(joinFields) == 0 {
|
if len(joinFields) == 0 {
|
||||||
joinFields = make([][2]*schema.Field, len(columns))
|
joinFields = make([][2]*schema.Field, len(columns))
|
||||||
}
|
}
|
||||||
|
|
||||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||||
|
@ -151,12 +157,16 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
if db.Statement.ReflectValue.Type() != Schema.ModelType {
|
||||||
|
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
|
}
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||||
continue
|
continue
|
||||||
|
@ -172,10 +182,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
field.Set(db.Statement.ReflectValue, values[idx])
|
field.Set(db.Statement.ReflectValue, values[idx])
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||||
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
value := reflect.ValueOf(values[idx]).Elem()
|
value := reflect.ValueOf(values[idx]).Elem()
|
||||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -144,8 +145,8 @@ func TestFillSmallerStruct(t *testing.T) {
|
||||||
user := User{Name: "SmallerUser", Age: 100}
|
user := User{Name: "SmallerUser", Age: 100}
|
||||||
DB.Save(&user)
|
DB.Save(&user)
|
||||||
type SimpleUser struct {
|
type SimpleUser struct {
|
||||||
Name string
|
|
||||||
ID int64
|
ID int64
|
||||||
|
Name string
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
@ -156,6 +157,26 @@ func TestFillSmallerStruct(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt")
|
AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt")
|
||||||
|
|
||||||
|
var simpleUser2 SimpleUser
|
||||||
|
if err := DB.Model(&User{}).Select("id").First(&simpleUser2, user.ID).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to query smaller user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertObjEqual(t, user, simpleUser2, "ID")
|
||||||
|
|
||||||
|
var simpleUsers []SimpleUser
|
||||||
|
if err := DB.Model(&User{}).Select("id").Find(&simpleUsers, user.ID).Error; err != nil || len(simpleUsers) != 1 {
|
||||||
|
t.Fatalf("Failed to query smaller user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertObjEqual(t, user, simpleUsers[0], "ID")
|
||||||
|
|
||||||
|
result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID)
|
||||||
|
|
||||||
|
if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPluck(t *testing.T) {
|
func TestPluck(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue