OnConflict support for mysql

This commit is contained in:
Jinzhu 2020-05-29 22:34:35 +08:00
parent 55074213bc
commit d05128be78
4 changed files with 41 additions and 5 deletions

View File

@ -8,9 +8,7 @@ type Interface interface {
}
// ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder interface {
Build(Clause, Builder)
}
type ClauseBuilder func(Clause, Builder)
type Writer interface {
WriteByte(byte) error
@ -38,7 +36,7 @@ type Clause struct {
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder.Build(c, builder)
c.Builder(c, builder)
} else {
builders := c.BeforeExpressions
if c.Name != "" {

View File

@ -26,9 +26,43 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"ON CONFLICT": func(c clause.Clause, builder clause.Builder) {
if onConflict, ok := c.Expression.(clause.OnConflict); ok {
builder.WriteString("ON DUPLICATE KEY UPDATE ")
if len(onConflict.DoUpdates) == 0 {
if s := builder.(*gorm.Statement).Schema; s != nil {
var column clause.Column
onConflict.DoNothing = false
if s.PrioritizedPrimaryField != nil {
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
} else {
for _, field := range s.FieldsByDBName {
column = clause.Column{Name: field.DBName}
break
}
}
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
}
}
onConflict.DoUpdates.Build(builder)
} else {
c.Build(builder)
}
},
}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,

View File

@ -95,6 +95,10 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
db.callbacks = initializeCallbacks(db)
if config.ClauseBuilders == nil {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
}
if dialector != nil {
err = dialector.Initialize(db)
}

View File

@ -286,7 +286,7 @@ func (stmt *Statement) Build(clauses ...string) {
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b.Build(c, stmt)
b(c, stmt)
} else {
c.Build(stmt)
}