From 429a10085676fdd0fff56169b3b9e42dddcbd3b1 Mon Sep 17 00:00:00 2001 From: jnfeinstein Date: Mon, 17 Nov 2014 07:12:32 -0500 Subject: [PATCH] Add additional methods of specifying the 'select' portion of a query. This commit adds more ways of specifying selects: -) You can now pass in a []string. This is mostly for convenience, since you may want to dynamically create a list of fields to be selected. -) You can now use variables. This is important because a select could take user input. For example, finding a MAX between a record and a given number could be easily done using select, and then you don't have to process anything in backend logic. This is also necessary to use postgres text search capabilities (which actaully play nicely with the rest of gorm). -) You can now chain select calls. This could be useful in conjunction with gorm's scopes functionality. --- README.md | 6 ++++++ main.go | 4 ++-- query_test.go | 23 +++++++++++++++++++++++ scope_private.go | 41 ++++++++++++++++++++++++++++++++++++++--- search.go | 8 ++++---- search_test.go | 2 +- 6 files changed, 74 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index dba65d5f..c5e06edd 100644 --- a/README.md +++ b/README.md @@ -609,6 +609,12 @@ db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user) ```go db.Select("name, age").Find(&users) //// SELECT name, age FROM users; + +db.Select([]string{"name", "age"}).Find(&users) +//// SELECT name, age FROM users; + +db.Table("users").Select("COALESCE(age,?)", 42).Rows() +//// SELECT COALESCE(age,'42') FROM users; ``` ## Order diff --git a/main.go b/main.go index 14a76b19..ec974cae 100644 --- a/main.go +++ b/main.go @@ -125,8 +125,8 @@ func (s *DB) Order(value string, reorder ...bool) *DB { return s.clone().search.order(value, reorder...).db } -func (s *DB) Select(value interface{}) *DB { - return s.clone().search.selects(value).db +func (s *DB) Select(query interface{}, args ...interface{}) *DB { + return s.clone().search.selects(query, args...).db } func (s *DB) Group(query string) *DB { diff --git a/query_test.go b/query_test.go index 5867612f..d4a68aa1 100644 --- a/query_test.go +++ b/query_test.go @@ -6,6 +6,7 @@ import ( "github.com/jinzhu/now" + "math/rand" "testing" "time" ) @@ -537,3 +538,25 @@ func TestSelectWithEscapedFieldName(t *testing.T) { t.Errorf("Expected 3 name, but got: %d", len(names)) } } + +func TestSelectWithVariables(t *testing.T) { + DB.Save(&User{Name: "jinzhu"}) + + randomNum := rand.Intn(1000000000) + rows, _ := DB.Table("users").Select("? as fake", randomNum).Where("fake = ?", randomNum).Rows() + + if !rows.Next() { + t.Errorf("Should have returned at least one row") + } +} + +func TestSelectWithArrayInput(t *testing.T) { + DB.Save(&User{Name: "jinzhu", Age: 42}) + + var user User + DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) + + if user.Name != "jinzhu" || user.Age != 42 { + t.Errorf("Should have selected both age and name") + } +} diff --git a/scope_private.go b/scope_private.go index b9004efd..0e2308e9 100644 --- a/scope_private.go +++ b/scope_private.go @@ -129,6 +129,34 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string return } +func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { + switch value := clause["query"].(type) { + case string: + str = value + case []string: + str = strings.Join(value, ", ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.TypeOf(arg).Kind() { + case reflect.Slice: + values := reflect.ValueOf(arg) + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() + } + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } + } + return +} + func (scope *Scope) where(where ...interface{}) { if len(where) > 0 { scope.Search = scope.Search.clone().where(where[0], where[1:]...) @@ -180,11 +208,18 @@ func (scope *Scope) whereSql() (sql string) { } func (s *Scope) selectSql() string { - if len(s.Search.Select) == 0 { + if len(s.Search.Selects) == 0 { return "*" - } else { - return s.Search.Select } + + var selectQueries []string + + for _, clause := range s.Search.Selects { + selectQueries = append(selectQueries, s.buildSelectQuery(clause)) + } + + return strings.Join(selectQueries, ", ") + } func (s *Scope) orderSql() string { diff --git a/search.go b/search.go index 8591d659..78ed5300 100644 --- a/search.go +++ b/search.go @@ -12,7 +12,7 @@ type search struct { HavingCondition map[string]interface{} Orders []string Joins string - Select string + Selects []map[string]interface{} Offset string Limit string Group string @@ -30,7 +30,7 @@ func (s *search) clone() *search { AssignAttrs: s.AssignAttrs, HavingCondition: s.HavingCondition, Orders: s.Orders, - Select: s.Select, + Selects: s.Selects, Offset: s.Offset, Limit: s.Limit, Unscope: s.Unscope, @@ -75,8 +75,8 @@ func (s *search) order(value string, reorder ...bool) *search { return s } -func (s *search) selects(value interface{}) *search { - s.Select = s.getInterfaceAsSql(value) +func (s *search) selects(query interface{}, args ...interface{}) *search { + s.Selects = append(s.Selects, map[string]interface{}{"query": query, "args": args}) return s } diff --git a/search_test.go b/search_test.go index 4e19d531..2d0af08b 100644 --- a/search_test.go +++ b/search_test.go @@ -24,7 +24,7 @@ func TestCloneSearch(t *testing.T) { t.Errorf("InitAttrs should be copied") } - if reflect.DeepEqual(s.Select, s1.Select) { + if reflect.DeepEqual(s.Selects, s1.Selects) { t.Errorf("selectStr should be copied") } }