mirror of https://github.com/go-gorm/gorm.git
Add SetupJoinTable support
This commit is contained in:
parent
db03616993
commit
e490e09db5
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
|
@ -44,7 +45,16 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro
|
|||
tx = association.DB.Model(out)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped {
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE", "LIMIT")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
|
@ -321,10 +331,13 @@ func (association *Association) Count() (count int64) {
|
|||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
tx.Clauses(queryClause)
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE", "LIMIT")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
|
|
|
@ -169,12 +169,12 @@ func CreateWithReturning(db *gorm.DB) {
|
|||
if err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
} else if !db.DryRun {
|
||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
37
gorm.go
37
gorm.go
|
@ -108,6 +108,7 @@ func (db *DB) Session(config *Session) *DB {
|
|||
if config.Context != nil {
|
||||
if tx.Statement != nil {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
} else {
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
|
@ -181,6 +182,42 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
stmt = tx.Statement
|
||||
modelSchema, joinSchema *schema.Schema
|
||||
)
|
||||
|
||||
if err := stmt.Parse(model); err == nil {
|
||||
modelSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := stmt.Parse(joinTable); err == nil {
|
||||
joinSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
|
||||
for _, ref := range relation.References {
|
||||
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||
ref.ForeignKey = f
|
||||
} else {
|
||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
relation.JoinTable = joinSchema
|
||||
} else {
|
||||
return fmt.Errorf("failed to found relation: %v", field)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Callback returns callback manager
|
||||
func (db *DB) Callback() *callbacks {
|
||||
return db.callbacks
|
||||
|
|
|
@ -33,7 +33,7 @@ type Relationship struct {
|
|||
Type RelationshipType
|
||||
Field *Field
|
||||
Polymorphic *Polymorphic
|
||||
References []Reference
|
||||
References []*Reference
|
||||
Schema *Schema
|
||||
FieldSchema *Schema
|
||||
JoinTable *Schema
|
||||
|
@ -139,7 +139,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||
}
|
||||
|
||||
if schema.err == nil {
|
||||
relation.References = append(relation.References, Reference{
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryValue: relation.Polymorphic.Value,
|
||||
ForeignKey: relation.Polymorphic.PolymorphicType,
|
||||
})
|
||||
|
@ -150,7 +150,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
|
||||
}
|
||||
}
|
||||
relation.References = append(relation.References, Reference{
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryKeyField,
|
||||
ForeignKey: relation.Polymorphic.PolymorphicID,
|
||||
OwnPrimaryKey: true,
|
||||
|
@ -246,7 +246,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
|
||||
// build references
|
||||
for _, f := range relation.JoinTable.Fields {
|
||||
relation.References = append(relation.References, Reference{
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: fieldsMap[f.Name],
|
||||
ForeignKey: f,
|
||||
OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name],
|
||||
|
@ -326,7 +326,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||
|
||||
// build references
|
||||
for idx, foreignField := range foreignFields {
|
||||
relation.References = append(relation.References, Reference{
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryFields[idx],
|
||||
ForeignKey: foreignField,
|
||||
OwnPrimaryKey: schema == primarySchema && guessHas,
|
||||
|
|
|
@ -158,9 +158,11 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||
writer.WriteString("(NULL)")
|
||||
}
|
||||
case *DB:
|
||||
result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement
|
||||
writer.WriteString(result.SQL.String())
|
||||
stmt.Vars = append(stmt.Vars, result.Vars...)
|
||||
subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance()
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
writer.WriteString(subdb.Statement.SQL.String())
|
||||
stmt.Vars = subdb.Statement.Vars
|
||||
default:
|
||||
switch rv := reflect.ValueOf(v); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
type Person struct {
|
||||
ID int
|
||||
Name string
|
||||
Addresses []Address `gorm:"many2many:person_addresses;"`
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
ID uint
|
||||
Name string
|
||||
}
|
||||
|
||||
type PersonAddress struct {
|
||||
PersonID int
|
||||
AddressID int
|
||||
CreatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
|
||||
func TestOverrideJoinTable(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{})
|
||||
|
||||
if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil {
|
||||
t.Fatalf("Failed to setup join table for person, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil {
|
||||
t.Fatalf("Failed to migrate, got %v", err)
|
||||
}
|
||||
|
||||
address1 := Address{Name: "address 1"}
|
||||
address2 := Address{Name: "address 2"}
|
||||
person := Person{Name: "person", Addresses: []Address{address1, address2}}
|
||||
DB.Create(&person)
|
||||
|
||||
var addresses1 []Address
|
||||
if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 {
|
||||
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1))
|
||||
}
|
||||
|
||||
if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil {
|
||||
t.Fatalf("Failed to delete address, got error %v", err)
|
||||
}
|
||||
|
||||
if len(person.Addresses) != 1 {
|
||||
t.Fatalf("Should have one address left")
|
||||
}
|
||||
|
||||
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 {
|
||||
t.Fatalf("Should found one address")
|
||||
}
|
||||
|
||||
var addresses2 []Address
|
||||
if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 {
|
||||
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2))
|
||||
}
|
||||
|
||||
if DB.Model(&person).Association("Addresses").Count() != 1 {
|
||||
t.Fatalf("Should found one address")
|
||||
}
|
||||
|
||||
var addresses3 []Address
|
||||
if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 {
|
||||
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3))
|
||||
}
|
||||
|
||||
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 {
|
||||
t.Fatalf("Should found soft deleted addresses with unscoped")
|
||||
}
|
||||
|
||||
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 {
|
||||
t.Fatalf("Should found soft deleted addresses with unscoped")
|
||||
}
|
||||
|
||||
DB.Model(&person).Association("Addresses").Clear()
|
||||
|
||||
if DB.Model(&person).Association("Addresses").Count() != 0 {
|
||||
t.Fatalf("Should deleted all addresses")
|
||||
}
|
||||
|
||||
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 {
|
||||
t.Fatalf("Should found soft deleted addresses with unscoped")
|
||||
}
|
||||
|
||||
DB.Unscoped().Model(&person).Association("Addresses").Clear()
|
||||
|
||||
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
|
||||
t.Fatalf("address should be deleted when clear with unscoped")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue