diff --git a/association.go b/association.go index 78e09108..9ce64043 100644 --- a/association.go +++ b/association.go @@ -78,7 +78,7 @@ func (association *Association) Delete(values ...interface{}) *Association { relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName), relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName)) - if err := association.Scope.db.Model("").Table(relationship.JoinTable). + if err := association.Scope.DB().Table(relationship.JoinTable). Where(whereSql, association.PrimaryKey, primaryKeys).Delete("").Error; err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { @@ -136,7 +136,7 @@ func (association *Association) Replace(values ...interface{}) *Association { relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName), relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName)) - scope.db.Model("").Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey, addedPrimaryKeys).Delete("") + scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey, addedPrimaryKeys).Delete("") } else { association.setErr(errors.New("replace only support many to many")) } @@ -148,7 +148,7 @@ func (association *Association) Clear() *Association { scope := association.Scope if relationship.Kind == "many_to_many" { whereSql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName)) - if err := scope.db.Model("").Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey).Delete("").Error; err == nil { + if err := scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey).Delete("").Error; err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -166,18 +166,12 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - whereSql := fmt.Sprintf("%v.%v IN (SELECT %v.%v FROM %v WHERE %v.%v = ?)", - newScope.QuotedTableName(), - scope.Quote(newScope.PrimaryKey()), - relationship.JoinTable, - scope.Quote(relationship.AssociationForeignDBName), - relationship.JoinTable, - relationship.JoinTable, - scope.Quote(relationship.ForeignDBName)) - scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) + scope.DB().Table(relationship.JoinTable). + Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). + Where(relationship.ForeignDBName+" = ?", association.PrimaryKey).Row().Scan(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey) + countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) if relationship.PolymorphicType != "" { countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } @@ -185,7 +179,7 @@ func (association *Association) Count() int { } else if relationship.Kind == "belongs_to" { if v, ok := scope.FieldByName(association.Column); ok { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count) + scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count) } } diff --git a/association_test.go b/association_test.go index fb73d443..d7984b12 100644 --- a/association_test.go +++ b/association_test.go @@ -1,9 +1,6 @@ package gorm_test -import ( - "os" - "testing" -) +import "testing" func TestHasOneAndHasManyAssociation(t *testing.T) { DB.DropTable(Category{}) @@ -183,7 +180,6 @@ func TestManyToMany(t *testing.T) { DB.Model(&user).Association("Languages").Delete(language, &language) if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { t.Errorf("Relations should be deleted with Delete") - os.Exit(1) } if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { t.Errorf("Language EE should not be deleted") diff --git a/main.go b/main.go index f0787d20..cf416799 100644 --- a/main.go +++ b/main.go @@ -322,7 +322,10 @@ func (s *DB) Count(value interface{}) *DB { } func (s *DB) Table(name string) *DB { - return s.clone().search.table(name).db + clone := s.clone() + clone.search.table(name) + clone.Value = nil + return clone } func (s *DB) Debug() *DB { diff --git a/scope.go b/scope.go index f88bf6dd..cabe9743 100644 --- a/scope.go +++ b/scope.go @@ -53,11 +53,16 @@ func (scope *Scope) NewDB() *DB { if scope.db != nil { db := scope.db.clone() db.search = nil + db.Value = nil return db } return nil } +func (scope *Scope) DB() *DB { + return scope.db +} + // SqlDB return *sql.DB func (scope *Scope) SqlDB() sqlCommon { return scope.db.db @@ -221,12 +226,12 @@ func (scope *Scope) TableName() string { return scope.GetModelStruct().TableName } -func (scope *Scope) QuotedTableName() string { +func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.TableName) > 0 { - return scope.Search.TableName + return scope.Quote(scope.Search.TableName) + } else { + return scope.Quote(scope.TableName()) } - - return scope.Quote(scope.TableName()) } // CombinedConditionSql get combined condition sql