mirror of https://github.com/go-gorm/gorm.git
Support mssql merge
This commit is contained in:
parent
c07a08d88b
commit
cc07ee0444
|
@ -16,49 +16,48 @@ func Create(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
if db.Statement.SQL.String() == "" {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
c := db.Statement.Clauses["ON CONFLICT"]
|
||||||
Table: clause.Table{Name: db.Statement.Table},
|
onConflict, hasConflict := c.Expression.(clause.OnConflict)
|
||||||
})
|
|
||||||
db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement))
|
|
||||||
|
|
||||||
db.Statement.Build("INSERT")
|
if hasConflict {
|
||||||
db.Statement.WriteByte(' ')
|
MergeCreate(db, onConflict)
|
||||||
|
} else {
|
||||||
c := db.Statement.Clauses["VALUES"]
|
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
|
||||||
if values, ok := c.Expression.(clause.Values); ok {
|
db.Statement.Build("INSERT")
|
||||||
if len(values.Columns) > 0 {
|
db.Statement.WriteByte(' ')
|
||||||
db.Statement.WriteByte('(')
|
|
||||||
for idx, column := range values.Columns {
|
|
||||||
if idx > 0 {
|
|
||||||
db.Statement.WriteByte(',')
|
|
||||||
}
|
|
||||||
db.Statement.WriteQuoted(column)
|
|
||||||
}
|
|
||||||
db.Statement.WriteByte(')')
|
|
||||||
|
|
||||||
if db.Statement.Schema.PrioritizedPrimaryField != nil {
|
|
||||||
db.Statement.WriteString(" OUTPUT INSERTED.")
|
|
||||||
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Statement.WriteString(" VALUES ")
|
|
||||||
|
|
||||||
for idx, value := range values.Values {
|
|
||||||
if idx > 0 {
|
|
||||||
db.Statement.WriteByte(',')
|
|
||||||
}
|
|
||||||
|
|
||||||
|
db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement))
|
||||||
|
if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok {
|
||||||
|
if len(values.Columns) > 0 {
|
||||||
db.Statement.WriteByte('(')
|
db.Statement.WriteByte('(')
|
||||||
db.Statement.AddVar(db.Statement, value...)
|
for idx, column := range values.Columns {
|
||||||
|
if idx > 0 {
|
||||||
|
db.Statement.WriteByte(',')
|
||||||
|
}
|
||||||
|
db.Statement.WriteQuoted(column)
|
||||||
|
}
|
||||||
db.Statement.WriteByte(')')
|
db.Statement.WriteByte(')')
|
||||||
|
|
||||||
|
outputInserted(db)
|
||||||
|
|
||||||
|
db.Statement.WriteString(" VALUES ")
|
||||||
|
|
||||||
|
for idx, value := range values.Values {
|
||||||
|
if idx > 0 {
|
||||||
|
db.Statement.WriteByte(',')
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.WriteByte('(')
|
||||||
|
db.Statement.AddVar(db.Statement, value...)
|
||||||
|
db.Statement.WriteByte(')')
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.WriteString(";")
|
||||||
|
} else {
|
||||||
|
db.Statement.WriteString("DEFAULT VALUES")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
db.Statement.WriteString("DEFAULT VALUES")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.WriteByte(' ')
|
|
||||||
db.Statement.Build("ON CONFLICT")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
@ -93,3 +92,75 @@ func Create(db *gorm.DB) {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||||
|
values := callbacks.ConvertToCreateValues(db.Statement)
|
||||||
|
|
||||||
|
db.Statement.WriteString("MERGE INTO ")
|
||||||
|
db.Statement.WriteQuoted(db.Statement.Table)
|
||||||
|
db.Statement.WriteString(" USING (VALUES")
|
||||||
|
for idx, value := range values.Values {
|
||||||
|
if idx > 0 {
|
||||||
|
db.Statement.WriteByte(',')
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.WriteByte('(')
|
||||||
|
db.Statement.AddVar(db.Statement, value...)
|
||||||
|
db.Statement.WriteByte(')')
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.WriteString(") AS source (")
|
||||||
|
for idx, column := range values.Columns {
|
||||||
|
if idx > 0 {
|
||||||
|
db.Statement.WriteByte(',')
|
||||||
|
}
|
||||||
|
db.Statement.WriteQuoted(column.Name)
|
||||||
|
}
|
||||||
|
db.Statement.WriteString(") ON ")
|
||||||
|
|
||||||
|
var where clause.Where
|
||||||
|
for _, field := range db.Statement.Schema.PrimaryFields {
|
||||||
|
where.Exprs = append(where.Exprs, clause.Eq{
|
||||||
|
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
|
||||||
|
Value: clause.Column{Table: "source", Name: field.DBName},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
where.Build(db.Statement)
|
||||||
|
|
||||||
|
if len(onConflict.DoUpdates) > 0 {
|
||||||
|
db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
|
||||||
|
onConflict.DoUpdates.Build(db.Statement)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
|
||||||
|
|
||||||
|
for idx, column := range values.Columns {
|
||||||
|
if idx > 0 {
|
||||||
|
db.Statement.WriteByte(',')
|
||||||
|
}
|
||||||
|
db.Statement.WriteQuoted(column.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.WriteString(") VALUES (")
|
||||||
|
|
||||||
|
for idx, column := range values.Columns {
|
||||||
|
if idx > 0 {
|
||||||
|
db.Statement.WriteByte(',')
|
||||||
|
}
|
||||||
|
db.Statement.WriteQuoted(clause.Column{
|
||||||
|
Table: "source",
|
||||||
|
Name: column.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.WriteString(")")
|
||||||
|
outputInserted(db)
|
||||||
|
db.Statement.WriteString(";")
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputInserted(db *gorm.DB) {
|
||||||
|
if db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||||
|
db.Statement.WriteString(" OUTPUT INSERTED.")
|
||||||
|
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||||
if size > 0 && size <= 4000 {
|
if size > 0 && size <= 4000 {
|
||||||
return fmt.Sprintf("nvarchar(%d)", size)
|
return fmt.Sprintf("nvarchar(%d)", size)
|
||||||
}
|
}
|
||||||
return "ntext"
|
return "nvarchar(MAX)"
|
||||||
case schema.Time:
|
case schema.Time:
|
||||||
return "datetimeoffset"
|
return "datetimeoffset"
|
||||||
case schema.Bytes:
|
case schema.Bytes:
|
||||||
|
|
Loading…
Reference in New Issue