Add SelectAttrs, OmitAttrs

This commit is contained in:
Jinzhu 2015-03-12 15:50:38 +08:00
parent 94adc3e1d8
commit da7830ea50
4 changed files with 28 additions and 9 deletions

View File

@ -152,7 +152,7 @@ func (s *DB) Order(value string, reorder ...bool) *DB {
} }
func (s *DB) Select(query interface{}, args ...interface{}) *DB { func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.Selects(query, args...).db return s.clone().search.Select(query, args...).db
} }
func (s *DB) Group(query string) *DB { func (s *DB) Group(query string) *DB {

View File

@ -358,7 +358,7 @@ func (scope *Scope) initialize() *Scope {
func (scope *Scope) pluck(column string, value interface{}) *Scope { func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value)) dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search.Selects(column) scope.Search.Select(column)
if dest.Kind() != reflect.Slice { if dest.Kind() != reflect.Slice {
scope.Err(errors.New("results should be a slice")) scope.Err(errors.New("results should be a slice"))
return scope return scope
@ -377,7 +377,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
} }
func (scope *Scope) count(value interface{}) *Scope { func (scope *Scope) count(value interface{}) *Scope {
scope.Search.Selects("count(*)") scope.Search.Select("count(*)")
scope.Err(scope.row().Scan(value)) scope.Err(scope.row().Scan(value))
return scope return scope
} }

View File

@ -11,6 +11,7 @@ type search struct {
initAttrs []interface{} initAttrs []interface{}
assignAttrs []interface{} assignAttrs []interface{}
selects map[string]interface{} selects map[string]interface{}
omits []string
orders []string orders []string
joins string joins string
preload map[string][]interface{} preload map[string][]interface{}
@ -18,8 +19,8 @@ type search struct {
limit string limit string
group string group string
tableName string tableName string
Unscoped bool
raw bool raw bool
Unscoped bool
} }
func (s *search) clone() *search { func (s *search) clone() *search {
@ -32,14 +33,15 @@ func (s *search) clone() *search {
initAttrs: s.initAttrs, initAttrs: s.initAttrs,
assignAttrs: s.assignAttrs, assignAttrs: s.assignAttrs,
selects: s.selects, selects: s.selects,
omits: s.omits,
orders: s.orders, orders: s.orders,
joins: s.joins, joins: s.joins,
offset: s.offset, offset: s.offset,
limit: s.limit, limit: s.limit,
group: s.group, group: s.group,
tableName: s.tableName, tableName: s.tableName,
Unscoped: s.Unscoped,
raw: s.raw, raw: s.raw,
Unscoped: s.Unscoped,
} }
} }
@ -77,11 +79,28 @@ func (s *search) Order(value string, reorder ...bool) *search {
return s return s
} }
func (s *search) Selects(query interface{}, args ...interface{}) *search { func (s *search) Select(query interface{}, args ...interface{}) *search {
s.selects = map[string]interface{}{"query": query, "args": args} s.selects = map[string]interface{}{"query": query, "args": args}
return s return s
} }
func (s *search) Omit(columns ...string) *search {
s.omits = columns
return s
}
func (s *search) SelectAttrs() (attrs []string) {
for key, value := range s.selects {
attrs = append(attrs, key)
attrs = append(attrs, value.([]string)...)
}
return attrs
}
func (s *search) OmitAttrs() []string {
return s.omits
}
func (s *search) Limit(value interface{}) *search { func (s *search) Limit(value interface{}) *search {
s.limit = s.getInterfaceAsSql(value) s.limit = s.getInterfaceAsSql(value)
return s return s

View File

@ -7,10 +7,10 @@ import (
func TestCloneSearch(t *testing.T) { func TestCloneSearch(t *testing.T) {
s := new(search) s := new(search)
s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Selects("name, age") s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age")
s1 := s.clone() s1 := s.clone()
s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Selects("email") s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email")
if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { if reflect.DeepEqual(s.whereConditions, s1.whereConditions) {
t.Errorf("Where should be copied") t.Errorf("Where should be copied")
@ -24,7 +24,7 @@ func TestCloneSearch(t *testing.T) {
t.Errorf("InitAttrs should be copied") t.Errorf("InitAttrs should be copied")
} }
if reflect.DeepEqual(s.Selects, s1.Selects) { if reflect.DeepEqual(s.Select, s1.Select) {
t.Errorf("selectStr should be copied") t.Errorf("selectStr should be copied")
} }
} }