From e26cb8dbc455455f59ecf50de92b11ade29909b7 Mon Sep 17 00:00:00 2001
From: slockij <jakub.slocki@showpad.com>
Date: Fri, 4 Nov 2016 17:41:31 +0100
Subject: [PATCH] In some cases (Error not checked, missed data) one can
 perform very harmful operation - global update or delete (all records)

This is to prevent it.
---
 callback_delete.go |  9 ++++++++-
 callback_update.go |  5 +++++
 main.go            | 24 +++++++++++++++++++++++-
 main_test.go       | 38 ++++++++++++++++++++++++++++++++++++++
 scope.go           |  7 +++++++
 5 files changed, 81 insertions(+), 2 deletions(-)

diff --git a/callback_delete.go b/callback_delete.go
index c8ffcc82..6217706e 100644
--- a/callback_delete.go
+++ b/callback_delete.go
@@ -1,6 +1,9 @@
 package gorm
 
-import "fmt"
+import (
+	"errors"
+	"fmt"
+)
 
 // Define callbacks for deleting
 func init() {
@@ -13,6 +16,10 @@ func init() {
 
 // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
 func beforeDeleteCallback(scope *Scope) {
+	if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
+		scope.Err(errors.New("Missing WHERE clause while deleting"))
+		return
+	}
 	if !scope.HasError() {
 		scope.CallMethod("BeforeDelete")
 	}
diff --git a/callback_update.go b/callback_update.go
index aa27b5fb..6948439f 100644
--- a/callback_update.go
+++ b/callback_update.go
@@ -1,6 +1,7 @@
 package gorm
 
 import (
+	"errors"
 	"fmt"
 	"strings"
 )
@@ -31,6 +32,10 @@ func assignUpdatingAttributesCallback(scope *Scope) {
 
 // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
 func beforeUpdateCallback(scope *Scope) {
+	if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
+		scope.Err(errors.New("Missing WHERE clause while updating"))
+		return
+	}
 	if _, ok := scope.Get("gorm:update_column"); !ok {
 		if !scope.HasError() {
 			scope.CallMethod("BeforeSave")
diff --git a/main.go b/main.go
index 192dbd7c..558e9674 100644
--- a/main.go
+++ b/main.go
@@ -25,6 +25,7 @@ type DB struct {
 	source            string
 	values            map[string]interface{}
 	joinTableHandlers map[string]JoinTableHandler
+	blockGlobalUpdate bool
 }
 
 // Open initialize a new db connection, need to import driver first, e.g:
@@ -142,6 +143,18 @@ func (s *DB) LogMode(enable bool) *DB {
 	return s
 }
 
+// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
+// This is to prevent eventual error with empty objects updates/deletions
+func (s *DB) BlockGlobalUpdate(enable bool) *DB {
+	s.blockGlobalUpdate = enable
+	return s
+}
+
+// HasBlockGlobalUpdate return state of block
+func (s *DB) HasBlockGlobalUpdate() bool {
+	return s.blockGlobalUpdate
+}
+
 // SingularTable use singular table by default
 func (s *DB) SingularTable(enable bool) {
 	modelStructsMap = newModelStructsMap()
@@ -682,7 +695,16 @@ func (s *DB) GetErrors() (errors []error) {
 ////////////////////////////////////////////////////////////////////////////////
 
 func (s *DB) clone() *DB {
-	db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
+	db := DB{
+		db:                s.db,
+		parent:            s.parent,
+		logger:            s.logger,
+		logMode:           s.logMode,
+		values:            map[string]interface{}{},
+		Value:             s.Value,
+		Error:             s.Error,
+		blockGlobalUpdate: s.blockGlobalUpdate,
+	}
 
 	for key, value := range s.values {
 		db.values[key] = value
diff --git a/main_test.go b/main_test.go
index 729e6eb2..9869a7ad 100644
--- a/main_test.go
+++ b/main_test.go
@@ -771,6 +771,44 @@ func TestOpenWithOneParameter(t *testing.T) {
 	}
 }
 
+func TestBlockGlobalUpdate(t *testing.T) {
+	db := DB.New()
+	db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
+
+	err := db.Model(&Toy{}).Update("OwnerType", "Human").Error
+	if err != nil {
+		t.Error("Unexpected error on global update")
+	}
+
+	err = db.Delete(&Toy{}).Error
+	if err != nil {
+		t.Error("Unexpected error on global delete")
+	}
+
+	db.BlockGlobalUpdate(true)
+
+	db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
+
+	err = db.Model(&Toy{}).Update("OwnerType", "Human").Error
+	if err == nil {
+		t.Error("Expected error on global update")
+	}
+
+	err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error
+	if err != nil {
+		t.Error("Unxpected error on conditional update")
+	}
+
+	err = db.Delete(&Toy{}).Error
+	if err == nil {
+		t.Error("Expected error on global delete")
+	}
+	err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error
+	if err != nil {
+		t.Error("Unexpected error on conditional delete")
+	}
+}
+
 func BenchmarkGorm(b *testing.B) {
 	b.N = 2000
 	for x := 0; x < b.N; x++ {
diff --git a/scope.go b/scope.go
index 4a962062..8212d4a6 100644
--- a/scope.go
+++ b/scope.go
@@ -1280,3 +1280,10 @@ func (scope *Scope) getColumnAsScope(column string) *Scope {
 	}
 	return nil
 }
+
+func (scope *Scope) hasConditions() bool {
+	return !scope.PrimaryKeyZero() ||
+		len(scope.Search.whereConditions) > 0 ||
+		len(scope.Search.orConditions) > 0 ||
+		len(scope.Search.notConditions) > 0
+}