From 98ad29f2c24bd5c358355c8daacf575dd888d6ca Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 16 Feb 2020 13:45:27 +0800 Subject: [PATCH] Add Selects, Omits for statement --- chainable_api.go | 72 ++++++++++++++++++++++++++++++++++--------- clause/select.go | 12 ++++---- clause/select_test.go | 2 +- dialects/mysql/go.mod | 7 ----- go.mod | 4 +-- helpers.go | 5 +++ statement.go | 2 ++ 7 files changed, 73 insertions(+), 31 deletions(-) delete mode 100644 dialects/mysql/go.mod diff --git a/chainable_api.go b/chainable_api.go index 432026cf..9aa08b54 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "strings" "github.com/jinzhu/gorm/clause" ) @@ -31,9 +32,7 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { } if len(whereConds) > 0 { - tx.Statement.AddClause(&clause.Where{ - tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), - }) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...)}) } return } @@ -48,38 +47,83 @@ func (db *DB) Table(name string) (tx *DB) { // Select specify fields that you want when querying, creating, updating func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() + + switch v := query.(type) { + case []string: + tx.Statement.Selects = v + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + return + } + } + case string: + fields := strings.FieldsFunc(v, isChar) + + // normal field names + if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + tx.Statement.Selects = fields + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.Statement.AddClause(clause.Select{ + Expression: clause.Expr{SQL: v, Vars: args}, + }) + return + } + } + } else { + tx.Statement.AddClause(clause.Select{ + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + } + return } // Omit specify fields that you want to ignore when creating, updating and querying func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() + + if len(columns) == 1 && strings.Contains(columns[0], ",") { + tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) + } else { + tx.Statement.Omits = columns + } return } func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(&clause.Where{ - tx.Statement.BuildCondtion(query, args...), - }) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) return } // Not add NOT condition func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(&clause.Where{ - []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}, - }) + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) return } // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(&clause.Where{ - []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}, - }) + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) return } @@ -110,11 +154,11 @@ func (db *DB) Order(value interface{}) (tx *DB) { switch v := value.(type) { case clause.OrderByColumn: - db.Statement.AddClause(clause.OrderBy{ + tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{v}, }) default: - db.Statement.AddClause(clause.OrderBy{ + tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{{ Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, }}, diff --git a/clause/select.go b/clause/select.go index 4bb1af8d..20b17e07 100644 --- a/clause/select.go +++ b/clause/select.go @@ -2,8 +2,8 @@ package clause // Select select attrs when querying, updating, creating type Select struct { - Columns []Column - Omits []Column + Columns []Column + Expression Expression } func (s Select) Name() string { @@ -24,9 +24,9 @@ func (s Select) Build(builder Builder) { } func (s Select) MergeClause(clause *Clause) { - if v, ok := clause.Expression.(Select); ok { - s.Columns = append(v.Columns, s.Columns...) - s.Omits = append(v.Omits, s.Omits...) + if s.Expression != nil { + clause.Expression = s.Expression + } else { + clause.Expression = s } - clause.Expression = s } diff --git a/clause/select_test.go b/clause/select_test.go index 8255e51b..0863d086 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -29,7 +29,7 @@ func TestSelect(t *testing.T) { }, clause.Select{ Columns: []clause.Column{{Name: "name"}}, }, clause.From{}}, - "SELECT `users`.`id`,`name` FROM `users`", nil, + "SELECT `name` FROM `users`", nil, }, } diff --git a/dialects/mysql/go.mod b/dialects/mysql/go.mod deleted file mode 100644 index a1f29122..00000000 --- a/dialects/mysql/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module github.com/jinzhu/gorm/dialects/mysql - -go 1.13 - -require ( - github.com/go-sql-driver/mysql v1.5.0 -) diff --git a/go.mod b/go.mod index e47297fb..cdb7e574 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,6 @@ module github.com/jinzhu/gorm go 1.13 require ( - github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 - github.com/lib/pq v1.3.0 // indirect - github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/jinzhu/now v1.1.1 ) diff --git a/helpers.go b/helpers.go index 77bbece8..2e5c8ed1 100644 --- a/helpers.go +++ b/helpers.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "time" + "unicode" ) var ( @@ -27,3 +28,7 @@ type Model struct { UpdatedAt time.Time DeletedAt *time.Time `gorm:"index"` } + +func isChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) +} diff --git a/statement.go b/statement.go index 1c3934c1..b2626d95 100644 --- a/statement.go +++ b/statement.go @@ -43,6 +43,8 @@ type Statement struct { Model interface{} Dest interface{} Clauses map[string]clause.Clause + Selects []string // selected columns + Omits []string // omit columns Settings sync.Map DB *DB Schema *schema.Schema