Allow customize clauses

This commit is contained in:
Jinzhu 2021-04-28 17:19:30 +08:00
parent 82cb4ebfe2
commit 6951be0284
7 changed files with 53 additions and 9 deletions

View File

@ -32,6 +32,7 @@ type callbacks struct {
type processor struct {
db *DB
Clauses []string
fns []func(*DB)
callbacks []*callback
}
@ -84,8 +85,14 @@ func (p *processor) Execute(db *DB) {
var (
curTime = time.Now()
stmt = db.Statement
resetBuildClauses bool
)
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
@ -131,6 +138,10 @@ func (p *processor) Execute(db *DB) {
stmt.SQL.Reset()
stmt.Vars = nil
}
if resetBuildClauses {
stmt.BuildClauses = nil
}
}
func (p *processor) Get(name string) func(*DB) {

View File

@ -4,9 +4,20 @@ import (
"gorm.io/gorm"
)
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct {
LastInsertIDReversed bool
WithReturning bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
@ -22,11 +33,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
@ -35,6 +54,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
@ -45,7 +68,16 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
updateCallback.Clauses = config.UpdateClauses
db.Callback().Row().Register("gorm:row", RowQuery)
db.Callback().Raw().Register("gorm:raw", RawExec)
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
}

View File

@ -47,7 +47,7 @@ func Create(config *Config) func(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
db.Statement.Build(db.Statement.BuildClauses...)
}
if !db.DryRun && db.Error == nil {
@ -118,7 +118,7 @@ func CreateWithReturning(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
db.Statement.Build(db.Statement.BuildClauses...)
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {

View File

@ -135,7 +135,7 @@ func Delete(db *gorm.DB) {
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("DELETE", "FROM", "WHERE")
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {

View File

@ -167,7 +167,7 @@ func BuildQuerySQL(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
db.Statement.Build(db.Statement.BuildClauses...)
}
}

View File

@ -66,7 +66,7 @@ func Update(db *gorm.DB) {
} else {
return
}
db.Statement.Build("UPDATE", "SET", "WHERE")
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {

View File

@ -27,6 +27,7 @@ type Statement struct {
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
Selects []string // selected columns
Omits []string // omit columns