diff --git a/callback_delete.go b/callback_delete.go index c8ffcc82..6217706e 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -1,6 +1,9 @@ package gorm -import "fmt" +import ( + "errors" + "fmt" +) // Define callbacks for deleting func init() { @@ -13,6 +16,10 @@ func init() { // beforeDeleteCallback will invoke `BeforeDelete` method before deleting func beforeDeleteCallback(scope *Scope) { + if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { + scope.Err(errors.New("Missing WHERE clause while deleting")) + return + } if !scope.HasError() { scope.CallMethod("BeforeDelete") } diff --git a/callback_update.go b/callback_update.go index aa27b5fb..6948439f 100644 --- a/callback_update.go +++ b/callback_update.go @@ -1,6 +1,7 @@ package gorm import ( + "errors" "fmt" "strings" ) @@ -31,6 +32,10 @@ func assignUpdatingAttributesCallback(scope *Scope) { // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating func beforeUpdateCallback(scope *Scope) { + if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { + scope.Err(errors.New("Missing WHERE clause while updating")) + return + } if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { scope.CallMethod("BeforeSave") diff --git a/main.go b/main.go index 192dbd7c..558e9674 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ type DB struct { source string values map[string]interface{} joinTableHandlers map[string]JoinTableHandler + blockGlobalUpdate bool } // Open initialize a new db connection, need to import driver first, e.g: @@ -142,6 +143,18 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// BlockGlobalUpdate if true, generates an error on update/delete without where clause. +// This is to prevent eventual error with empty objects updates/deletions +func (s *DB) BlockGlobalUpdate(enable bool) *DB { + s.blockGlobalUpdate = enable + return s +} + +// HasBlockGlobalUpdate return state of block +func (s *DB) HasBlockGlobalUpdate() bool { + return s.blockGlobalUpdate +} + // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { modelStructsMap = newModelStructsMap() @@ -682,7 +695,16 @@ func (s *DB) GetErrors() (errors []error) { //////////////////////////////////////////////////////////////////////////////// func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{ + db: s.db, + parent: s.parent, + logger: s.logger, + logMode: s.logMode, + values: map[string]interface{}{}, + Value: s.Value, + Error: s.Error, + blockGlobalUpdate: s.blockGlobalUpdate, + } for key, value := range s.values { db.values[key] = value diff --git a/main_test.go b/main_test.go index 729e6eb2..9869a7ad 100644 --- a/main_test.go +++ b/main_test.go @@ -771,6 +771,44 @@ func TestOpenWithOneParameter(t *testing.T) { } } +func TestBlockGlobalUpdate(t *testing.T) { + db := DB.New() + db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) + + err := db.Model(&Toy{}).Update("OwnerType", "Human").Error + if err != nil { + t.Error("Unexpected error on global update") + } + + err = db.Delete(&Toy{}).Error + if err != nil { + t.Error("Unexpected error on global delete") + } + + db.BlockGlobalUpdate(true) + + db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) + + err = db.Model(&Toy{}).Update("OwnerType", "Human").Error + if err == nil { + t.Error("Expected error on global update") + } + + err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error + if err != nil { + t.Error("Unxpected error on conditional update") + } + + err = db.Delete(&Toy{}).Error + if err == nil { + t.Error("Expected error on global delete") + } + err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error + if err != nil { + t.Error("Unexpected error on conditional delete") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/scope.go b/scope.go index 4a962062..8212d4a6 100644 --- a/scope.go +++ b/scope.go @@ -1280,3 +1280,10 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { } return nil } + +func (scope *Scope) hasConditions() bool { + return !scope.PrimaryKeyZero() || + len(scope.Search.whereConditions) > 0 || + len(scope.Search.orConditions) > 0 || + len(scope.Search.notConditions) > 0 +}