diff --git a/README.md b/README.md index dba65d5f..c5e06edd 100644 --- a/README.md +++ b/README.md @@ -609,6 +609,12 @@ db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user) ```go db.Select("name, age").Find(&users) //// SELECT name, age FROM users; + +db.Select([]string{"name", "age"}).Find(&users) +//// SELECT name, age FROM users; + +db.Table("users").Select("COALESCE(age,?)", 42).Rows() +//// SELECT COALESCE(age,'42') FROM users; ``` ## Order diff --git a/main.go b/main.go index 14a76b19..ec974cae 100644 --- a/main.go +++ b/main.go @@ -125,8 +125,8 @@ func (s *DB) Order(value string, reorder ...bool) *DB { return s.clone().search.order(value, reorder...).db } -func (s *DB) Select(value interface{}) *DB { - return s.clone().search.selects(value).db +func (s *DB) Select(query interface{}, args ...interface{}) *DB { + return s.clone().search.selects(query, args...).db } func (s *DB) Group(query string) *DB { diff --git a/query_test.go b/query_test.go index 5867612f..d4a68aa1 100644 --- a/query_test.go +++ b/query_test.go @@ -6,6 +6,7 @@ import ( "github.com/jinzhu/now" + "math/rand" "testing" "time" ) @@ -537,3 +538,25 @@ func TestSelectWithEscapedFieldName(t *testing.T) { t.Errorf("Expected 3 name, but got: %d", len(names)) } } + +func TestSelectWithVariables(t *testing.T) { + DB.Save(&User{Name: "jinzhu"}) + + randomNum := rand.Intn(1000000000) + rows, _ := DB.Table("users").Select("? as fake", randomNum).Where("fake = ?", randomNum).Rows() + + if !rows.Next() { + t.Errorf("Should have returned at least one row") + } +} + +func TestSelectWithArrayInput(t *testing.T) { + DB.Save(&User{Name: "jinzhu", Age: 42}) + + var user User + DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) + + if user.Name != "jinzhu" || user.Age != 42 { + t.Errorf("Should have selected both age and name") + } +} diff --git a/scope_private.go b/scope_private.go index b9004efd..0e2308e9 100644 --- a/scope_private.go +++ b/scope_private.go @@ -129,6 +129,34 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string return } +func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { + switch value := clause["query"].(type) { + case string: + str = value + case []string: + str = strings.Join(value, ", ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.TypeOf(arg).Kind() { + case reflect.Slice: + values := reflect.ValueOf(arg) + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() + } + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } + } + return +} + func (scope *Scope) where(where ...interface{}) { if len(where) > 0 { scope.Search = scope.Search.clone().where(where[0], where[1:]...) @@ -180,11 +208,18 @@ func (scope *Scope) whereSql() (sql string) { } func (s *Scope) selectSql() string { - if len(s.Search.Select) == 0 { + if len(s.Search.Selects) == 0 { return "*" - } else { - return s.Search.Select } + + var selectQueries []string + + for _, clause := range s.Search.Selects { + selectQueries = append(selectQueries, s.buildSelectQuery(clause)) + } + + return strings.Join(selectQueries, ", ") + } func (s *Scope) orderSql() string { diff --git a/search.go b/search.go index 8591d659..78ed5300 100644 --- a/search.go +++ b/search.go @@ -12,7 +12,7 @@ type search struct { HavingCondition map[string]interface{} Orders []string Joins string - Select string + Selects []map[string]interface{} Offset string Limit string Group string @@ -30,7 +30,7 @@ func (s *search) clone() *search { AssignAttrs: s.AssignAttrs, HavingCondition: s.HavingCondition, Orders: s.Orders, - Select: s.Select, + Selects: s.Selects, Offset: s.Offset, Limit: s.Limit, Unscope: s.Unscope, @@ -75,8 +75,8 @@ func (s *search) order(value string, reorder ...bool) *search { return s } -func (s *search) selects(value interface{}) *search { - s.Select = s.getInterfaceAsSql(value) +func (s *search) selects(query interface{}, args ...interface{}) *search { + s.Selects = append(s.Selects, map[string]interface{}{"query": query, "args": args}) return s } diff --git a/search_test.go b/search_test.go index 4e19d531..2d0af08b 100644 --- a/search_test.go +++ b/search_test.go @@ -24,7 +24,7 @@ func TestCloneSearch(t *testing.T) { t.Errorf("InitAttrs should be copied") } - if reflect.DeepEqual(s.Select, s1.Select) { + if reflect.DeepEqual(s.Selects, s1.Selects) { t.Errorf("selectStr should be copied") } }