Allow customize clauses

This commit is contained in:
Jinzhu 2021-04-28 17:19:30 +08:00
parent 82cb4ebfe2
commit 6951be0284
7 changed files with 53 additions and 9 deletions

View File

@ -32,6 +32,7 @@ type callbacks struct {
type processor struct { type processor struct {
db *DB db *DB
Clauses []string
fns []func(*DB) fns []func(*DB)
callbacks []*callback callbacks []*callback
} }
@ -84,8 +85,14 @@ func (p *processor) Execute(db *DB) {
var ( var (
curTime = time.Now() curTime = time.Now()
stmt = db.Statement stmt = db.Statement
resetBuildClauses bool
) )
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
// assign model values // assign model values
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
@ -131,6 +138,10 @@ func (p *processor) Execute(db *DB) {
stmt.SQL.Reset() stmt.SQL.Reset()
stmt.Vars = nil stmt.Vars = nil
} }
if resetBuildClauses {
stmt.BuildClauses = nil
}
} }
func (p *processor) Get(name string) func(*DB) { func (p *processor) Get(name string) func(*DB) {

View File

@ -4,9 +4,20 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct { type Config struct {
LastInsertIDReversed bool LastInsertIDReversed bool
WithReturning bool WithReturning bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
} }
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
@ -22,11 +33,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate) createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query() queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Register("gorm:after_query", AfterQuery)
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete() deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
@ -35,6 +54,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update() updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
@ -45,7 +68,16 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.UpdateClauses) == 0 {
db.Callback().Row().Register("gorm:row", RowQuery) config.UpdateClauses = updateClauses
db.Callback().Raw().Register("gorm:raw", RawExec) }
updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
} }

View File

@ -47,7 +47,7 @@ func Create(config *Config) func(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") db.Statement.Build(db.Statement.BuildClauses...)
} }
if !db.DryRun && db.Error == nil { if !db.DryRun && db.Error == nil {
@ -118,7 +118,7 @@ func CreateWithReturning(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") db.Statement.Build(db.Statement.BuildClauses...)
} }
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {

View File

@ -135,7 +135,7 @@ func Delete(db *gorm.DB) {
} }
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("DELETE", "FROM", "WHERE") db.Statement.Build(db.Statement.BuildClauses...)
} }
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {

View File

@ -167,7 +167,7 @@ func BuildQuerySQL(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build(db.Statement.BuildClauses...)
} }
} }

View File

@ -66,7 +66,7 @@ func Update(db *gorm.DB) {
} else { } else {
return return
} }
db.Statement.Build("UPDATE", "SET", "WHERE") db.Statement.Build(db.Statement.BuildClauses...)
} }
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {

View File

@ -27,6 +27,7 @@ type Statement struct {
Dest interface{} Dest interface{}
ReflectValue reflect.Value ReflectValue reflect.Value
Clauses map[string]clause.Clause Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool Distinct bool
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns