Merge pull request #282 from jnfeinstein/master

Add additional methods of specifying the 'select' portion of a query.
This commit is contained in:
Jinzhu 2014-11-18 20:34:24 +08:00
commit f43456fecf
6 changed files with 74 additions and 10 deletions

View File

@ -609,6 +609,12 @@ db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user)
```go ```go
db.Select("name, age").Find(&users) db.Select("name, age").Find(&users)
//// SELECT name, age FROM users; //// SELECT name, age FROM users;
db.Select([]string{"name", "age"}).Find(&users)
//// SELECT name, age FROM users;
db.Table("users").Select("COALESCE(age,?)", 42).Rows()
//// SELECT COALESCE(age,'42') FROM users;
``` ```
## Order ## Order

View File

@ -125,8 +125,8 @@ func (s *DB) Order(value string, reorder ...bool) *DB {
return s.clone().search.order(value, reorder...).db return s.clone().search.order(value, reorder...).db
} }
func (s *DB) Select(value interface{}) *DB { func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.selects(value).db return s.clone().search.selects(query, args...).db
} }
func (s *DB) Group(query string) *DB { func (s *DB) Group(query string) *DB {

View File

@ -6,6 +6,7 @@ import (
"github.com/jinzhu/now" "github.com/jinzhu/now"
"math/rand"
"testing" "testing"
"time" "time"
) )
@ -537,3 +538,25 @@ func TestSelectWithEscapedFieldName(t *testing.T) {
t.Errorf("Expected 3 name, but got: %d", len(names)) t.Errorf("Expected 3 name, but got: %d", len(names))
} }
} }
func TestSelectWithVariables(t *testing.T) {
DB.Save(&User{Name: "jinzhu"})
randomNum := rand.Intn(1000000000)
rows, _ := DB.Table("users").Select("? as fake", randomNum).Where("fake = ?", randomNum).Rows()
if !rows.Next() {
t.Errorf("Should have returned at least one row")
}
}
func TestSelectWithArrayInput(t *testing.T) {
DB.Save(&User{Name: "jinzhu", Age: 42})
var user User
DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user)
if user.Name != "jinzhu" || user.Age != 42 {
t.Errorf("Should have selected both age and name")
}
}

View File

@ -129,6 +129,34 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
return return
} }
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
str = value
case []string:
str = strings.Join(value, ", ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice:
values := reflect.ValueOf(arg)
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) where(where ...interface{}) { func (scope *Scope) where(where ...interface{}) {
if len(where) > 0 { if len(where) > 0 {
scope.Search = scope.Search.clone().where(where[0], where[1:]...) scope.Search = scope.Search.clone().where(where[0], where[1:]...)
@ -180,11 +208,18 @@ func (scope *Scope) whereSql() (sql string) {
} }
func (s *Scope) selectSql() string { func (s *Scope) selectSql() string {
if len(s.Search.Select) == 0 { if len(s.Search.Selects) == 0 {
return "*" return "*"
} else {
return s.Search.Select
} }
var selectQueries []string
for _, clause := range s.Search.Selects {
selectQueries = append(selectQueries, s.buildSelectQuery(clause))
}
return strings.Join(selectQueries, ", ")
} }
func (s *Scope) orderSql() string { func (s *Scope) orderSql() string {

View File

@ -12,7 +12,7 @@ type search struct {
HavingCondition map[string]interface{} HavingCondition map[string]interface{}
Orders []string Orders []string
Joins string Joins string
Select string Selects []map[string]interface{}
Offset string Offset string
Limit string Limit string
Group string Group string
@ -30,7 +30,7 @@ func (s *search) clone() *search {
AssignAttrs: s.AssignAttrs, AssignAttrs: s.AssignAttrs,
HavingCondition: s.HavingCondition, HavingCondition: s.HavingCondition,
Orders: s.Orders, Orders: s.Orders,
Select: s.Select, Selects: s.Selects,
Offset: s.Offset, Offset: s.Offset,
Limit: s.Limit, Limit: s.Limit,
Unscope: s.Unscope, Unscope: s.Unscope,
@ -75,8 +75,8 @@ func (s *search) order(value string, reorder ...bool) *search {
return s return s
} }
func (s *search) selects(value interface{}) *search { func (s *search) selects(query interface{}, args ...interface{}) *search {
s.Select = s.getInterfaceAsSql(value) s.Selects = append(s.Selects, map[string]interface{}{"query": query, "args": args})
return s return s
} }

View File

@ -24,7 +24,7 @@ func TestCloneSearch(t *testing.T) {
t.Errorf("InitAttrs should be copied") t.Errorf("InitAttrs should be copied")
} }
if reflect.DeepEqual(s.Select, s1.Select) { if reflect.DeepEqual(s.Selects, s1.Selects) {
t.Errorf("selectStr should be copied") t.Errorf("selectStr should be copied")
} }
} }