Save many to many relations into join table

This commit is contained in:
Jinzhu 2014-07-30 14:18:15 +08:00
parent a7aaf151cf
commit 14590a65dc
8 changed files with 59 additions and 4 deletions

View File

@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strings"
)
func BeginTransaction(scope *Scope) {
@ -49,12 +50,36 @@ func SaveAfterAssociations(scope *Scope) {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
if field.JoinTable != nil && field.JoinTable.foreignKey != "" {
if field.JoinTable != nil && field.JoinTable.joinTable == "" && field.JoinTable.foreignKey != "" {
newDB.NewScope(elem).SetColumn(field.JoinTable.foreignKey, scope.PrimaryKeyValue())
}
scope.Err(newDB.Save(elem).Error)
fmt.Sprintf("INSERT INTO %v (%v, %v) SELECT (%v, %v) FROM %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v) limit 1;")
if field.JoinTable != nil && field.JoinTable.joinTable != "" {
newScope := scope.New(elem)
joinTable := field.JoinTable.joinTable
foreignKey := ToSnake(field.JoinTable.foreignKey)
foreignValue := fmt.Sprintf("%v", scope.PrimaryKeyValue())
associationForeignKey := ToSnake(field.JoinTable.associationForeignKey)
associationForeignValue := fmt.Sprintf("%v", newScope.PrimaryKeyValue())
newScope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);",
joinTable,
strings.Join([]string{scope.Quote(foreignKey), scope.Quote(associationForeignKey)}, ","),
strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","),
scope.Dialect().SelectFromDummyTable(),
joinTable,
scope.Quote(foreignKey),
newScope.AddToVars(foreignValue),
scope.Quote(associationForeignKey),
newScope.AddToVars(associationForeignValue),
))
if _, err := scope.DB().Exec(newScope.Sql, newScope.SqlVars...); err != nil {
scope.Err(err)
}
}
}
default:
newDB := scope.NewDB()

View File

@ -64,6 +64,10 @@ func (s *commonDialect) ReturningStr(key string) string {
return ""
}
func (s *commonDialect) SelectFromDummyTable() string {
return ""
}
func (s *commonDialect) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}

View File

@ -14,6 +14,7 @@ type Dialect interface {
SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(key string) string
SelectFromDummyTable() string
Quote(key string) string
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool

View File

@ -65,6 +65,10 @@ func (s *mysql) ReturningStr(key string) string {
return ""
}
func (s *mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s *mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}

View File

@ -66,6 +66,10 @@ func (s *postgres) ReturningStr(key string) string {
return fmt.Sprintf("RETURNING \"%v\"", key)
}
func (s *postgres) SelectFromDummyTable() string {
return ""
}
func (s *postgres) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}

View File

@ -269,8 +269,8 @@ func (scope *Scope) Fields() []*Field {
// parse association
typ := indirectValue.Type()
foreignKey := settings["FOREIGNKEY"]
associationForeignKey := settings["ASSOCIATIONFOREIGNKEY"]
foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"])
associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"])
many2many := settings["MANY2MANY"]
switch indirectValue.Kind() {

View File

@ -432,6 +432,19 @@ func (scope *Scope) createTable() *Scope {
if !field.IsIgnored && len(field.SqlTag) > 0 {
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
}
if field.JoinTable != nil && field.JoinTable.joinTable != "" {
if !scope.Dialect().HasTable(scope, field.JoinTable.joinTable) {
newScope := scope.db.NewScope("")
primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255)
newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)",
field.JoinTable.joinTable,
strings.Join([]string{
scope.Quote(ToSnake(field.JoinTable.foreignKey)) + " " + primaryKeySqlType,
scope.Quote(ToSnake(field.JoinTable.associationForeignKey)) + " " + primaryKeySqlType}, ",")),
).Exec()
scope.Err(newScope.db.Error)
}
}
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec()
return scope

View File

@ -56,6 +56,10 @@ func (s *sqlite3) ReturningStr(key string) string {
return ""
}
func (s *sqlite3) SelectFromDummyTable() string {
return ""
}
func (s *sqlite3) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}