diff --git a/callback_shared.go b/callback_shared.go index 324d121d..b49ab017 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" "reflect" + "strings" ) func BeginTransaction(scope *Scope) { @@ -49,12 +50,36 @@ func SaveAfterAssociations(scope *Scope) { newDB := scope.NewDB() elem := value.Index(i).Addr().Interface() - if field.JoinTable != nil && field.JoinTable.foreignKey != "" { + if field.JoinTable != nil && field.JoinTable.joinTable == "" && field.JoinTable.foreignKey != "" { newDB.NewScope(elem).SetColumn(field.JoinTable.foreignKey, scope.PrimaryKeyValue()) } scope.Err(newDB.Save(elem).Error) - fmt.Sprintf("INSERT INTO %v (%v, %v) SELECT (%v, %v) FROM %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v) limit 1;") + + if field.JoinTable != nil && field.JoinTable.joinTable != "" { + newScope := scope.New(elem) + joinTable := field.JoinTable.joinTable + foreignKey := ToSnake(field.JoinTable.foreignKey) + foreignValue := fmt.Sprintf("%v", scope.PrimaryKeyValue()) + associationForeignKey := ToSnake(field.JoinTable.associationForeignKey) + associationForeignValue := fmt.Sprintf("%v", newScope.PrimaryKeyValue()) + + newScope.Raw(fmt.Sprintf( + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);", + joinTable, + strings.Join([]string{scope.Quote(foreignKey), scope.Quote(associationForeignKey)}, ","), + strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","), + scope.Dialect().SelectFromDummyTable(), + joinTable, + scope.Quote(foreignKey), + newScope.AddToVars(foreignValue), + scope.Quote(associationForeignKey), + newScope.AddToVars(associationForeignValue), + )) + if _, err := scope.DB().Exec(newScope.Sql, newScope.SqlVars...); err != nil { + scope.Err(err) + } + } } default: newDB := scope.NewDB() diff --git a/common_dialect.go b/common_dialect.go index ec5fe7ec..e54df6ff 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -64,6 +64,10 @@ func (s *commonDialect) ReturningStr(key string) string { return "" } +func (s *commonDialect) SelectFromDummyTable() string { + return "" +} + func (s *commonDialect) Quote(key string) string { return fmt.Sprintf("`%s`", key) } diff --git a/dialect.go b/dialect.go index 702660ae..771a7ac4 100644 --- a/dialect.go +++ b/dialect.go @@ -14,6 +14,7 @@ type Dialect interface { SqlTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string ReturningStr(key string) string + SelectFromDummyTable() string Quote(key string) string HasTable(scope *Scope, tableName string) bool HasColumn(scope *Scope, tableName string, columnName string) bool diff --git a/mysql.go b/mysql.go index afb748a7..1332add7 100644 --- a/mysql.go +++ b/mysql.go @@ -65,6 +65,10 @@ func (s *mysql) ReturningStr(key string) string { return "" } +func (s *mysql) SelectFromDummyTable() string { + return "FROM DUAL" +} + func (s *mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } diff --git a/postgres.go b/postgres.go index b91b1309..c137f182 100644 --- a/postgres.go +++ b/postgres.go @@ -66,6 +66,10 @@ func (s *postgres) ReturningStr(key string) string { return fmt.Sprintf("RETURNING \"%v\"", key) } +func (s *postgres) SelectFromDummyTable() string { + return "" +} + func (s *postgres) Quote(key string) string { return fmt.Sprintf("\"%s\"", key) } diff --git a/scope.go b/scope.go index ccf50dde..3856a2fc 100644 --- a/scope.go +++ b/scope.go @@ -269,8 +269,8 @@ func (scope *Scope) Fields() []*Field { // parse association typ := indirectValue.Type() - foreignKey := settings["FOREIGNKEY"] - associationForeignKey := settings["ASSOCIATIONFOREIGNKEY"] + foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) + associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) many2many := settings["MANY2MANY"] switch indirectValue.Kind() { diff --git a/scope_private.go b/scope_private.go index e6e91c15..8ea47eed 100644 --- a/scope_private.go +++ b/scope_private.go @@ -432,6 +432,19 @@ func (scope *Scope) createTable() *Scope { if !field.IsIgnored && len(field.SqlTag) > 0 { sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) } + if field.JoinTable != nil && field.JoinTable.joinTable != "" { + if !scope.Dialect().HasTable(scope, field.JoinTable.joinTable) { + newScope := scope.db.NewScope("") + primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255) + newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", + field.JoinTable.joinTable, + strings.Join([]string{ + scope.Quote(ToSnake(field.JoinTable.foreignKey)) + " " + primaryKeySqlType, + scope.Quote(ToSnake(field.JoinTable.associationForeignKey)) + " " + primaryKeySqlType}, ",")), + ).Exec() + scope.Err(newScope.db.Error) + } + } } scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec() return scope diff --git a/sqlite3.go b/sqlite3.go index 795ee2bd..088da6b1 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -56,6 +56,10 @@ func (s *sqlite3) ReturningStr(key string) string { return "" } +func (s *sqlite3) SelectFromDummyTable() string { + return "" +} + func (s *sqlite3) Quote(key string) string { return fmt.Sprintf("\"%s\"", key) }