diff --git a/association.go b/association.go index 55dd7772..23e5a82f 100644 --- a/association.go +++ b/association.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" @@ -44,7 +45,16 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro tx = association.DB.Model(out) ) - if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped { + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE", "LIMIT") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + tx.Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: queryConds}, @@ -321,10 +331,13 @@ func (association *Association) Count() (count int64) { ) if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - tx.Clauses(queryClause) + joinStmt.AddClause(queryClause) } + joinStmt.Build("WHERE", "LIMIT") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } tx.Clauses(clause.From{Joins: []clause.Join{{ diff --git a/callbacks/create.go b/callbacks/create.go index 01329141..0277407e 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -169,12 +169,12 @@ func CreateWithReturning(db *gorm.DB) { if err != nil { db.AddError(err) } + } + } else if !db.DryRun { + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() } else { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) - } + db.AddError(err) } } } diff --git a/gorm.go b/gorm.go index 7d6bd2ed..fd0d4b7e 100644 --- a/gorm.go +++ b/gorm.go @@ -108,6 +108,7 @@ func (db *DB) Session(config *Session) *DB { if config.Context != nil { if tx.Statement != nil { tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx } else { tx.Statement = &Statement{ DB: tx, @@ -181,6 +182,42 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return nil, false } +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + if err := stmt.Parse(model); err == nil { + modelSchema = stmt.Schema + } else { + return err + } + + if err := stmt.Parse(joinTable); err == nil { + joinSchema = stmt.Schema + } else { + return err + } + + if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { + for _, ref := range relation.References { + if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + ref.ForeignKey = f + } else { + return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + } + } + + relation.JoinTable = joinSchema + } else { + return fmt.Errorf("failed to found relation: %v", field) + } + + return nil +} + // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks diff --git a/schema/relationship.go b/schema/relationship.go index dffe5988..194fbeff 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -33,7 +33,7 @@ type Relationship struct { Type RelationshipType Field *Field Polymorphic *Polymorphic - References []Reference + References []*Reference Schema *Schema FieldSchema *Schema JoinTable *Schema @@ -139,7 +139,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } if schema.err == nil { - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryValue: relation.Polymorphic.Value, ForeignKey: relation.Polymorphic.PolymorphicType, }) @@ -150,7 +150,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) } } - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, ForeignKey: relation.Polymorphic.PolymorphicID, OwnPrimaryKey: true, @@ -246,7 +246,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], @@ -326,7 +326,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH // build references for idx, foreignField := range foreignFields { - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, OwnPrimaryKey: schema == primarySchema && guessHas, diff --git a/statement.go b/statement.go index 03d1b8a8..e78dfea9 100644 --- a/statement.go +++ b/statement.go @@ -158,9 +158,11 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement - writer.WriteString(result.SQL.String()) - stmt.Vars = append(stmt.Vars, result.Vars...) + subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() + subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) + subdb.callbacks.Query().Execute(subdb) + writer.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars default: switch rv := reflect.ValueOf(v); rv.Kind() { case reflect.Slice, reflect.Array: diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go new file mode 100644 index 00000000..091ca65c --- /dev/null +++ b/tests/joins_table_test.go @@ -0,0 +1,99 @@ +package tests_test + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +type Person struct { + ID int + Name string + Addresses []Address `gorm:"many2many:person_addresses;"` +} + +type Address struct { + ID uint + Name string +} + +type PersonAddress struct { + PersonID int + AddressID int + CreatedAt time.Time + DeletedAt gorm.DeletedAt +} + +func TestOverrideJoinTable(t *testing.T) { + DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) + + if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { + t.Fatalf("Failed to setup join table for person, got error %v", err) + } + + if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { + t.Fatalf("Failed to migrate, got %v", err) + } + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + person := Person{Name: "person", Addresses: []Address{address1, address2}} + DB.Create(&person) + + var addresses1 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) + } + + if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { + t.Fatalf("Failed to delete address, got error %v", err) + } + + if len(person.Addresses) != 1 { + t.Fatalf("Should have one address left") + } + + if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { + t.Fatalf("Should found one address") + } + + var addresses2 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) + } + + if DB.Model(&person).Association("Addresses").Count() != 1 { + t.Fatalf("Should found one address") + } + + var addresses3 []Address + if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) + } + + if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Model(&person).Association("Addresses").Clear() + + if DB.Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("Should deleted all addresses") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Unscoped().Model(&person).Association("Addresses").Clear() + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("address should be deleted when clear with unscoped") + } +}