From 6951be0284135a5ecd6f359eb4d173b8fb35e572 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 28 Apr 2021 17:19:30 +0800 Subject: [PATCH] Allow customize clauses --- callbacks.go | 15 +++++++++++++-- callbacks/callbacks.go | 36 ++++++++++++++++++++++++++++++++++-- callbacks/create.go | 4 ++-- callbacks/delete.go | 2 +- callbacks/query.go | 2 +- callbacks/update.go | 2 +- statement.go | 1 + 7 files changed, 53 insertions(+), 9 deletions(-) diff --git a/callbacks.go b/callbacks.go index 20fec429..01d9ed30 100644 --- a/callbacks.go +++ b/callbacks.go @@ -32,6 +32,7 @@ type callbacks struct { type processor struct { db *DB + Clauses []string fns []func(*DB) callbacks []*callback } @@ -82,10 +83,16 @@ func (p *processor) Execute(db *DB) { } var ( - curTime = time.Now() - stmt = db.Statement + curTime = time.Now() + stmt = db.Statement + resetBuildClauses bool ) + if len(stmt.BuildClauses) == 0 { + stmt.BuildClauses = p.Clauses + resetBuildClauses = true + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest @@ -131,6 +138,10 @@ func (p *processor) Execute(db *DB) { stmt.SQL.Reset() stmt.Vars = nil } + + if resetBuildClauses { + stmt.BuildClauses = nil + } } func (p *processor) Get(name string) func(*DB) { diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 7bb27318..d85c1928 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -4,9 +4,20 @@ import ( "gorm.io/gorm" ) +var ( + createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} + queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + updateClauses = []string{"UPDATE", "SET", "WHERE"} + deleteClauses = []string{"DELETE", "FROM", "WHERE"} +) + type Config struct { LastInsertIDReversed bool WithReturning bool + CreateClauses []string + QueryClauses []string + UpdateClauses []string + DeleteClauses []string } func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { @@ -22,11 +33,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) @@ -35,6 +54,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) @@ -45,7 +68,16 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + updateCallback.Clauses = config.UpdateClauses - db.Callback().Row().Register("gorm:row", RowQuery) - db.Callback().Raw().Register("gorm:raw", RawExec) + rowCallback := db.Callback().Row() + rowCallback.Register("gorm:row", RowQuery) + rowCallback.Clauses = config.QueryClauses + + rawCallback := db.Callback().Raw() + rawCallback.Register("gorm:raw", RawExec) + rawCallback.Clauses = config.QueryClauses } diff --git a/callbacks/create.go b/callbacks/create.go index 909d984a..727bd380 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -47,7 +47,7 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + db.Statement.Build(db.Statement.BuildClauses...) } if !db.DryRun && db.Error == nil { @@ -118,7 +118,7 @@ func CreateWithReturning(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + db.Statement.Build(db.Statement.BuildClauses...) } if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { diff --git a/callbacks/delete.go b/callbacks/delete.go index 64dd7236..91659c51 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -135,7 +135,7 @@ func Delete(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("DELETE", "FROM", "WHERE") + db.Statement.Build(db.Statement.BuildClauses...) } if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { diff --git a/callbacks/query.go b/callbacks/query.go index 11753472..d0341284 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -167,7 +167,7 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clauseSelect) - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + db.Statement.Build(db.Statement.BuildClauses...) } } diff --git a/callbacks/update.go b/callbacks/update.go index db5b52fb..75bb02db 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -66,7 +66,7 @@ func Update(db *gorm.DB) { } else { return } - db.Statement.Build("UPDATE", "SET", "WHERE") + db.Statement.Build(db.Statement.BuildClauses...) } if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { diff --git a/statement.go b/statement.go index 2734752d..a87fd212 100644 --- a/statement.go +++ b/statement.go @@ -27,6 +27,7 @@ type Statement struct { Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause + BuildClauses []string Distinct bool Selects []string // selected columns Omits []string // omit columns