diff --git a/gorm.go b/gorm.go index 7f7bad26..71cd01e8 100644 --- a/gorm.go +++ b/gorm.go @@ -387,44 +387,46 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac modelSchema, joinSchema *schema.Schema ) - if err := stmt.Parse(model); err == nil { - modelSchema = stmt.Schema - } else { + err := stmt.Parse(model) + if err != nil { return err } + modelSchema = stmt.Schema - if err := stmt.Parse(joinTable); err == nil { - joinSchema = stmt.Schema - } else { + err = stmt.Parse(joinTable) + if err != nil { return err } + joinSchema = stmt.Schema - if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { - for _, ref := range relation.References { - if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { - f.DataType = ref.ForeignKey.DataType - f.GORMDataType = ref.ForeignKey.GORMDataType - if f.Size == 0 { - f.Size = ref.ForeignKey.Size - } - ref.ForeignKey = f - } else { - return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) - } - } - - for name, rel := range relation.JoinTable.Relationships.Relations { - if _, ok := joinSchema.Relationships.Relations[name]; !ok { - rel.Schema = joinSchema - joinSchema.Relationships.Relations[name] = rel - } - } - - relation.JoinTable = joinSchema - } else { + relation, ok := modelSchema.Relationships.Relations[field] + isRelation := ok && relation.JoinTable != nil + if !isRelation { return fmt.Errorf("failed to found relation: %s", field) } + for _, ref := range relation.References { + f := joinSchema.LookUpField(ref.ForeignKey.DBName) + if f == nil { + return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) + } + + f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } + ref.ForeignKey = f + } + + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } + } + relation.JoinTable = joinSchema + return nil }