From 43e9035dad7eb9fba4306d367cb25d0cfd9738c2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 12 Jan 2016 13:44:16 +0800 Subject: [PATCH] Fix Association Count with Soft Delete --- association_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++ callback_delete.go | 2 +- main.go | 7 ++-- main_private.go | 2 +- scope_private.go | 8 ++--- search.go | 6 ---- structs_test.go | 2 +- 7 files changed, 99 insertions(+), 16 deletions(-) diff --git a/association_test.go b/association_test.go index 29a65292..ab3abd91 100644 --- a/association_test.go +++ b/association_test.go @@ -293,6 +293,34 @@ func TestHasOne(t *testing.T) { if DB.Model(&user).Association("CreditCard").Count() != 0 { t.Errorf("User's credit card count should be 0 after Clear") } + + // Check Association mode with soft delete + var creditcard6 = CreditCard{ + Number: "411111111116", + } + DB.Model(&user).Association("CreditCard").Append(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 after Append, but got %v", count) + } + + DB.Delete(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { + t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) + } + + if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { + t.Errorf("User's creditcard is not findable after Delete") + } + + if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { + t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) + } } func TestHasMany(t *testing.T) { @@ -402,6 +430,36 @@ func TestHasMany(t *testing.T) { if len(comments51) != 0 { t.Errorf("Clear has many relations") } + + // Check Association mode with soft delete + var comment6 = Comment{ + Content: "comment 6", + } + DB.Model(&post).Association("Comments").Append(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 after Append, but got %v", count) + } + + DB.Delete(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 0 { + t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) + } + + var comments6 []Comment + if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { + t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) + } + + if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) + } + + var comments61 []Comment + if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) + } } func TestManyToMany(t *testing.T) { @@ -500,6 +558,36 @@ func TestManyToMany(t *testing.T) { if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { t.Errorf("Relations should be cleared") } + + // Check Association mode with soft delete + var language6 = Language{ + Name: "language 6", + } + DB.Model(&user).Association("Languages").Append(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 after Append, but got %v", count) + } + + DB.Delete(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 0 { + t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) + } + + var languages6 []Language + if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { + t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) + } + + if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) + } + + var languages61 []Language + if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) + } } func TestRelated(t *testing.T) { diff --git a/callback_delete.go b/callback_delete.go index 72236659..8e56196b 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -8,7 +8,7 @@ func BeforeDelete(scope *Scope) { func Delete(scope *Scope) { if !scope.HasError() { - if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { + if !scope.db.unscoped && scope.HasColumn("DeletedAt") { scope.Raw( fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.QuotedTableName(), diff --git a/main.go b/main.go index 9fe6cf4e..d84c139b 100644 --- a/main.go +++ b/main.go @@ -28,6 +28,7 @@ type DB struct { parent *DB search *search logMode int + unscoped bool logger logger dialect Dialect singularTable bool @@ -186,7 +187,9 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { } func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db + clone := s.clone() + clone.unscoped = true + return clone } func (s *DB) Attrs(attrs ...interface{}) *DB { @@ -434,7 +437,7 @@ func (s *DB) DropColumn(column string) *DB { } func (s *DB) AddIndex(indexName string, column ...string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.Unscoped().NewScope(s.Value) scope.addIndex(false, indexName, column...) return scope.db } diff --git a/main_private.go b/main_private.go index bd097ce0..3431de81 100644 --- a/main_private.go +++ b/main_private.go @@ -3,7 +3,7 @@ package gorm import "time" func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, unscoped: s.unscoped, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} for key, value := range s.values { db.values[key] = value diff --git a/scope_private.go b/scope_private.go index d893c095..36292423 100644 --- a/scope_private.go +++ b/scope_private.go @@ -161,7 +161,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) func (scope *Scope) whereSql() (sql string) { var primaryConditions, andConditions, orConditions []string - if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { + if !scope.db.unscoped && scope.Fields()["deleted_at"] != nil { sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) primaryConditions = append(primaryConditions, sql) } @@ -601,9 +601,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { sqlCreate = "CREATE UNIQUE INDEX" } - scope.Search.Unscoped = true scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() - scope.Search.Unscoped = false } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { @@ -659,11 +657,11 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - scope.addIndex(false, name, columns...) + scope.NewDB().Model(scope.Value).AddIndex(name, columns...) } for name, columns := range uniqueIndexes { - scope.addIndex(true, name, columns...) + scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) } return scope diff --git a/search.go b/search.go index 166b9a86..cabce05c 100644 --- a/search.go +++ b/search.go @@ -20,7 +20,6 @@ type search struct { group string tableName string raw bool - Unscoped bool countingQuery bool } @@ -124,11 +123,6 @@ func (s *search) Raw(b bool) *search { return s } -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - func (s *search) Table(name string) *search { s.tableName = name return s diff --git a/structs_test.go b/structs_test.go index 8f529952..20666740 100644 --- a/structs_test.go +++ b/structs_test.go @@ -66,7 +66,7 @@ type Address struct { } type Language struct { - Id int + gorm.Model Name string Users []User `gorm:"many2many:user_languages;"` }