diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index 4aecce10..9183ba76 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -16,49 +16,48 @@ func Create(db *gorm.DB) { } if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) + c := db.Statement.Clauses["ON CONFLICT"] + onConflict, hasConflict := c.Expression.(clause.OnConflict) - db.Statement.Build("INSERT") - db.Statement.WriteByte(' ') - - c := db.Statement.Clauses["VALUES"] - if values, ok := c.Expression.(clause.Values); ok { - if len(values.Columns) > 0 { - db.Statement.WriteByte('(') - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column) - } - db.Statement.WriteByte(')') - - if db.Statement.Schema.PrioritizedPrimaryField != nil { - db.Statement.WriteString(" OUTPUT INSERTED.") - db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) - } - - db.Statement.WriteString(" VALUES ") - - for idx, value := range values.Values { - if idx > 0 { - db.Statement.WriteByte(',') - } + if hasConflict { + MergeCreate(db, onConflict) + } else { + db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) + db.Statement.Build("INSERT") + db.Statement.WriteByte(' ') + db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) + if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok { + if len(values.Columns) > 0 { db.Statement.WriteByte('(') - db.Statement.AddVar(db.Statement, value...) + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column) + } db.Statement.WriteByte(')') + + outputInserted(db) + + db.Statement.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + db.Statement.WriteByte(',') + } + + db.Statement.WriteByte('(') + db.Statement.AddVar(db.Statement, value...) + db.Statement.WriteByte(')') + } + + db.Statement.WriteString(";") + } else { + db.Statement.WriteString("DEFAULT VALUES") } - } else { - db.Statement.WriteString("DEFAULT VALUES") } } - - db.Statement.WriteByte(' ') - db.Statement.Build("ON CONFLICT") } rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -93,3 +92,75 @@ func Create(db *gorm.DB) { db.AddError(err) } } + +func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { + values := callbacks.ConvertToCreateValues(db.Statement) + + db.Statement.WriteString("MERGE INTO ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString(" USING (VALUES") + for idx, value := range values.Values { + if idx > 0 { + db.Statement.WriteByte(',') + } + + db.Statement.WriteByte('(') + db.Statement.AddVar(db.Statement, value...) + db.Statement.WriteByte(')') + } + + db.Statement.WriteString(") AS source (") + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column.Name) + } + db.Statement.WriteString(") ON ") + + var where clause.Where + for _, field := range db.Statement.Schema.PrimaryFields { + where.Exprs = append(where.Exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, + Value: clause.Column{Table: "source", Name: field.DBName}, + }) + } + where.Build(db.Statement) + + if len(onConflict.DoUpdates) > 0 { + db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") + onConflict.DoUpdates.Build(db.Statement) + } + + db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") + + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column.Name) + } + + db.Statement.WriteString(") VALUES (") + + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(clause.Column{ + Table: "source", + Name: column.Name, + }) + } + + db.Statement.WriteString(")") + outputInserted(db) + db.Statement.WriteString(";") +} + +func outputInserted(db *gorm.DB) { + if db.Statement.Schema.PrioritizedPrimaryField != nil { + db.Statement.WriteString(" OUTPUT INSERTED.") + db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) + } +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 35fcb484..de82f375 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -112,7 +112,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { if size > 0 && size <= 4000 { return fmt.Sprintf("nvarchar(%d)", size) } - return "ntext" + return "nvarchar(MAX)" case schema.Time: return "datetimeoffset" case schema.Bytes: