Add additional methods of specifying the 'select' portion of a query.

This commit adds more ways of specifying selects:

-) You can now pass in a []string.  This is mostly for convenience,
since you may want to dynamically create a list of fields to be
selected.

-) You can now use variables.  This is important because a select
could take user input.  For example, finding a MAX between a record
and a given number could be easily done using select, and then
you don't have to process anything in backend logic.  This is also
necessary to use postgres text search capabilities (which actaully
play nicely with the rest of gorm).

-) You can now chain select calls.  This could be useful in
conjunction with gorm's scopes functionality.
This commit is contained in:
jnfeinstein 2014-11-17 07:12:32 -05:00
parent 62b447bc9a
commit 429a100856
6 changed files with 74 additions and 10 deletions

View File

@ -609,6 +609,12 @@ db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user)
```go ```go
db.Select("name, age").Find(&users) db.Select("name, age").Find(&users)
//// SELECT name, age FROM users; //// SELECT name, age FROM users;
db.Select([]string{"name", "age"}).Find(&users)
//// SELECT name, age FROM users;
db.Table("users").Select("COALESCE(age,?)", 42).Rows()
//// SELECT COALESCE(age,'42') FROM users;
``` ```
## Order ## Order

View File

@ -125,8 +125,8 @@ func (s *DB) Order(value string, reorder ...bool) *DB {
return s.clone().search.order(value, reorder...).db return s.clone().search.order(value, reorder...).db
} }
func (s *DB) Select(value interface{}) *DB { func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.selects(value).db return s.clone().search.selects(query, args...).db
} }
func (s *DB) Group(query string) *DB { func (s *DB) Group(query string) *DB {

View File

@ -6,6 +6,7 @@ import (
"github.com/jinzhu/now" "github.com/jinzhu/now"
"math/rand"
"testing" "testing"
"time" "time"
) )
@ -537,3 +538,25 @@ func TestSelectWithEscapedFieldName(t *testing.T) {
t.Errorf("Expected 3 name, but got: %d", len(names)) t.Errorf("Expected 3 name, but got: %d", len(names))
} }
} }
func TestSelectWithVariables(t *testing.T) {
DB.Save(&User{Name: "jinzhu"})
randomNum := rand.Intn(1000000000)
rows, _ := DB.Table("users").Select("? as fake", randomNum).Where("fake = ?", randomNum).Rows()
if !rows.Next() {
t.Errorf("Should have returned at least one row")
}
}
func TestSelectWithArrayInput(t *testing.T) {
DB.Save(&User{Name: "jinzhu", Age: 42})
var user User
DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user)
if user.Name != "jinzhu" || user.Age != 42 {
t.Errorf("Should have selected both age and name")
}
}

View File

@ -129,6 +129,34 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
return return
} }
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
str = value
case []string:
str = strings.Join(value, ", ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice:
values := reflect.ValueOf(arg)
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) where(where ...interface{}) { func (scope *Scope) where(where ...interface{}) {
if len(where) > 0 { if len(where) > 0 {
scope.Search = scope.Search.clone().where(where[0], where[1:]...) scope.Search = scope.Search.clone().where(where[0], where[1:]...)
@ -180,11 +208,18 @@ func (scope *Scope) whereSql() (sql string) {
} }
func (s *Scope) selectSql() string { func (s *Scope) selectSql() string {
if len(s.Search.Select) == 0 { if len(s.Search.Selects) == 0 {
return "*" return "*"
} else {
return s.Search.Select
} }
var selectQueries []string
for _, clause := range s.Search.Selects {
selectQueries = append(selectQueries, s.buildSelectQuery(clause))
}
return strings.Join(selectQueries, ", ")
} }
func (s *Scope) orderSql() string { func (s *Scope) orderSql() string {

View File

@ -12,7 +12,7 @@ type search struct {
HavingCondition map[string]interface{} HavingCondition map[string]interface{}
Orders []string Orders []string
Joins string Joins string
Select string Selects []map[string]interface{}
Offset string Offset string
Limit string Limit string
Group string Group string
@ -30,7 +30,7 @@ func (s *search) clone() *search {
AssignAttrs: s.AssignAttrs, AssignAttrs: s.AssignAttrs,
HavingCondition: s.HavingCondition, HavingCondition: s.HavingCondition,
Orders: s.Orders, Orders: s.Orders,
Select: s.Select, Selects: s.Selects,
Offset: s.Offset, Offset: s.Offset,
Limit: s.Limit, Limit: s.Limit,
Unscope: s.Unscope, Unscope: s.Unscope,
@ -75,8 +75,8 @@ func (s *search) order(value string, reorder ...bool) *search {
return s return s
} }
func (s *search) selects(value interface{}) *search { func (s *search) selects(query interface{}, args ...interface{}) *search {
s.Select = s.getInterfaceAsSql(value) s.Selects = append(s.Selects, map[string]interface{}{"query": query, "args": args})
return s return s
} }

View File

@ -24,7 +24,7 @@ func TestCloneSearch(t *testing.T) {
t.Errorf("InitAttrs should be copied") t.Errorf("InitAttrs should be copied")
} }
if reflect.DeepEqual(s.Select, s1.Select) { if reflect.DeepEqual(s.Selects, s1.Selects) {
t.Errorf("selectStr should be copied") t.Errorf("selectStr should be copied")
} }
} }