Refactor check missing where condition

This commit is contained in:
Jinzhu 2022-02-25 10:48:23 +08:00
parent 397b583b8e
commit f2edda50e1
5 changed files with 33 additions and 35 deletions

View File

@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{})
@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) {
}
db.Statement.AddClauseIfNotExists(clause.From{})
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
db.AddError(gorm.ErrMissingWhereClause)
return
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning)

View File

@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}

View File

@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) {
return
}
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {

View File

@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
}
@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
}
}
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
stmt.DB.AddError(ErrMissingWhereClause)
} else {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build(stmt.DB.Callback().Update().Clauses...)
}

View File

@ -645,7 +645,7 @@ func TestSave(t *testing.T) {
dryDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryDB.Save(&user).Statement
if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
}