Support mssql merge

This commit is contained in:
Jinzhu 2020-05-30 12:46:30 +08:00
parent c07a08d88b
commit cc07ee0444
2 changed files with 108 additions and 37 deletions

View File

@ -16,16 +16,18 @@ func Create(db *gorm.DB) {
} }
if db.Statement.SQL.String() == "" { if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{ c := db.Statement.Clauses["ON CONFLICT"]
Table: clause.Table{Name: db.Statement.Table}, onConflict, hasConflict := c.Expression.(clause.OnConflict)
})
db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement))
if hasConflict {
MergeCreate(db, onConflict)
} else {
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
db.Statement.Build("INSERT") db.Statement.Build("INSERT")
db.Statement.WriteByte(' ') db.Statement.WriteByte(' ')
c := db.Statement.Clauses["VALUES"] db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement))
if values, ok := c.Expression.(clause.Values); ok { if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok {
if len(values.Columns) > 0 { if len(values.Columns) > 0 {
db.Statement.WriteByte('(') db.Statement.WriteByte('(')
for idx, column := range values.Columns { for idx, column := range values.Columns {
@ -36,10 +38,7 @@ func Create(db *gorm.DB) {
} }
db.Statement.WriteByte(')') db.Statement.WriteByte(')')
if db.Statement.Schema.PrioritizedPrimaryField != nil { outputInserted(db)
db.Statement.WriteString(" OUTPUT INSERTED.")
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName)
}
db.Statement.WriteString(" VALUES ") db.Statement.WriteString(" VALUES ")
@ -52,13 +51,13 @@ func Create(db *gorm.DB) {
db.Statement.AddVar(db.Statement, value...) db.Statement.AddVar(db.Statement, value...)
db.Statement.WriteByte(')') db.Statement.WriteByte(')')
} }
db.Statement.WriteString(";")
} else { } else {
db.Statement.WriteString("DEFAULT VALUES") db.Statement.WriteString("DEFAULT VALUES")
} }
} }
}
db.Statement.WriteByte(' ')
db.Statement.Build("ON CONFLICT")
} }
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
@ -93,3 +92,75 @@ func Create(db *gorm.DB) {
db.AddError(err) db.AddError(err)
} }
} }
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
values := callbacks.ConvertToCreateValues(db.Statement)
db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString(" USING (VALUES")
for idx, value := range values.Values {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteByte('(')
db.Statement.AddVar(db.Statement, value...)
db.Statement.WriteByte(')')
}
db.Statement.WriteString(") AS source (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column.Name)
}
db.Statement.WriteString(") ON ")
var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "source", Name: field.DBName},
})
}
where.Build(db.Statement)
if len(onConflict.DoUpdates) > 0 {
db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}
db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column.Name)
}
db.Statement.WriteString(") VALUES (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(clause.Column{
Table: "source",
Name: column.Name,
})
}
db.Statement.WriteString(")")
outputInserted(db)
db.Statement.WriteString(";")
}
func outputInserted(db *gorm.DB) {
if db.Statement.Schema.PrioritizedPrimaryField != nil {
db.Statement.WriteString(" OUTPUT INSERTED.")
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName)
}
}

View File

@ -112,7 +112,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
if size > 0 && size <= 4000 { if size > 0 && size <= 4000 {
return fmt.Sprintf("nvarchar(%d)", size) return fmt.Sprintf("nvarchar(%d)", size)
} }
return "ntext" return "nvarchar(MAX)"
case schema.Time: case schema.Time:
return "datetimeoffset" return "datetimeoffset"
case schema.Bytes: case schema.Bytes: