From 46b1c85f88e332a36dec31b17a3bd8e6eae07da9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 4 Feb 2020 08:56:15 +0800 Subject: [PATCH] Add more clauses --- callbacks.go | 20 +++++++++++++------- callbacks/callbacks.go | 6 ++++-- callbacks/create.go | 2 +- callbacks/query.go | 17 ++++++++++++++++- chainable_api.go | 19 ++++++++++++++++++- clause/clause.go | 31 +++++++++++-------------------- clause/expression.go | 25 ++++++++++++++++++------- clause/from.go | 7 +++++++ clause/on_conflict.go | 6 ++++++ clause/order_by.go | 34 ++++++++++++++++++++++++++++++++++ clause/select.go | 12 ++++++++---- finisher_api.go | 8 ++++++-- gorm.go | 9 +++++---- statement.go | 16 +++++++++++++--- 14 files changed, 160 insertions(+), 52 deletions(-) create mode 100644 clause/on_conflict.go diff --git a/callbacks.go b/callbacks.go index 51ee150f..8546ae16 100644 --- a/callbacks.go +++ b/callbacks.go @@ -69,14 +69,20 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { - if stmt := db.Statement; stmt != nil && stmt.Dest != nil { - var err error - stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy) + if stmt := db.Statement; stmt != nil { + if stmt.Model == nil { + stmt.Model = stmt.Dest + } - if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) { - db.AddError(err) - } else if stmt.Table == "" && stmt.Schema != nil { - stmt.Table = stmt.Schema.Table + if stmt.Model != nil { + var err error + stmt.Schema, err = schema.Parse(stmt.Model, db.cacheStore, db.NamingStrategy) + + if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { + db.AddError(err) + } else if stmt.Table == "" && stmt.Schema != nil { + stmt.Table = stmt.Schema.Table + } } } diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index a3e5245b..f9d5543d 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -1,6 +1,8 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "github.com/jinzhu/gorm" +) func RegisterDefaultCallbacks(db *gorm.DB) { enableTransaction := func(db *gorm.DB) bool { @@ -17,7 +19,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) queryCallback := db.Callback().Query() - queryCallback.Register("gorm:query", BeforeCreate) + queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) diff --git a/callbacks/create.go b/callbacks/create.go index 983b95ce..58256085 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -22,7 +22,7 @@ func Create(db *gorm.DB) { Table: clause.Table{Table: db.Statement.Table}, }) - db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING") + db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) fmt.Println(err) fmt.Println(result) diff --git a/callbacks/query.go b/callbacks/query.go index 5d27ea17..edf8f281 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,8 +1,23 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" +) func Query(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{ + Tables: []clause.Table{{Table: clause.CurrentTable}}, + }) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + fmt.Println(err) + fmt.Println(result) + fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } func Preload(db *gorm.DB) { diff --git a/chainable_api.go b/chainable_api.go index b577d5cf..f358d316 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -1,6 +1,10 @@ package gorm -import "github.com/jinzhu/gorm/clause" +import ( + "fmt" + + "github.com/jinzhu/gorm/clause" +) // Model specify the model you would like to run db operations // // update all users's name to `hello` @@ -107,6 +111,19 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() + + switch v := value.(type) { + case clause.OrderBy: + db.Statement.AddClause(clause.OrderByClause{ + Columns: []clause.OrderBy{v}, + }) + default: + db.Statement.AddClause(clause.OrderByClause{ + Columns: []clause.OrderBy{{ + Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, + }}, + }) + } return } diff --git a/clause/clause.go b/clause/clause.go index c0ebe7e2..6d4698e9 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -11,11 +11,6 @@ type Clause struct { Builder ClauseBuilder } -// ClauseBuilder clause builder, allows to custmize how to build clause -type ClauseBuilder interface { - Build(Clause, Builder) -} - // Build build clause func (c Clause) Build(builder Builder) { if c.Builder != nil { @@ -47,25 +42,21 @@ type Interface interface { MergeExpression(Expression) } +// OverrideNameInterface override name interface type OverrideNameInterface interface { OverrideName() string } -// Column quote with name -type Column struct { - Table string - Name string - Alias string - Raw bool +// ClauseBuilder clause builder, allows to custmize how to build clause +type ClauseBuilder interface { + Build(Clause, Builder) } -func ToColumns(value ...interface{}) []Column { - return nil -} - -// Table quote with name -type Table struct { - Table string - Alias string - Raw bool +// Builder builder interface +type Builder interface { + WriteByte(byte) error + Write(sql ...string) error + WriteQuoted(field interface{}) error + AddVar(vars ...interface{}) string + Quote(field interface{}) string } diff --git a/clause/expression.go b/clause/expression.go index 17313d43..722df7c7 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,5 +1,10 @@ package clause +const ( + PrimaryKey string = "@@@priamry_key@@@" + CurrentTable string = "@@@table@@@" +) + // Expression expression interface type Expression interface { Build(builder Builder) @@ -10,13 +15,19 @@ type NegationExpressionBuilder interface { NegationBuild(builder Builder) } -// Builder builder interface -type Builder interface { - WriteByte(byte) error - Write(sql ...string) error - WriteQuoted(field interface{}) error - AddVar(vars ...interface{}) string - Quote(field interface{}) string +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool +} + +// Table quote with name +type Table struct { + Table string + Alias string + Raw bool } // Expr raw expression diff --git a/clause/from.go b/clause/from.go index 610d69a4..1a7bcb5c 100644 --- a/clause/from.go +++ b/clause/from.go @@ -20,3 +20,10 @@ func (from From) Build(builder Builder) { builder.WriteQuoted(table) } } + +// MergeExpression merge order by clauses +func (from From) MergeExpression(expr Expression) { + if v, ok := expr.(From); ok { + from.Tables = append(v.Tables, from.Tables...) + } +} diff --git a/clause/on_conflict.go b/clause/on_conflict.go new file mode 100644 index 00000000..5cbe3dd7 --- /dev/null +++ b/clause/on_conflict.go @@ -0,0 +1,6 @@ +package clause + +type OnConflict struct { + ON string // duplicate key + Values *Values // update c=c+1 +} diff --git a/clause/order_by.go b/clause/order_by.go index a11a3c48..6025e1ba 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -1,4 +1,38 @@ package clause type OrderBy struct { + Column Column + Desc bool + Reorder bool +} + +type OrderByClause struct { + Columns []OrderBy +} + +// Name where clause name +func (orderBy OrderByClause) Name() string { + return "ORDER BY" +} + +// Build build where clause +func (orderBy OrderByClause) Build(builder Builder) { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + builder.WriteQuoted(orderBy.Columns[i].Column) + + if orderBy.Columns[i].Desc { + builder.Write(" DESC") + } + + if orderBy.Columns[i].Reorder { + break + } + } +} + +// MergeExpression merge order by clauses +func (orderBy OrderByClause) MergeExpression(expr Expression) { + if v, ok := expr.(OrderByClause); ok { + orderBy.Columns = append(v.Columns, orderBy.Columns...) + } } diff --git a/clause/select.go b/clause/select.go index 1342c411..7f0e4438 100644 --- a/clause/select.go +++ b/clause/select.go @@ -1,15 +1,19 @@ package clause +// SelectInterface select clause interface +type SelectInterface interface { + Selects() []Column + Omits() []Column +} + // Select select attrs when querying, updating, creating type Select struct { SelectColumns []Column OmitColumns []Column } -// SelectInterface select clause interface -type SelectInterface interface { - Selects() []Column - Omits() []Column +func (s Select) Name() string { + return "SELECT" } func (s Select) Selects() []Column { diff --git a/finisher_api.go b/finisher_api.go index a311ca78..06809651 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -2,6 +2,8 @@ package gorm import ( "database/sql" + + "github.com/jinzhu/gorm/clause" ) // Create insert the value into database @@ -20,9 +22,11 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() + tx = db.getInstance().Limit(1).Order(clause.OrderBy{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, + }) tx.Statement.Dest = out - tx.Limit(1) tx.callbacks.Query().Execute(tx) return } diff --git a/gorm.go b/gorm.go index a72314bd..10d61f80 100644 --- a/gorm.go +++ b/gorm.go @@ -61,10 +61,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } db = &DB{ - Config: config, - Dialector: dialector, - clone: true, - cacheStore: &sync.Map{}, + Config: config, + Dialector: dialector, + ClauseBuilders: map[string]clause.ClauseBuilder{}, + clone: true, + cacheStore: &sync.Map{}, } db.callbacks = initializeCallbacks(db) diff --git a/statement.go b/statement.go index c01be0f5..b2407599 100644 --- a/statement.go +++ b/statement.go @@ -84,18 +84,28 @@ func (stmt Statement) Quote(field interface{}) string { switch v := field.(type) { case clause.Table: - str.WriteString(v.Table) + if v.Alias != "" { str.WriteString(" AS ") str.WriteString(v.Alias) } case clause.Column: if v.Table != "" { - str.WriteString(v.Table) + if v.Table == clause.CurrentTable { + str.WriteString(stmt.Table) + } else { + str.WriteString(v.Table) + } str.WriteByte('.') } - str.WriteString(v.Name) + if v.Name == clause.PrimaryKey { + if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { + str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName) + } + } else { + str.WriteString(v.Name) + } if v.Alias != "" { str.WriteString(" AS ") str.WriteString(v.Alias)