forked from mirror/gorm
Add SetupJoinTable support
This commit is contained in:
parent
db03616993
commit
e490e09db5
|
@ -4,6 +4,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
|
@ -44,7 +45,16 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro
|
||||||
tx = association.DB.Model(out)
|
tx = association.DB.Model(out)
|
||||||
)
|
)
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped {
|
if association.Relationship.JoinTable != nil {
|
||||||
|
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||||
|
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||||
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||||
|
joinStmt.AddClause(queryClause)
|
||||||
|
}
|
||||||
|
joinStmt.Build("WHERE", "LIMIT")
|
||||||
|
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||||
|
}
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||||
ON: clause.Where{Exprs: queryConds},
|
ON: clause.Where{Exprs: queryConds},
|
||||||
|
@ -321,10 +331,13 @@ func (association *Association) Count() (count int64) {
|
||||||
)
|
)
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
if association.Relationship.JoinTable != nil {
|
||||||
if !tx.Statement.Unscoped {
|
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||||
|
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||||
tx.Clauses(queryClause)
|
joinStmt.AddClause(queryClause)
|
||||||
}
|
}
|
||||||
|
joinStmt.Build("WHERE", "LIMIT")
|
||||||
|
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||||
|
|
|
@ -169,12 +169,12 @@ func CreateWithReturning(db *gorm.DB) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} else if !db.DryRun {
|
||||||
|
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
} else {
|
} else {
|
||||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
db.AddError(err)
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
|
||||||
} else {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
37
gorm.go
37
gorm.go
|
@ -108,6 +108,7 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
if config.Context != nil {
|
if config.Context != nil {
|
||||||
if tx.Statement != nil {
|
if tx.Statement != nil {
|
||||||
tx.Statement = tx.Statement.clone()
|
tx.Statement = tx.Statement.clone()
|
||||||
|
tx.Statement.DB = tx
|
||||||
} else {
|
} else {
|
||||||
tx.Statement = &Statement{
|
tx.Statement = &Statement{
|
||||||
DB: tx,
|
DB: tx,
|
||||||
|
@ -181,6 +182,42 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||||
|
var (
|
||||||
|
tx = db.getInstance()
|
||||||
|
stmt = tx.Statement
|
||||||
|
modelSchema, joinSchema *schema.Schema
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := stmt.Parse(model); err == nil {
|
||||||
|
modelSchema = stmt.Schema
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stmt.Parse(joinTable); err == nil {
|
||||||
|
joinSchema = stmt.Schema
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
|
||||||
|
for _, ref := range relation.References {
|
||||||
|
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||||
|
ref.ForeignKey = f
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relation.JoinTable = joinSchema
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to found relation: %v", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Callback returns callback manager
|
// Callback returns callback manager
|
||||||
func (db *DB) Callback() *callbacks {
|
func (db *DB) Callback() *callbacks {
|
||||||
return db.callbacks
|
return db.callbacks
|
||||||
|
|
|
@ -33,7 +33,7 @@ type Relationship struct {
|
||||||
Type RelationshipType
|
Type RelationshipType
|
||||||
Field *Field
|
Field *Field
|
||||||
Polymorphic *Polymorphic
|
Polymorphic *Polymorphic
|
||||||
References []Reference
|
References []*Reference
|
||||||
Schema *Schema
|
Schema *Schema
|
||||||
FieldSchema *Schema
|
FieldSchema *Schema
|
||||||
JoinTable *Schema
|
JoinTable *Schema
|
||||||
|
@ -139,7 +139,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
||||||
}
|
}
|
||||||
|
|
||||||
if schema.err == nil {
|
if schema.err == nil {
|
||||||
relation.References = append(relation.References, Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryValue: relation.Polymorphic.Value,
|
PrimaryValue: relation.Polymorphic.Value,
|
||||||
ForeignKey: relation.Polymorphic.PolymorphicType,
|
ForeignKey: relation.Polymorphic.PolymorphicType,
|
||||||
})
|
})
|
||||||
|
@ -150,7 +150,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
||||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
|
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
relation.References = append(relation.References, Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryKey: primaryKeyField,
|
PrimaryKey: primaryKeyField,
|
||||||
ForeignKey: relation.Polymorphic.PolymorphicID,
|
ForeignKey: relation.Polymorphic.PolymorphicID,
|
||||||
OwnPrimaryKey: true,
|
OwnPrimaryKey: true,
|
||||||
|
@ -246,7 +246,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||||
|
|
||||||
// build references
|
// build references
|
||||||
for _, f := range relation.JoinTable.Fields {
|
for _, f := range relation.JoinTable.Fields {
|
||||||
relation.References = append(relation.References, Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryKey: fieldsMap[f.Name],
|
PrimaryKey: fieldsMap[f.Name],
|
||||||
ForeignKey: f,
|
ForeignKey: f,
|
||||||
OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name],
|
OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name],
|
||||||
|
@ -326,7 +326,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
||||||
|
|
||||||
// build references
|
// build references
|
||||||
for idx, foreignField := range foreignFields {
|
for idx, foreignField := range foreignFields {
|
||||||
relation.References = append(relation.References, Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryKey: primaryFields[idx],
|
PrimaryKey: primaryFields[idx],
|
||||||
ForeignKey: foreignField,
|
ForeignKey: foreignField,
|
||||||
OwnPrimaryKey: schema == primarySchema && guessHas,
|
OwnPrimaryKey: schema == primarySchema && guessHas,
|
||||||
|
|
|
@ -158,9 +158,11 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||||
writer.WriteString("(NULL)")
|
writer.WriteString("(NULL)")
|
||||||
}
|
}
|
||||||
case *DB:
|
case *DB:
|
||||||
result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement
|
subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance()
|
||||||
writer.WriteString(result.SQL.String())
|
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
|
||||||
stmt.Vars = append(stmt.Vars, result.Vars...)
|
subdb.callbacks.Query().Execute(subdb)
|
||||||
|
writer.WriteString(subdb.Statement.SQL.String())
|
||||||
|
stmt.Vars = subdb.Statement.Vars
|
||||||
default:
|
default:
|
||||||
switch rv := reflect.ValueOf(v); rv.Kind() {
|
switch rv := reflect.ValueOf(v); rv.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
. "github.com/jinzhu/gorm/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Person struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Addresses []Address `gorm:"many2many:person_addresses;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Address struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PersonAddress struct {
|
||||||
|
PersonID int
|
||||||
|
AddressID int
|
||||||
|
CreatedAt time.Time
|
||||||
|
DeletedAt gorm.DeletedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOverrideJoinTable(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{})
|
||||||
|
|
||||||
|
if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil {
|
||||||
|
t.Fatalf("Failed to setup join table for person, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
address1 := Address{Name: "address 1"}
|
||||||
|
address2 := Address{Name: "address 2"}
|
||||||
|
person := Person{Name: "person", Addresses: []Address{address1, address2}}
|
||||||
|
DB.Create(&person)
|
||||||
|
|
||||||
|
var addresses1 []Address
|
||||||
|
if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 {
|
||||||
|
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil {
|
||||||
|
t.Fatalf("Failed to delete address, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(person.Addresses) != 1 {
|
||||||
|
t.Fatalf("Should have one address left")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 {
|
||||||
|
t.Fatalf("Should found one address")
|
||||||
|
}
|
||||||
|
|
||||||
|
var addresses2 []Address
|
||||||
|
if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 {
|
||||||
|
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2))
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&person).Association("Addresses").Count() != 1 {
|
||||||
|
t.Fatalf("Should found one address")
|
||||||
|
}
|
||||||
|
|
||||||
|
var addresses3 []Address
|
||||||
|
if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 {
|
||||||
|
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3))
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 {
|
||||||
|
t.Fatalf("Should found soft deleted addresses with unscoped")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 {
|
||||||
|
t.Fatalf("Should found soft deleted addresses with unscoped")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&person).Association("Addresses").Clear()
|
||||||
|
|
||||||
|
if DB.Model(&person).Association("Addresses").Count() != 0 {
|
||||||
|
t.Fatalf("Should deleted all addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 {
|
||||||
|
t.Fatalf("Should found soft deleted addresses with unscoped")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Unscoped().Model(&person).Association("Addresses").Clear()
|
||||||
|
|
||||||
|
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
|
||||||
|
t.Fatalf("address should be deleted when clear with unscoped")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue