forked from mirror/gorm
Refactor check missing where condition
This commit is contained in:
parent
3741f258d0
commit
6a18a15c93
|
@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
|
@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
|
|
|
@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
|
|||
}
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func checkMissingWhereConditions(db *gorm.DB) {
|
||||
if !db.AllowGlobalUpdate && db.Error == nil {
|
||||
where, withCondition := db.Statement.Clauses["WHERE"]
|
||||
if withCondition {
|
||||
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
|
||||
whereClause, _ := where.Expression.(clause.Where)
|
||||
withCondition = len(whereClause.Exprs) > 1
|
||||
}
|
||||
}
|
||||
if !withCondition {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
|
@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||
|
|
|
@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
|
|||
|
||||
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||
if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
|||
}
|
||||
}
|
||||
|
||||
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
|
||||
stmt.DB.AddError(ErrMissingWhereClause)
|
||||
} else {
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||
}
|
||||
|
|
|
@ -645,7 +645,7 @@ func TestSave(t *testing.T) {
|
|||
|
||||
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
stmt := dryDB.Save(&user).Statement
|
||||
if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
|
||||
if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
|
||||
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue