diff --git a/callbacks/query.go b/callbacks/query.go index 27d53a4d..4b7f5bd5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} - if db.Statement.ReflectValue.Kind() == reflect.Struct { + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { @@ -64,6 +64,16 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() && db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } + } } // inline joins diff --git a/scan.go b/scan.go index 2d227ec2..0b199029 100644 --- a/scan.go +++ b/scan.go @@ -69,6 +69,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(dest)) } default: + Schema := db.Statement.Schema + switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( @@ -84,16 +86,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) - if db.Statement.Schema != nil { + if Schema != nil { + if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { if len(joinFields) == 0 { joinFields = make([][2]*schema.Field, len(columns)) } - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field joinFields[idx] = [2]*schema.Field{rel.Field, field} @@ -151,12 +157,16 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } case reflect.Struct: + if db.Statement.ReflectValue.Type() != Schema.ModelType { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + if initialized || rows.Next() { for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue @@ -172,10 +182,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(values...)) for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { value := reflect.ValueOf(values[idx]).Elem() diff --git a/tests/query_test.go b/tests/query_test.go index de65b63b..7973fd51 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -3,6 +3,7 @@ package tests_test import ( "fmt" "reflect" + "regexp" "sort" "strconv" "testing" @@ -144,8 +145,8 @@ func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) type SimpleUser struct { - Name string ID int64 + Name string UpdatedAt time.Time CreatedAt time.Time } @@ -156,6 +157,26 @@ func TestFillSmallerStruct(t *testing.T) { } AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt") + + var simpleUser2 SimpleUser + if err := DB.Model(&User{}).Select("id").First(&simpleUser2, user.ID).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser2, "ID") + + var simpleUsers []SimpleUser + if err := DB.Model(&User{}).Select("id").Find(&simpleUsers, user.ID).Error; err != nil || len(simpleUsers) != 1 { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUsers[0], "ID") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) + + if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } } func TestPluck(t *testing.T) {