diff --git a/callbacks/query.go b/callbacks/query.go index 4b7f5bd5..9601f9bd 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -64,14 +64,24 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } - } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() && db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType { - stmt := gorm.Statement{DB: db} - // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil { - clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + smallerStruct := false + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } - for idx, dbName := range stmt.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Name: dbName} + if smallerStruct { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } } } } diff --git a/tests/query_test.go b/tests/query_test.go index 1db490b7..62005e3a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -177,6 +177,24 @@ func TestFillSmallerStruct(t *testing.T) { if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]*User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } } func TestNot(t *testing.T) {