diff --git a/README.md b/README.md index e85e8e02..fa710802 100644 --- a/README.md +++ b/README.md @@ -595,7 +595,7 @@ db.Model(&user).Related(&profile) ### Has Many ```go -// User belongs to a profile, ProfileID is the foreign key +// User has many emails, UserID is the foreign key type User struct { gorm.Model Emails []Email diff --git a/association.go b/association.go index 5cc32e1c..30ea36b2 100644 --- a/association.go +++ b/association.go @@ -297,13 +297,16 @@ func (association *Association) Clear() *Association { } func (association *Association) Count() int { - count := -1 - relationship := association.Field.Relationship - scope := association.Scope - newScope := scope.New(association.Field.Field.Interface()) + var ( + count = 0 + relationship = association.Field.Relationship + scope = association.Scope + fieldValue = association.Field.Field.Interface() + newScope = scope.New(fieldValue) + ) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) + relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { query := scope.DB() for idx, foreignKey := range relationship.ForeignDBNames { @@ -316,16 +319,16 @@ func (association *Association) Count() int { if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - query.Table(newScope.TableName()).Count(&count) + query.Model(fieldValue).Count(&count) } else if relationship.Kind == "belongs_to" { query := scope.DB() - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + for idx, primaryKey := range relationship.AssociationForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), field.Field.Interface()) } } - query.Table(newScope.TableName()).Count(&count) + query.Model(fieldValue).Count(&count) } return count diff --git a/association_test.go b/association_test.go index 0e61d51f..ab3abd91 100644 --- a/association_test.go +++ b/association_test.go @@ -19,7 +19,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Got errors when save post", err.Error()) } - if post.Category.Id == 0 || post.MainCategory.Id == 0 { + if post.Category.ID == 0 || post.MainCategory.ID == 0 { t.Errorf("Category's primary key should be updated") } @@ -46,11 +46,11 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Query belongs to relations with Related") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } - if DB.Model(&post).Association("MainCategory").Count() == 1 { + if DB.Model(&post).Association("MainCategory").Count() != 1 { t.Errorf("Post's main category count should be 1") } @@ -60,7 +60,7 @@ func TestBelongsTo(t *testing.T) { } DB.Model(&post).Association("Category").Append(&category2) - if category2.Id == 0 { + if category2.ID == 0 { t.Errorf("Category should has ID when created with Append") } @@ -71,7 +71,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be updated with Append") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } @@ -81,7 +81,7 @@ func TestBelongsTo(t *testing.T) { } DB.Model(&post).Association("Category").Replace(&category3) - if category3.Id == 0 { + if category3.ID == 0 { t.Errorf("Category should has ID when created with Replace") } @@ -91,7 +91,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be updated with Replace") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } @@ -117,8 +117,8 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be deleted with Delete") } - if DB.Model(&post).Association("Category").Count() == 0 { - t.Errorf("Post's category count should be 0 after Delete") + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Delete, but got %v", count) } // Clear @@ -144,8 +144,36 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Should not find any category after Clear") } - if DB.Model(&post).Association("Category").Count() == 0 { - t.Errorf("Post's category count should be 0 after Clear") + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Clear, but got %v", count) + } + + // Check Association mode with soft delete + category6 := Category{ + Name: "Category 6", + } + DB.Model(&post).Association("Category").Append(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 after Append, but got %v", count) + } + + DB.Delete(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) + } + + if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { + t.Errorf("Post's category is not findable after Delete") + } + + if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { + t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) } } @@ -265,6 +293,34 @@ func TestHasOne(t *testing.T) { if DB.Model(&user).Association("CreditCard").Count() != 0 { t.Errorf("User's credit card count should be 0 after Clear") } + + // Check Association mode with soft delete + var creditcard6 = CreditCard{ + Number: "411111111116", + } + DB.Model(&user).Association("CreditCard").Append(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 after Append, but got %v", count) + } + + DB.Delete(&creditcard6) + + if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { + t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) + } + + if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { + t.Errorf("User's creditcard is not findable after Delete") + } + + if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { + t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { + t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) + } } func TestHasMany(t *testing.T) { @@ -374,6 +430,36 @@ func TestHasMany(t *testing.T) { if len(comments51) != 0 { t.Errorf("Clear has many relations") } + + // Check Association mode with soft delete + var comment6 = Comment{ + Content: "comment 6", + } + DB.Model(&post).Association("Comments").Append(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 after Append, but got %v", count) + } + + DB.Delete(&comment6) + + if count := DB.Model(&post).Association("Comments").Count(); count != 0 { + t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) + } + + var comments6 []Comment + if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { + t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) + } + + if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) + } + + var comments61 []Comment + if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { + t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) + } } func TestManyToMany(t *testing.T) { @@ -472,6 +558,36 @@ func TestManyToMany(t *testing.T) { if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { t.Errorf("Relations should be cleared") } + + // Check Association mode with soft delete + var language6 = Language{ + Name: "language 6", + } + DB.Model(&user).Association("Languages").Append(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 after Append, but got %v", count) + } + + DB.Delete(&language6) + + if count := DB.Model(&user).Association("Languages").Count(); count != 0 { + t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) + } + + var languages6 []Language + if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { + t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) + } + + if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) + } + + var languages61 []Language + if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { + t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) + } } func TestRelated(t *testing.T) { diff --git a/join_table_test.go b/join_table_test.go index ce92e42f..1a83a9c8 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -39,7 +39,7 @@ func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) - return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) + return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { diff --git a/main.go b/main.go index f6cd66ad..ed902d0b 100644 --- a/main.go +++ b/main.go @@ -434,7 +434,7 @@ func (s *DB) DropColumn(column string) *DB { } func (s *DB) AddIndex(indexName string, column ...string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.Unscoped().NewScope(s.Value) scope.addIndex(false, indexName, column...) return scope.db } diff --git a/main_test.go b/main_test.go index d288b8ad..9d90bbc7 100644 --- a/main_test.go +++ b/main_test.go @@ -33,8 +33,9 @@ func init() { // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // DB.SetLogger(log.New(os.Stdout, "\r\n", 0)) - // DB.LogMode(true) - DB.LogMode(false) + if os.Getenv("DEBUG") == "true" { + DB.LogMode(true) + } DB.DB().SetMaxIdleConns(10) diff --git a/scope_private.go b/scope_private.go index f8a68229..135e7f92 100644 --- a/scope_private.go +++ b/scope_private.go @@ -604,9 +604,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { sqlCreate = "CREATE UNIQUE INDEX" } - scope.Search.Unscoped = true scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() - scope.Search.Unscoped = false } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { @@ -662,11 +660,11 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - scope.addIndex(false, name, columns...) + scope.NewDB().Model(scope.Value).AddIndex(name, columns...) } for name, columns := range uniqueIndexes { - scope.addIndex(true, name, columns...) + scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) } return scope diff --git a/structs_test.go b/structs_test.go index ef04cd4b..e595df58 100644 --- a/structs_test.go +++ b/structs_test.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" + "github.com/jinzhu/gorm" + "reflect" "time" ) @@ -64,7 +66,7 @@ type Address struct { } type Language struct { - Id int + gorm.Model Name string Users []User `gorm:"many2many:user_languages;"` } @@ -154,12 +156,12 @@ type Post struct { } type Category struct { - Id int64 + gorm.Model Name string } type Comment struct { - Id int64 + gorm.Model PostId int64 Content string Post Post