From fee1e4aafd39800814c08c8ab4d5c2d1dc773856 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 21 Jun 2020 10:19:16 +0800 Subject: [PATCH] Fix create foreign keys for many2many relations --- gorm.go | 7 ++++++ migrator/migrator.go | 29 ++++++++++++++++++------- schema/naming.go | 2 +- schema/relationship.go | 49 +++++++++++++++++++++++++++++++++++++++++- tests/go.mod | 4 ++-- 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/gorm.go b/gorm.go index a5f8bbfd..e3193f59 100644 --- a/gorm.go +++ b/gorm.go @@ -293,6 +293,13 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac } } + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } + } + relation.JoinTable = joinSchema } else { return fmt.Errorf("failed to found relation: %v", field) diff --git a/migrator/migrator.go b/migrator/migrator.go index b598bd93..90ab7892 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -88,7 +88,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return err } } else { - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { for _, field := range stmt.Schema.FieldsByDBName { if !tx.Migrator().HasColumn(value, field.DBName) { if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { @@ -120,9 +120,13 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + defer func() { + errr = tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + }() } else { - defer tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) + defer func() { + errr = tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) + }() } } } @@ -139,7 +143,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{clause.Table{Name: stmt.Table}} @@ -166,7 +170,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { - defer tx.Migrator().CreateIndex(value, idx.Name) + defer func() { + errr = tx.Migrator().CreateIndex(value, idx.Name) + }() } else { createTableSQL += "INDEX ? ?," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) @@ -186,7 +192,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + defer func(table string, joinValue interface{}) { + errr = tx.Table(table).Migrator().CreateTable(joinValue) + }(rel.JoinTable.Table, joinValue) } } } @@ -204,7 +212,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL += fmt.Sprint(tableOption) } - return tx.Exec(createTableSQL, values...).Error + errr = tx.Exec(createTableSQL, values...).Error + return errr }); err != nil { return err } @@ -553,6 +562,10 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } + + if rel.JoinTable != nil && rel.Schema != rel.FieldSchema { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } } valuesMap[dep.Schema.Table] = dep @@ -566,6 +579,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i if _, ok := orderedModelNamesMap[name]; ok { return // avoid loop } + orderedModelNamesMap[name] = true dep := valuesMap[name] for _, d := range dep.Depends { @@ -578,7 +592,6 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } orderedModelNames = append(orderedModelNames, name) - orderedModelNamesMap[name] = true } for _, value := range values { diff --git a/schema/naming.go b/schema/naming.go index f7c82f32..d2a4919f 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name)) + return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)) } // CheckerName generate checker name diff --git a/schema/relationship.go b/schema/relationship.go index c69a4a09..a13d53b9 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -253,16 +253,63 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel relation.JoinTable.Table = schema.namer.JoinTableName(many2many) relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) + relName := relation.Schema.Name + relRefName := relation.FieldSchema.Name + if relName == relRefName { + relRefName = relation.Field.Name + } + + if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok { + relation.JoinTable.Relationships.Relations[relName] = &Relationship{ + Name: relName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.Schema, + } + } else { + relation.JoinTable.Relationships.Relations[relName].References = []*Reference{} + } + + if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok { + relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{ + Name: relRefName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.FieldSchema, + } + } else { + relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{} + } + // build references for idx, f := range relation.JoinTable.Fields { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType relation.JoinTable.PrimaryFields[idx] = f + ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + + if ownPriamryField { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } else { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, - OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], + OwnPrimaryKey: ownPriamryField, }) } return diff --git a/tests/go.mod b/tests/go.mod index 1cd56f6b..85ef8dcb 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,9 +6,9 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.2 + gorm.io/driver/mysql v0.2.3 gorm.io/driver/postgres v0.2.2 - gorm.io/driver/sqlite v1.0.5 + gorm.io/driver/sqlite v1.0.6 gorm.io/driver/sqlserver v0.2.2 gorm.io/gorm v0.2.9 )