diff --git a/clause/clause.go b/clause/clause.go index 59b229ce..9a5d1273 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -8,9 +8,7 @@ type Interface interface { } // ClauseBuilder clause builder, allows to custmize how to build clause -type ClauseBuilder interface { - Build(Clause, Builder) -} +type ClauseBuilder func(Clause, Builder) type Writer interface { WriteByte(byte) error @@ -38,7 +36,7 @@ type Clause struct { // Build build clause func (c Clause) Build(builder Builder) { if c.Builder != nil { - c.Builder.Build(c, builder) + c.Builder(c, builder) } else { builders := c.BeforeExpressions if c.Name != "" { diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 7b8f0491..6ca9f5f5 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -26,9 +26,43 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) db.ConnPool, err = sql.Open("mysql", dialector.DSN) + + for k, v := range dialector.ClauseBuilders() { + db.ClauseBuilders[k] = v + } return } +func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + return map[string]clause.ClauseBuilder{ + "ON CONFLICT": func(c clause.Clause, builder clause.Builder) { + if onConflict, ok := c.Expression.(clause.OnConflict); ok { + builder.WriteString("ON DUPLICATE KEY UPDATE ") + if len(onConflict.DoUpdates) == 0 { + if s := builder.(*gorm.Statement).Schema; s != nil { + var column clause.Column + onConflict.DoNothing = false + + if s.PrioritizedPrimaryField != nil { + column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} + } else { + for _, field := range s.FieldsByDBName { + column = clause.Column{Name: field.DBName} + break + } + } + onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} + } + } + + onConflict.DoUpdates.Build(builder) + } else { + c.Build(builder) + } + }, + } +} + func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ DB: db, diff --git a/gorm.go b/gorm.go index 1fa69383..942024cf 100644 --- a/gorm.go +++ b/gorm.go @@ -95,6 +95,10 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { db.callbacks = initializeCallbacks(db) + if config.ClauseBuilders == nil { + config.ClauseBuilders = map[string]clause.ClauseBuilder{} + } + if dialector != nil { err = dialector.Initialize(db) } diff --git a/statement.go b/statement.go index 626ca689..f81ae0e5 100644 --- a/statement.go +++ b/statement.go @@ -286,7 +286,7 @@ func (stmt *Statement) Build(clauses ...string) { firstClauseWritten = true if b, ok := stmt.DB.ClauseBuilders[name]; ok { - b.Build(c, stmt) + b(c, stmt) } else { c.Build(stmt) }