Add SetupJoinTable support

This commit is contained in:
Jinzhu 2020-06-01 22:31:50 +08:00
parent db03616993
commit e490e09db5
6 changed files with 167 additions and 16 deletions

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
@ -44,7 +45,16 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro
tx = association.DB.Model(out) tx = association.DB.Model(out)
) )
if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped { if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{ tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table}, Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds}, ON: clause.Where{Exprs: queryConds},
@ -321,10 +331,13 @@ func (association *Association) Count() (count int64) {
) )
if association.Relationship.JoinTable != nil { if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped { if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses { for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
tx.Clauses(queryClause) joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
} }
tx.Clauses(clause.From{Joins: []clause.Join{{ tx.Clauses(clause.From{Joins: []clause.Join{{

View File

@ -169,7 +169,8 @@ func CreateWithReturning(db *gorm.DB) {
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
} }
} else { }
} else if !db.DryRun {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} else { } else {
@ -178,7 +179,6 @@ func CreateWithReturning(db *gorm.DB) {
} }
} }
} }
}
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {

37
gorm.go
View File

@ -108,6 +108,7 @@ func (db *DB) Session(config *Session) *DB {
if config.Context != nil { if config.Context != nil {
if tx.Statement != nil { if tx.Statement != nil {
tx.Statement = tx.Statement.clone() tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
} else { } else {
tx.Statement = &Statement{ tx.Statement = &Statement{
DB: tx, DB: tx,
@ -181,6 +182,42 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
return nil, false return nil, false
} }
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}
// Callback returns callback manager // Callback returns callback manager
func (db *DB) Callback() *callbacks { func (db *DB) Callback() *callbacks {
return db.callbacks return db.callbacks

View File

@ -33,7 +33,7 @@ type Relationship struct {
Type RelationshipType Type RelationshipType
Field *Field Field *Field
Polymorphic *Polymorphic Polymorphic *Polymorphic
References []Reference References []*Reference
Schema *Schema Schema *Schema
FieldSchema *Schema FieldSchema *Schema
JoinTable *Schema JoinTable *Schema
@ -139,7 +139,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
} }
if schema.err == nil { if schema.err == nil {
relation.References = append(relation.References, Reference{ relation.References = append(relation.References, &Reference{
PrimaryValue: relation.Polymorphic.Value, PrimaryValue: relation.Polymorphic.Value,
ForeignKey: relation.Polymorphic.PolymorphicType, ForeignKey: relation.Polymorphic.PolymorphicType,
}) })
@ -150,7 +150,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
} }
} }
relation.References = append(relation.References, Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField, PrimaryKey: primaryKeyField,
ForeignKey: relation.Polymorphic.PolymorphicID, ForeignKey: relation.Polymorphic.PolymorphicID,
OwnPrimaryKey: true, OwnPrimaryKey: true,
@ -246,7 +246,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
// build references // build references
for _, f := range relation.JoinTable.Fields { for _, f := range relation.JoinTable.Fields {
relation.References = append(relation.References, Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name], PrimaryKey: fieldsMap[f.Name],
ForeignKey: f, ForeignKey: f,
OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name],
@ -326,7 +326,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
// build references // build references
for idx, foreignField := range foreignFields { for idx, foreignField := range foreignFields {
relation.References = append(relation.References, Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx], PrimaryKey: primaryFields[idx],
ForeignKey: foreignField, ForeignKey: foreignField,
OwnPrimaryKey: schema == primarySchema && guessHas, OwnPrimaryKey: schema == primarySchema && guessHas,

View File

@ -158,9 +158,11 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
writer.WriteString("(NULL)") writer.WriteString("(NULL)")
} }
case *DB: case *DB:
result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance()
writer.WriteString(result.SQL.String()) subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
stmt.Vars = append(stmt.Vars, result.Vars...) subdb.callbacks.Query().Execute(subdb)
writer.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
default: default:
switch rv := reflect.ValueOf(v); rv.Kind() { switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:

99
tests/joins_table_test.go Normal file
View File

@ -0,0 +1,99 @@
package tests_test
import (
"testing"
"time"
"github.com/jinzhu/gorm"
. "github.com/jinzhu/gorm/tests"
)
type Person struct {
ID int
Name string
Addresses []Address `gorm:"many2many:person_addresses;"`
}
type Address struct {
ID uint
Name string
}
type PersonAddress struct {
PersonID int
AddressID int
CreatedAt time.Time
DeletedAt gorm.DeletedAt
}
func TestOverrideJoinTable(t *testing.T) {
DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{})
if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil {
t.Fatalf("Failed to setup join table for person, got error %v", err)
}
if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil {
t.Fatalf("Failed to migrate, got %v", err)
}
address1 := Address{Name: "address 1"}
address2 := Address{Name: "address 2"}
person := Person{Name: "person", Addresses: []Address{address1, address2}}
DB.Create(&person)
var addresses1 []Address
if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 {
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1))
}
if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil {
t.Fatalf("Failed to delete address, got error %v", err)
}
if len(person.Addresses) != 1 {
t.Fatalf("Should have one address left")
}
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 {
t.Fatalf("Should found one address")
}
var addresses2 []Address
if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 {
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2))
}
if DB.Model(&person).Association("Addresses").Count() != 1 {
t.Fatalf("Should found one address")
}
var addresses3 []Address
if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 {
t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3))
}
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 {
t.Fatalf("Should found soft deleted addresses with unscoped")
}
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 {
t.Fatalf("Should found soft deleted addresses with unscoped")
}
DB.Model(&person).Association("Addresses").Clear()
if DB.Model(&person).Association("Addresses").Count() != 0 {
t.Fatalf("Should deleted all addresses")
}
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 {
t.Fatalf("Should found soft deleted addresses with unscoped")
}
DB.Unscoped().Model(&person).Association("Addresses").Clear()
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
t.Fatalf("address should be deleted when clear with unscoped")
}
}