OnConflict support for mysql

This commit is contained in:
Jinzhu 2020-05-29 22:34:35 +08:00
parent 55074213bc
commit d05128be78
4 changed files with 41 additions and 5 deletions

View File

@ -8,9 +8,7 @@ type Interface interface {
} }
// ClauseBuilder clause builder, allows to custmize how to build clause // ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder interface { type ClauseBuilder func(Clause, Builder)
Build(Clause, Builder)
}
type Writer interface { type Writer interface {
WriteByte(byte) error WriteByte(byte) error
@ -38,7 +36,7 @@ type Clause struct {
// Build build clause // Build build clause
func (c Clause) Build(builder Builder) { func (c Clause) Build(builder Builder) {
if c.Builder != nil { if c.Builder != nil {
c.Builder.Build(c, builder) c.Builder(c, builder)
} else { } else {
builders := c.BeforeExpressions builders := c.BeforeExpressions
if c.Name != "" { if c.Name != "" {

View File

@ -26,9 +26,43 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
db.ConnPool, err = sql.Open("mysql", dialector.DSN) db.ConnPool, err = sql.Open("mysql", dialector.DSN)
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return return
} }
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"ON CONFLICT": func(c clause.Clause, builder clause.Builder) {
if onConflict, ok := c.Expression.(clause.OnConflict); ok {
builder.WriteString("ON DUPLICATE KEY UPDATE ")
if len(onConflict.DoUpdates) == 0 {
if s := builder.(*gorm.Statement).Schema; s != nil {
var column clause.Column
onConflict.DoNothing = false
if s.PrioritizedPrimaryField != nil {
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
} else {
for _, field := range s.FieldsByDBName {
column = clause.Column{Name: field.DBName}
break
}
}
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
}
}
onConflict.DoUpdates.Build(builder)
} else {
c.Build(builder)
}
},
}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{ return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db, DB: db,

View File

@ -95,6 +95,10 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
db.callbacks = initializeCallbacks(db) db.callbacks = initializeCallbacks(db)
if config.ClauseBuilders == nil {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
}
if dialector != nil { if dialector != nil {
err = dialector.Initialize(db) err = dialector.Initialize(db)
} }

View File

@ -286,7 +286,7 @@ func (stmt *Statement) Build(clauses ...string) {
firstClauseWritten = true firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok { if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b.Build(c, stmt) b(c, stmt)
} else { } else {
c.Build(stmt) c.Build(stmt)
} }