Fix auto select with smaller struct for slices

This commit is contained in:
Jinzhu 2020-07-09 12:15:35 +08:00
parent 0790ff6937
commit a8655f7947
2 changed files with 35 additions and 7 deletions

View File

@ -64,10 +64,19 @@ func BuildQuerySQL(db *gorm.DB) {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
} }
} }
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() && db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType { } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
smallerStruct := false
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
}
if smallerStruct {
stmt := gorm.Statement{DB: db} stmt := gorm.Statement{DB: db}
// smaller struct // smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil { if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
for idx, dbName := range stmt.Schema.DBNames { for idx, dbName := range stmt.Schema.DBNames {
@ -75,6 +84,7 @@ func BuildQuerySQL(db *gorm.DB) {
} }
} }
} }
}
// inline joins // inline joins
if len(db.Statement.Joins) != 0 { if len(db.Statement.Joins) != 0 {

View File

@ -177,6 +177,24 @@ func TestFillSmallerStruct(t *testing.T) {
if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) {
t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String())
} }
result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&User{}, user.ID)
if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) {
t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String())
}
result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]User{}, user.ID)
if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) {
t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String())
}
result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]*User{}, user.ID)
if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) {
t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String())
}
} }
func TestNot(t *testing.T) { func TestNot(t *testing.T) {