Save many to many relations into join table

This commit is contained in:
Jinzhu 2014-07-30 14:18:15 +08:00
parent a7aaf151cf
commit 14590a65dc
8 changed files with 59 additions and 4 deletions

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
) )
func BeginTransaction(scope *Scope) { func BeginTransaction(scope *Scope) {
@ -49,12 +50,36 @@ func SaveAfterAssociations(scope *Scope) {
newDB := scope.NewDB() newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface() elem := value.Index(i).Addr().Interface()
if field.JoinTable != nil && field.JoinTable.foreignKey != "" { if field.JoinTable != nil && field.JoinTable.joinTable == "" && field.JoinTable.foreignKey != "" {
newDB.NewScope(elem).SetColumn(field.JoinTable.foreignKey, scope.PrimaryKeyValue()) newDB.NewScope(elem).SetColumn(field.JoinTable.foreignKey, scope.PrimaryKeyValue())
} }
scope.Err(newDB.Save(elem).Error) scope.Err(newDB.Save(elem).Error)
fmt.Sprintf("INSERT INTO %v (%v, %v) SELECT (%v, %v) FROM %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v) limit 1;")
if field.JoinTable != nil && field.JoinTable.joinTable != "" {
newScope := scope.New(elem)
joinTable := field.JoinTable.joinTable
foreignKey := ToSnake(field.JoinTable.foreignKey)
foreignValue := fmt.Sprintf("%v", scope.PrimaryKeyValue())
associationForeignKey := ToSnake(field.JoinTable.associationForeignKey)
associationForeignValue := fmt.Sprintf("%v", newScope.PrimaryKeyValue())
newScope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);",
joinTable,
strings.Join([]string{scope.Quote(foreignKey), scope.Quote(associationForeignKey)}, ","),
strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","),
scope.Dialect().SelectFromDummyTable(),
joinTable,
scope.Quote(foreignKey),
newScope.AddToVars(foreignValue),
scope.Quote(associationForeignKey),
newScope.AddToVars(associationForeignValue),
))
if _, err := scope.DB().Exec(newScope.Sql, newScope.SqlVars...); err != nil {
scope.Err(err)
}
}
} }
default: default:
newDB := scope.NewDB() newDB := scope.NewDB()

View File

@ -64,6 +64,10 @@ func (s *commonDialect) ReturningStr(key string) string {
return "" return ""
} }
func (s *commonDialect) SelectFromDummyTable() string {
return ""
}
func (s *commonDialect) Quote(key string) string { func (s *commonDialect) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf("`%s`", key)
} }

View File

@ -14,6 +14,7 @@ type Dialect interface {
SqlTag(value reflect.Value, size int) string SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(key string) string ReturningStr(key string) string
SelectFromDummyTable() string
Quote(key string) string Quote(key string) string
HasTable(scope *Scope, tableName string) bool HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool HasColumn(scope *Scope, tableName string, columnName string) bool

View File

@ -65,6 +65,10 @@ func (s *mysql) ReturningStr(key string) string {
return "" return ""
} }
func (s *mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s *mysql) Quote(key string) string { func (s *mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf("`%s`", key)
} }

View File

@ -66,6 +66,10 @@ func (s *postgres) ReturningStr(key string) string {
return fmt.Sprintf("RETURNING \"%v\"", key) return fmt.Sprintf("RETURNING \"%v\"", key)
} }
func (s *postgres) SelectFromDummyTable() string {
return ""
}
func (s *postgres) Quote(key string) string { func (s *postgres) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key) return fmt.Sprintf("\"%s\"", key)
} }

View File

@ -269,8 +269,8 @@ func (scope *Scope) Fields() []*Field {
// parse association // parse association
typ := indirectValue.Type() typ := indirectValue.Type()
foreignKey := settings["FOREIGNKEY"] foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"])
associationForeignKey := settings["ASSOCIATIONFOREIGNKEY"] associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"])
many2many := settings["MANY2MANY"] many2many := settings["MANY2MANY"]
switch indirectValue.Kind() { switch indirectValue.Kind() {

View File

@ -432,6 +432,19 @@ func (scope *Scope) createTable() *Scope {
if !field.IsIgnored && len(field.SqlTag) > 0 { if !field.IsIgnored && len(field.SqlTag) > 0 {
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
} }
if field.JoinTable != nil && field.JoinTable.joinTable != "" {
if !scope.Dialect().HasTable(scope, field.JoinTable.joinTable) {
newScope := scope.db.NewScope("")
primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255)
newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)",
field.JoinTable.joinTable,
strings.Join([]string{
scope.Quote(ToSnake(field.JoinTable.foreignKey)) + " " + primaryKeySqlType,
scope.Quote(ToSnake(field.JoinTable.associationForeignKey)) + " " + primaryKeySqlType}, ",")),
).Exec()
scope.Err(newScope.db.Error)
}
}
} }
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec() scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec()
return scope return scope

View File

@ -56,6 +56,10 @@ func (s *sqlite3) ReturningStr(key string) string {
return "" return ""
} }
func (s *sqlite3) SelectFromDummyTable() string {
return ""
}
func (s *sqlite3) Quote(key string) string { func (s *sqlite3) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key) return fmt.Sprintf("\"%s\"", key)
} }