In some cases (Error not checked, missed data) one can perform very harmful operation - global update or delete (all records)

This is to prevent it.
This commit is contained in:
slockij 2016-11-04 17:41:31 +01:00
parent d5d3e3a67b
commit e26cb8dbc4
5 changed files with 81 additions and 2 deletions

View File

@ -1,6 +1,9 @@
package gorm
import "fmt"
import (
"errors"
"fmt"
)
// Define callbacks for deleting
func init() {
@ -13,6 +16,10 @@ func init() {
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while deleting"))
return
}
if !scope.HasError() {
scope.CallMethod("BeforeDelete")
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"errors"
"fmt"
"strings"
)
@ -31,6 +32,10 @@ func assignUpdatingAttributesCallback(scope *Scope) {
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while updating"))
return
}
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("BeforeSave")

24
main.go
View File

@ -25,6 +25,7 @@ type DB struct {
source string
values map[string]interface{}
joinTableHandlers map[string]JoinTableHandler
blockGlobalUpdate bool
}
// Open initialize a new db connection, need to import driver first, e.g:
@ -142,6 +143,18 @@ func (s *DB) LogMode(enable bool) *DB {
return s
}
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
// This is to prevent eventual error with empty objects updates/deletions
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
s.blockGlobalUpdate = enable
return s
}
// HasBlockGlobalUpdate return state of block
func (s *DB) HasBlockGlobalUpdate() bool {
return s.blockGlobalUpdate
}
// SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) {
modelStructsMap = newModelStructsMap()
@ -682,7 +695,16 @@ func (s *DB) GetErrors() (errors []error) {
////////////////////////////////////////////////////////////////////////////////
func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
db := DB{
db: s.db,
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
values: map[string]interface{}{},
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
}
for key, value := range s.values {
db.values[key] = value

View File

@ -771,6 +771,44 @@ func TestOpenWithOneParameter(t *testing.T) {
}
}
func TestBlockGlobalUpdate(t *testing.T) {
db := DB.New()
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
err := db.Model(&Toy{}).Update("OwnerType", "Human").Error
if err != nil {
t.Error("Unexpected error on global update")
}
err = db.Delete(&Toy{}).Error
if err != nil {
t.Error("Unexpected error on global delete")
}
db.BlockGlobalUpdate(true)
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
err = db.Model(&Toy{}).Update("OwnerType", "Human").Error
if err == nil {
t.Error("Expected error on global update")
}
err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error
if err != nil {
t.Error("Unxpected error on conditional update")
}
err = db.Delete(&Toy{}).Error
if err == nil {
t.Error("Expected error on global delete")
}
err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error
if err != nil {
t.Error("Unexpected error on conditional delete")
}
}
func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {

View File

@ -1280,3 +1280,10 @@ func (scope *Scope) getColumnAsScope(column string) *Scope {
}
return nil
}
func (scope *Scope) hasConditions() bool {
return !scope.PrimaryKeyZero() ||
len(scope.Search.whereConditions) > 0 ||
len(scope.Search.orConditions) > 0 ||
len(scope.Search.notConditions) > 0
}