Be able to find many to many relations

This commit is contained in:
Jinzhu 2014-07-30 16:22:26 +08:00
parent 9b3be66a77
commit efd3b7678b
3 changed files with 50 additions and 8 deletions

View File

@ -1,25 +1,67 @@
package gorm package gorm
import (
"errors"
"fmt"
"reflect"
)
type Association struct { type Association struct {
Scope *Scope Scope *Scope
Column string Column string
Error error
} }
func (*Association) Find(value interface{}) { func (association *Association) Find(value interface{}) *Association {
scope := association.Scope
primaryKey := scope.PrimaryKeyValue()
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
association.Error = errors.New("primary key can't be nil")
} }
func (*Association) Append(values interface{}) { scopeType := scope.IndirectValue().Type()
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(association.Column)); ok {
field := scope.fieldFromStruct(f)
joinTable := field.JoinTable
if joinTable != nil && joinTable.foreignKey != "" {
if joinTable.joinTable != "" {
newScope := scope.New(value)
joinSql := fmt.Sprintf(
"INNER JOIN %v ON %v.%v = %v.%v",
scope.Quote(joinTable.joinTable),
scope.Quote(joinTable.joinTable),
scope.Quote(ToSnake(joinTable.associationForeignKey)),
newScope.QuotedTableName(),
scope.Quote(newScope.PrimaryKey()))
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(joinTable.joinTable), scope.Quote(ToSnake(joinTable.foreignKey)))
scope.db.Joins(joinSql).Where(whereSql, primaryKey).Find(value)
} else {
}
} else {
association.Error = errors.New(fmt.Sprintf("invalid association %v for %v", association.Column, scopeType))
}
} else {
association.Error = errors.New(fmt.Sprintf("%v doesn't have column %v", scopeType, association.Column))
}
return association
} }
func (*Association) Delete(value interface{}) { func (association *Association) Append(values interface{}) *Association {
return association
} }
func (*Association) Clear(value interface{}) { func (association *Association) Delete(value interface{}) *Association {
return association
} }
func (*Association) Replace(values interface{}) { func (association *Association) Clear(value interface{}) *Association {
return association
} }
func (*Association) Count(values interface{}) int { func (association *Association) Replace(values interface{}) *Association {
return association
}
func (association *Association) Count(values interface{}) int {
return 0 return 0
} }

View File

@ -140,7 +140,7 @@ func TestManyToMany(t *testing.T) {
newLanguages = []Language{} newLanguages = []Language{}
db.Model(&user).Association("Languages").Find(&newLanguages) db.Model(&user).Association("Languages").Find(&newLanguages)
if len(newLanguages) != 3 { if len(newLanguages) != 3 {
t.Errorf("Query many to many relations") t.Errorf("Should be able to find many to many relations")
} }
// db.Model(&User{}).Many2Many("Languages").Add(&Language{}) // db.Model(&User{}).Many2Many("Languages").Add(&Language{})

View File

@ -18,7 +18,7 @@ func runMigration() {
db.Exec("drop table roles") db.Exec("drop table roles")
db.Exec("drop table companies") db.Exec("drop table companies")
db.Exec("drop table animals") db.Exec("drop table animals")
db.Exec("drop table user_companies") db.Exec("drop table user_languages")
if err := db.CreateTable(&Animal{}).Error; err != nil { if err := db.CreateTable(&Animal{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))