From ee48d6986c4f88a42b5dfc3b230cf9d198f2de6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Jan 2016 07:28:08 +0800 Subject: [PATCH 1/4] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e85e8e02..fa710802 100644 --- a/README.md +++ b/README.md @@ -595,7 +595,7 @@ db.Model(&user).Related(&profile) ### Has Many ```go -// User belongs to a profile, ProfileID is the foreign key +// User has many emails, UserID is the foreign key type User struct { gorm.Model Emails []Email From 801a271d0760865d714a4e3532a273dbf2998676 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 12 Jan 2016 12:16:22 +0800 Subject: [PATCH 2/4] Fix Association Count --- association.go | 23 ++++++++++++--------- association_test.go | 50 +++++++++++++++++++++++++++++++++++---------- join_table_test.go | 2 +- main_test.go | 5 +++-- structs_test.go | 6 ++++-- 5 files changed, 60 insertions(+), 26 deletions(-) diff --git a/association.go b/association.go index 5cc32e1c..4660cb27 100644 --- a/association.go +++ b/association.go @@ -297,13 +297,16 @@ func (association *Association) Clear() *Association { } func (association *Association) Count() int { - count := -1 - relationship := association.Field.Relationship - scope := association.Scope - newScope := scope.New(association.Field.Field.Interface()) + var ( + count = 0 + relationship = association.Field.Relationship + scope = association.Scope + fieldValue = association.Field.Field.Interface() + newScope = scope.New(fieldValue) + ) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) + relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Model(fieldValue).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { query := scope.DB() for idx, foreignKey := range relationship.ForeignDBNames { @@ -316,16 +319,16 @@ func (association *Association) Count() int { if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - query.Table(newScope.TableName()).Count(&count) + query.Model(fieldValue).Count(&count) } else if relationship.Kind == "belongs_to" { query := scope.DB() - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + for idx, primaryKey := range relationship.AssociationForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), field.Field.Interface()) } } - query.Table(newScope.TableName()).Count(&count) + query.Model(fieldValue).Count(&count) } return count diff --git a/association_test.go b/association_test.go index 0e61d51f..29a65292 100644 --- a/association_test.go +++ b/association_test.go @@ -19,7 +19,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Got errors when save post", err.Error()) } - if post.Category.Id == 0 || post.MainCategory.Id == 0 { + if post.Category.ID == 0 || post.MainCategory.ID == 0 { t.Errorf("Category's primary key should be updated") } @@ -46,11 +46,11 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Query belongs to relations with Related") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } - if DB.Model(&post).Association("MainCategory").Count() == 1 { + if DB.Model(&post).Association("MainCategory").Count() != 1 { t.Errorf("Post's main category count should be 1") } @@ -60,7 +60,7 @@ func TestBelongsTo(t *testing.T) { } DB.Model(&post).Association("Category").Append(&category2) - if category2.Id == 0 { + if category2.ID == 0 { t.Errorf("Category should has ID when created with Append") } @@ -71,7 +71,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be updated with Append") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } @@ -81,7 +81,7 @@ func TestBelongsTo(t *testing.T) { } DB.Model(&post).Association("Category").Replace(&category3) - if category3.Id == 0 { + if category3.ID == 0 { t.Errorf("Category should has ID when created with Replace") } @@ -91,7 +91,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be updated with Replace") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } @@ -117,8 +117,8 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be deleted with Delete") } - if DB.Model(&post).Association("Category").Count() == 0 { - t.Errorf("Post's category count should be 0 after Delete") + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Delete, but got %v", count) } // Clear @@ -144,8 +144,36 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Should not find any category after Clear") } - if DB.Model(&post).Association("Category").Count() == 0 { - t.Errorf("Post's category count should be 0 after Clear") + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Clear, but got %v", count) + } + + // Check Association mode with soft delete + category6 := Category{ + Name: "Category 6", + } + DB.Model(&post).Association("Category").Append(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 after Append, but got %v", count) + } + + DB.Delete(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) + } + + if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { + t.Errorf("Post's category is not findable after Delete") + } + + if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { + t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) } } diff --git a/join_table_test.go b/join_table_test.go index 3353aee2..70e792ed 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -39,7 +39,7 @@ func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) - return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) + return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { diff --git a/main_test.go b/main_test.go index e6c703e4..65467d73 100644 --- a/main_test.go +++ b/main_test.go @@ -33,8 +33,9 @@ func init() { // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // DB.SetLogger(log.New(os.Stdout, "\r\n", 0)) - // DB.LogMode(true) - DB.LogMode(false) + if os.Getenv("DEBUG") == "true" { + DB.LogMode(true) + } DB.DB().SetMaxIdleConns(10) diff --git a/structs_test.go b/structs_test.go index a3dfa8b1..8f529952 100644 --- a/structs_test.go +++ b/structs_test.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" + "github.com/jinzhu/gorm" + "reflect" "time" ) @@ -154,12 +156,12 @@ type Post struct { } type Category struct { - Id int64 + gorm.Model Name string } type Comment struct { - Id int64 + gorm.Model PostId int64 Content string Post Post From 43e9035dad7eb9fba4306d367cb25d0cfd9738c2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 12 Jan 2016 13:44:16 +0800 Subject: [PATCH 3/4] Fix Association Count with Soft Delete --- association_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++ callback_delete.go | 2 +- main.go | 7 ++-- main_private.go | 2 +- scope_private.go | 8 ++--- search.go | 6 ---- structs_test.go | 2 +- 7 files changed, 99 insertions(+), 16 deletions(-) diff --git a/association_test.go b/association_test.go index 29a65292..ab3abd91 100644 --- a/association_test.go +++ b/association_test.go @@ -293,6 +293,34 @@ func TestHasOne(t *testing.T) { if DB.Model(&user).Association("CreditCard").Count() != 0 { t.Errorf("User's credit card count should be 0 after Clear") } + + // Check Association mode with soft delete + var creditcard6 = CreditCard{ + Number: "411111111116", + } + DB.Model(&user).Association("CreditCard").Append(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 after Append, but got %v", count) + } + + DB.Delete(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { + t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) + } + + if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { + t.Errorf("User's creditcard is not findable after Delete") + } + + if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { + t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) + } } func TestHasMany(t *testing.T) { @@ -402,6 +430,36 @@ func TestHasMany(t *testing.T) { if len(comments51) != 0 { t.Errorf("Clear has many relations") } + + // Check Association mode with soft delete + var comment6 = Comment{ + Content: "comment 6", + } + DB.Model(&post).Association("Comments").Append(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 after Append, but got %v", count) + } + + DB.Delete(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 0 { + t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) + } + + var comments6 []Comment + if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { + t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) + } + + if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) + } + + var comments61 []Comment + if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) + } } func TestManyToMany(t *testing.T) { @@ -500,6 +558,36 @@ func TestManyToMany(t *testing.T) { if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { t.Errorf("Relations should be cleared") } + + // Check Association mode with soft delete + var language6 = Language{ + Name: "language 6", + } + DB.Model(&user).Association("Languages").Append(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 after Append, but got %v", count) + } + + DB.Delete(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 0 { + t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) + } + + var languages6 []Language + if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { + t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) + } + + if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) + } + + var languages61 []Language + if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) + } } func TestRelated(t *testing.T) { diff --git a/callback_delete.go b/callback_delete.go index 72236659..8e56196b 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -8,7 +8,7 @@ func BeforeDelete(scope *Scope) { func Delete(scope *Scope) { if !scope.HasError() { - if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { + if !scope.db.unscoped && scope.HasColumn("DeletedAt") { scope.Raw( fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.QuotedTableName(), diff --git a/main.go b/main.go index 9fe6cf4e..d84c139b 100644 --- a/main.go +++ b/main.go @@ -28,6 +28,7 @@ type DB struct { parent *DB search *search logMode int + unscoped bool logger logger dialect Dialect singularTable bool @@ -186,7 +187,9 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { } func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db + clone := s.clone() + clone.unscoped = true + return clone } func (s *DB) Attrs(attrs ...interface{}) *DB { @@ -434,7 +437,7 @@ func (s *DB) DropColumn(column string) *DB { } func (s *DB) AddIndex(indexName string, column ...string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.Unscoped().NewScope(s.Value) scope.addIndex(false, indexName, column...) return scope.db } diff --git a/main_private.go b/main_private.go index bd097ce0..3431de81 100644 --- a/main_private.go +++ b/main_private.go @@ -3,7 +3,7 @@ package gorm import "time" func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, unscoped: s.unscoped, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} for key, value := range s.values { db.values[key] = value diff --git a/scope_private.go b/scope_private.go index d893c095..36292423 100644 --- a/scope_private.go +++ b/scope_private.go @@ -161,7 +161,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) func (scope *Scope) whereSql() (sql string) { var primaryConditions, andConditions, orConditions []string - if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { + if !scope.db.unscoped && scope.Fields()["deleted_at"] != nil { sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) primaryConditions = append(primaryConditions, sql) } @@ -601,9 +601,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { sqlCreate = "CREATE UNIQUE INDEX" } - scope.Search.Unscoped = true scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() - scope.Search.Unscoped = false } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { @@ -659,11 +657,11 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - scope.addIndex(false, name, columns...) + scope.NewDB().Model(scope.Value).AddIndex(name, columns...) } for name, columns := range uniqueIndexes { - scope.addIndex(true, name, columns...) + scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) } return scope diff --git a/search.go b/search.go index 166b9a86..cabce05c 100644 --- a/search.go +++ b/search.go @@ -20,7 +20,6 @@ type search struct { group string tableName string raw bool - Unscoped bool countingQuery bool } @@ -124,11 +123,6 @@ func (s *search) Raw(b bool) *search { return s } -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - func (s *search) Table(name string) *search { s.tableName = name return s diff --git a/structs_test.go b/structs_test.go index 8f529952..20666740 100644 --- a/structs_test.go +++ b/structs_test.go @@ -66,7 +66,7 @@ type Address struct { } type Language struct { - Id int + gorm.Model Name string Users []User `gorm:"many2many:user_languages;"` } From 341d047aa7ae166f12478d4c3c0941681aa22323 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 12 Jan 2016 15:27:25 +0800 Subject: [PATCH 4/4] Rollback to old Unscoped API --- association.go | 2 +- callback_delete.go | 2 +- main.go | 5 +---- main_private.go | 2 +- scope_private.go | 2 +- search.go | 6 ++++++ 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/association.go b/association.go index 4660cb27..30ea36b2 100644 --- a/association.go +++ b/association.go @@ -306,7 +306,7 @@ func (association *Association) Count() int { ) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Model(fieldValue).Count(&count) + relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { query := scope.DB() for idx, foreignKey := range relationship.ForeignDBNames { diff --git a/callback_delete.go b/callback_delete.go index 8e56196b..72236659 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -8,7 +8,7 @@ func BeforeDelete(scope *Scope) { func Delete(scope *Scope) { if !scope.HasError() { - if !scope.db.unscoped && scope.HasColumn("DeletedAt") { + if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { scope.Raw( fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.QuotedTableName(), diff --git a/main.go b/main.go index d84c139b..ff707f3f 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,6 @@ type DB struct { parent *DB search *search logMode int - unscoped bool logger logger dialect Dialect singularTable bool @@ -187,9 +186,7 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { } func (s *DB) Unscoped() *DB { - clone := s.clone() - clone.unscoped = true - return clone + return s.clone().search.unscoped().db } func (s *DB) Attrs(attrs ...interface{}) *DB { diff --git a/main_private.go b/main_private.go index 3431de81..bd097ce0 100644 --- a/main_private.go +++ b/main_private.go @@ -3,7 +3,7 @@ package gorm import "time" func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, unscoped: s.unscoped, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} for key, value := range s.values { db.values[key] = value diff --git a/scope_private.go b/scope_private.go index 36292423..a154c426 100644 --- a/scope_private.go +++ b/scope_private.go @@ -161,7 +161,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) func (scope *Scope) whereSql() (sql string) { var primaryConditions, andConditions, orConditions []string - if !scope.db.unscoped && scope.Fields()["deleted_at"] != nil { + if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) primaryConditions = append(primaryConditions, sql) } diff --git a/search.go b/search.go index cabce05c..166b9a86 100644 --- a/search.go +++ b/search.go @@ -20,6 +20,7 @@ type search struct { group string tableName string raw bool + Unscoped bool countingQuery bool } @@ -123,6 +124,11 @@ func (s *search) Raw(b bool) *search { return s } +func (s *search) unscoped() *search { + s.Unscoped = true + return s +} + func (s *search) Table(name string) *search { s.tableName = name return s