gorm/dialects/mssql/create.go

197 lines
5.1 KiB
Go
Raw Normal View History

2020-05-30 05:43:41 +03:00
package mssql
import (
"reflect"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/clause"
2020-05-30 08:46:33 +03:00
"github.com/jinzhu/gorm/schema"
2020-05-30 05:43:41 +03:00
)
func Create(db *gorm.DB) {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
2020-05-30 07:46:30 +03:00
c := db.Statement.Clauses["ON CONFLICT"]
onConflict, hasConflict := c.Expression.(clause.OnConflict)
if hasConflict {
MergeCreate(db, onConflict)
} else {
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
db.Statement.Build("INSERT")
db.Statement.WriteByte(' ')
db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement))
if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok {
if len(values.Columns) > 0 {
db.Statement.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column)
2020-05-30 05:43:41 +03:00
}
2020-05-30 07:46:30 +03:00
db.Statement.WriteByte(')')
2020-05-30 05:43:41 +03:00
2020-05-30 07:46:30 +03:00
outputInserted(db)
2020-05-30 05:43:41 +03:00
2020-05-30 07:46:30 +03:00
db.Statement.WriteString(" VALUES ")
2020-05-30 05:43:41 +03:00
2020-05-30 07:46:30 +03:00
for idx, value := range values.Values {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteByte('(')
db.Statement.AddVar(db.Statement, value...)
db.Statement.WriteByte(')')
2020-05-30 05:43:41 +03:00
}
2020-05-30 07:46:30 +03:00
db.Statement.WriteString(";")
} else {
db.Statement.WriteString("DEFAULT VALUES")
2020-05-30 05:43:41 +03:00
}
}
}
}
2020-06-01 16:26:23 +03:00
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
2020-05-30 05:43:41 +03:00
2020-06-01 16:26:23 +03:00
if err == nil {
defer rows.Close()
2020-05-30 05:43:41 +03:00
2020-06-01 16:26:23 +03:00
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(db.Statement.Schema.PrimaryFields) > 0 {
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
2020-05-31 07:52:49 +03:00
2020-06-01 16:26:23 +03:00
for rows.Next() {
for idx, field := range db.Statement.Schema.PrimaryFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
}
2020-05-31 07:52:49 +03:00
2020-06-01 16:26:23 +03:00
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
2020-05-30 05:43:41 +03:00
}
2020-06-01 16:26:23 +03:00
case reflect.Struct:
if len(db.Statement.Schema.PrimaryFields) > 0 {
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
2020-05-31 07:52:49 +03:00
2020-06-01 16:26:23 +03:00
for idx, field := range db.Statement.Schema.PrimaryFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
2020-05-31 07:52:49 +03:00
2020-06-01 16:26:23 +03:00
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
2020-05-31 07:52:49 +03:00
}
2020-05-30 05:43:41 +03:00
}
2020-06-01 16:26:23 +03:00
} else {
db.AddError(err)
2020-05-30 05:43:41 +03:00
}
}
}
2020-05-30 07:46:30 +03:00
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
values := callbacks.ConvertToCreateValues(db.Statement)
2020-05-30 08:46:33 +03:00
setIdentityInsert := false
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
if field.DataType == schema.Int || field.DataType == schema.Uint {
setIdentityInsert = true
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString("ON;")
}
}
2020-05-30 07:46:30 +03:00
db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString(" USING (VALUES")
for idx, value := range values.Values {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteByte('(')
db.Statement.AddVar(db.Statement, value...)
db.Statement.WriteByte(')')
}
db.Statement.WriteString(") AS source (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column.Name)
}
db.Statement.WriteString(") ON ")
var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "source", Name: field.DBName},
})
}
where.Build(db.Statement)
if len(onConflict.DoUpdates) > 0 {
db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}
db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column.Name)
}
db.Statement.WriteString(") VALUES (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(clause.Column{
Table: "source",
Name: column.Name,
})
}
db.Statement.WriteString(")")
outputInserted(db)
db.Statement.WriteString(";")
2020-05-30 08:46:33 +03:00
if setIdentityInsert {
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString("OFF;")
}
2020-05-30 07:46:30 +03:00
}
func outputInserted(db *gorm.DB) {
2020-05-31 07:52:49 +03:00
if len(db.Statement.Schema.PrimaryFields) > 0 {
db.Statement.WriteString(" OUTPUT ")
for idx, field := range db.Statement.Schema.PrimaryFields {
if idx > 0 {
db.Statement.WriteString(",")
}
db.Statement.WriteString(" INSERTED.")
db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName})
}
2020-05-30 07:46:30 +03:00
}
}