Merge pull request from slockij/block-global-update-delete

Block global updates / deletes
This commit is contained in:
Jinzhu 2016-11-10 08:45:29 +08:00 committed by GitHub
commit 45fc640bf3
5 changed files with 81 additions and 2 deletions

View File

@ -1,6 +1,9 @@
package gorm package gorm
import "fmt" import (
"errors"
"fmt"
)
// Define callbacks for deleting // Define callbacks for deleting
func init() { func init() {
@ -13,6 +16,10 @@ func init() {
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) { func beforeDeleteCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while deleting"))
return
}
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("BeforeDelete") scope.CallMethod("BeforeDelete")
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
) )
@ -31,6 +32,10 @@ func assignUpdatingAttributesCallback(scope *Scope) {
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) { func beforeUpdateCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while updating"))
return
}
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("BeforeSave") scope.CallMethod("BeforeSave")

24
main.go
View File

@ -25,6 +25,7 @@ type DB struct {
source string source string
values map[string]interface{} values map[string]interface{}
joinTableHandlers map[string]JoinTableHandler joinTableHandlers map[string]JoinTableHandler
blockGlobalUpdate bool
} }
// Open initialize a new db connection, need to import driver first, e.g: // Open initialize a new db connection, need to import driver first, e.g:
@ -142,6 +143,18 @@ func (s *DB) LogMode(enable bool) *DB {
return s return s
} }
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
// This is to prevent eventual error with empty objects updates/deletions
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
s.blockGlobalUpdate = enable
return s
}
// HasBlockGlobalUpdate return state of block
func (s *DB) HasBlockGlobalUpdate() bool {
return s.blockGlobalUpdate
}
// SingularTable use singular table by default // SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) { func (s *DB) SingularTable(enable bool) {
modelStructsMap = newModelStructsMap() modelStructsMap = newModelStructsMap()
@ -682,7 +695,16 @@ func (s *DB) GetErrors() []error {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
func (s *DB) clone() *DB { func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} db := DB{
db: s.db,
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
values: map[string]interface{}{},
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
}
for key, value := range s.values { for key, value := range s.values {
db.values[key] = value db.values[key] = value

View File

@ -771,6 +771,44 @@ func TestOpenWithOneParameter(t *testing.T) {
} }
} }
func TestBlockGlobalUpdate(t *testing.T) {
db := DB.New()
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
err := db.Model(&Toy{}).Update("OwnerType", "Human").Error
if err != nil {
t.Error("Unexpected error on global update")
}
err = db.Delete(&Toy{}).Error
if err != nil {
t.Error("Unexpected error on global delete")
}
db.BlockGlobalUpdate(true)
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
err = db.Model(&Toy{}).Update("OwnerType", "Human").Error
if err == nil {
t.Error("Expected error on global update")
}
err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error
if err != nil {
t.Error("Unxpected error on conditional update")
}
err = db.Delete(&Toy{}).Error
if err == nil {
t.Error("Expected error on global delete")
}
err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error
if err != nil {
t.Error("Unexpected error on conditional delete")
}
}
func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
b.N = 2000 b.N = 2000
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {

View File

@ -1280,3 +1280,10 @@ func (scope *Scope) getColumnAsScope(column string) *Scope {
} }
return nil return nil
} }
func (scope *Scope) hasConditions() bool {
return !scope.PrimaryKeyZero() ||
len(scope.Search.whereConditions) > 0 ||
len(scope.Search.orConditions) > 0 ||
len(scope.Search.notConditions) > 0
}