From da7830ea506f877a4cad20b6c465c29870ae3358 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Mar 2015 15:50:38 +0800 Subject: [PATCH] Add SelectAttrs, OmitAttrs --- main.go | 2 +- scope_private.go | 4 ++-- search.go | 25 ++++++++++++++++++++++--- search_test.go | 6 +++--- 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/main.go b/main.go index 0f4af2ae..b142e275 100644 --- a/main.go +++ b/main.go @@ -152,7 +152,7 @@ func (s *DB) Order(value string, reorder ...bool) *DB { } func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Selects(query, args...).db + return s.clone().search.Select(query, args...).db } func (s *DB) Group(query string) *DB { diff --git a/scope_private.go b/scope_private.go index 8d363225..e51e1faf 100644 --- a/scope_private.go +++ b/scope_private.go @@ -358,7 +358,7 @@ func (scope *Scope) initialize() *Scope { func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Selects(column) + scope.Search.Select(column) if dest.Kind() != reflect.Slice { scope.Err(errors.New("results should be a slice")) return scope @@ -377,7 +377,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { } func (scope *Scope) count(value interface{}) *Scope { - scope.Search.Selects("count(*)") + scope.Search.Select("count(*)") scope.Err(scope.row().Scan(value)) return scope } diff --git a/search.go b/search.go index 2908bec1..47f5a6cc 100644 --- a/search.go +++ b/search.go @@ -11,6 +11,7 @@ type search struct { initAttrs []interface{} assignAttrs []interface{} selects map[string]interface{} + omits []string orders []string joins string preload map[string][]interface{} @@ -18,8 +19,8 @@ type search struct { limit string group string tableName string - Unscoped bool raw bool + Unscoped bool } func (s *search) clone() *search { @@ -32,14 +33,15 @@ func (s *search) clone() *search { initAttrs: s.initAttrs, assignAttrs: s.assignAttrs, selects: s.selects, + omits: s.omits, orders: s.orders, joins: s.joins, offset: s.offset, limit: s.limit, group: s.group, tableName: s.tableName, - Unscoped: s.Unscoped, raw: s.raw, + Unscoped: s.Unscoped, } } @@ -77,11 +79,28 @@ func (s *search) Order(value string, reorder ...bool) *search { return s } -func (s *search) Selects(query interface{}, args ...interface{}) *search { +func (s *search) Select(query interface{}, args ...interface{}) *search { s.selects = map[string]interface{}{"query": query, "args": args} return s } +func (s *search) Omit(columns ...string) *search { + s.omits = columns + return s +} + +func (s *search) SelectAttrs() (attrs []string) { + for key, value := range s.selects { + attrs = append(attrs, key) + attrs = append(attrs, value.([]string)...) + } + return attrs +} + +func (s *search) OmitAttrs() []string { + return s.omits +} + func (s *search) Limit(value interface{}) *search { s.limit = s.getInterfaceAsSql(value) return s diff --git a/search_test.go b/search_test.go index 79be3a07..4db7ab6a 100644 --- a/search_test.go +++ b/search_test.go @@ -7,10 +7,10 @@ import ( func TestCloneSearch(t *testing.T) { s := new(search) - s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Selects("name, age") + s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") s1 := s.clone() - s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Selects("email") + s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { t.Errorf("Where should be copied") @@ -24,7 +24,7 @@ func TestCloneSearch(t *testing.T) { t.Errorf("InitAttrs should be copied") } - if reflect.DeepEqual(s.Selects, s1.Selects) { + if reflect.DeepEqual(s.Select, s1.Select) { t.Errorf("selectStr should be copied") } }