From 44b106c8e26ab7b0320d489cfc66cf9db0ed1dc9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 19 Mar 2015 18:23:54 +0800 Subject: [PATCH] Fix tests --- association.go | 4 +--- association_test.go | 2 +- join_table.go | 46 ++++++++++++++++++++++----------------------- join_table_test.go | 25 +++++++++++++----------- main.go | 9 +++++++++ scope_private.go | 5 +---- 6 files changed, 49 insertions(+), 42 deletions(-) diff --git a/association.go b/association.go index 60763f8c..89bb1bec 100644 --- a/association.go +++ b/association.go @@ -163,9 +163,7 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). - Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) - relationship.JoinTableHandler.JoinWith(query, association.Scope.Value).Count(&count) + relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/association_test.go b/association_test.go index a7b8f136..3ffd8880 100644 --- a/association_test.go +++ b/association_test.go @@ -143,7 +143,7 @@ func TestManyToMany(t *testing.T) { // Query var newLanguages []Language - DB.Debug().Model(&user).Related(&newLanguages, "Languages") + DB.Model(&user).Related(&newLanguages, "Languages") if len(newLanguages) != len([]string{"ZH", "EN"}) { t.Errorf("Query many to many relations") } diff --git a/join_table.go b/join_table.go index d29f7bfb..163bb4e2 100644 --- a/join_table.go +++ b/join_table.go @@ -18,28 +18,6 @@ type JoinTableForeignKey struct { AssociationDBName string } -func updateJoinTableHandler(relationship *Relationship) { - handler := relationship.JoinTableHandler.(*JoinTableHandler) - - destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} - for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - db := relationship.ForeignDBName - handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: db, - AssociationDBName: primaryField.DBName, - }) - } - - sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} - for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { - db := relationship.AssociationForeignDBName - handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ - DBName: db, - AssociationDBName: primaryField.DBName, - }) - } -} - type JoinTableSource struct { ModelType reflect.Type ForeignKeys []JoinTableForeignKey @@ -51,6 +29,28 @@ type JoinTableHandler struct { Destination JoinTableSource `sql:"-"` } +func updateJoinTableHandler(relationship *Relationship) { + handler := relationship.JoinTableHandler.(*JoinTableHandler) + + destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} + for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { + db := relationship.AssociationForeignDBName + handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } + + sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} + for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + db := relationship.ForeignDBName + handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } +} + func (s JoinTableHandler) Table(*DB) string { return s.TableName } @@ -88,7 +88,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) values = append(values, value) } - for _, value := range searchMap { + for _, value := range values { values = append(values, value) } diff --git a/join_table_test.go b/join_table_test.go index 429e46e1..38e9f943 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -15,29 +15,32 @@ type Person struct { } type PersonAddress struct { - gorm.JoinTableHandler PersonID int AddressID int DeletedAt time.Time CreatedAt time.Time } -func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { +func (*PersonAddress) Table(db *gorm.DB) string { + return "person_addresses" +} + +func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ - relationship.ForeignDBName: foreignValue, - relationship.AssociationForeignDBName: associationValue, + "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), + "address_id": db.NewScope(associationValue).PrimaryKeyValue(), }).Assign(map[string]interface{}{ - relationship.ForeignFieldName: foreignValue, - relationship.AssociationForeignFieldName: associationValue, - "DeletedAt": gorm.Expr("NULL"), + "person_id": foreignValue, + "address_id": associationValue, + "DeletedAt": gorm.Expr("NULL"), }).FirstOrCreate(&PersonAddress{}).Error } -func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error { +func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error { return db.Delete(&PersonAddress{}).Error } -func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { +func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } @@ -45,7 +48,7 @@ func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *go func TestJoinTable(t *testing.T) { DB.Exec("drop table person_addresses;") DB.AutoMigrate(&Person{}) - // DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") + DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) address1 := &Address{Address1: "address 1"} address2 := &Address{Address1: "address 2"} @@ -58,7 +61,7 @@ func TestJoinTable(t *testing.T) { t.Errorf("Should found one address") } - if DB.Model(person).Association("Addresses").Count() != 1 { + if DB.Debug().Model(person).Association("Addresses").Count() != 1 { t.Errorf("Should found one address") } diff --git a/main.go b/main.go index 377ef582..5f7db05a 100644 --- a/main.go +++ b/main.go @@ -469,3 +469,12 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { value, ok = s.values[name] return } + +func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { + for _, field := range s.NewScope(source).GetModelStruct().StructFields { + if field.Name == column || field.DBName == column { + field.Relationship.JoinTableHandler = handler + s.Table(handler.Table(s)).AutoMigrate(handler) + } + } +} diff --git a/scope_private.go b/scope_private.go index da78d3f2..d1f4a10b 100644 --- a/scope_private.go +++ b/scope_private.go @@ -403,10 +403,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db)) - scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value). - Where(fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)), scope.PrimaryKeyValue()). - Find(value).Error) + scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()