forked from mirror/gorm
Fix mssql tests
This commit is contained in:
parent
e490e09db5
commit
9807fffdbc
|
@ -2,6 +2,7 @@ package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/callbacks"
|
"github.com/jinzhu/gorm/callbacks"
|
||||||
|
@ -17,10 +18,35 @@ func Create(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
if db.Statement.SQL.String() == "" {
|
||||||
|
setIdentityInsert := false
|
||||||
c := db.Statement.Clauses["ON CONFLICT"]
|
c := db.Statement.Clauses["ON CONFLICT"]
|
||||||
onConflict, hasConflict := c.Expression.(clause.OnConflict)
|
onConflict, hasConflict := c.Expression.(clause.OnConflict)
|
||||||
|
|
||||||
if hasConflict {
|
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
|
||||||
|
setIdentityInsert = false
|
||||||
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
_, isZero := field.ValueOf(db.Statement.ReflectValue)
|
||||||
|
setIdentityInsert = !isZero
|
||||||
|
case reflect.Slice:
|
||||||
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
|
_, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i))
|
||||||
|
setIdentityInsert = !isZero
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) {
|
||||||
|
setIdentityInsert = true
|
||||||
|
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||||
|
db.Statement.WriteQuoted(db.Statement.Table)
|
||||||
|
db.Statement.WriteString(" ON;")
|
||||||
|
} else {
|
||||||
|
setIdentityInsert = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||||
MergeCreate(db, onConflict)
|
MergeCreate(db, onConflict)
|
||||||
} else {
|
} else {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
|
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
|
||||||
|
@ -55,10 +81,16 @@ func Create(db *gorm.DB) {
|
||||||
|
|
||||||
db.Statement.WriteString(";")
|
db.Statement.WriteString(";")
|
||||||
} else {
|
} else {
|
||||||
db.Statement.WriteString("DEFAULT VALUES")
|
db.Statement.WriteString("DEFAULT VALUES;")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if setIdentityInsert {
|
||||||
|
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||||
|
db.Statement.WriteQuoted(db.Statement.Table)
|
||||||
|
db.Statement.WriteString(" OFF;")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun {
|
||||||
|
@ -67,25 +99,32 @@ func Create(db *gorm.DB) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||||
case reflect.Slice, reflect.Array:
|
sortedKeys := []string{}
|
||||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||||
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
sortedKeys = append(sortedKeys, field.DBName)
|
||||||
|
}
|
||||||
|
sort.Strings(sortedKeys)
|
||||||
|
|
||||||
|
returnningFields := make([]*schema.Field, len(sortedKeys))
|
||||||
|
for idx, key := range sortedKeys {
|
||||||
|
returnningFields[idx] = db.Statement.Schema.LookUpField(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]interface{}, len(returnningFields))
|
||||||
|
|
||||||
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
for idx, field := range returnningFields {
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
}
|
}
|
||||||
}
|
case reflect.Struct:
|
||||||
case reflect.Struct:
|
for idx, field := range returnningFields {
|
||||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
|
||||||
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
|
||||||
|
|
||||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,16 +142,6 @@ func Create(db *gorm.DB) {
|
||||||
|
|
||||||
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||||
values := callbacks.ConvertToCreateValues(db.Statement)
|
values := callbacks.ConvertToCreateValues(db.Statement)
|
||||||
setIdentityInsert := false
|
|
||||||
|
|
||||||
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
|
|
||||||
if field.DataType == schema.Int || field.DataType == schema.Uint {
|
|
||||||
setIdentityInsert = true
|
|
||||||
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
|
||||||
db.Statement.WriteQuoted(db.Statement.Table)
|
|
||||||
db.Statement.WriteString("ON;")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Statement.WriteString("MERGE INTO ")
|
db.Statement.WriteString("MERGE INTO ")
|
||||||
db.Statement.WriteQuoted(db.Statement.Table)
|
db.Statement.WriteQuoted(db.Statement.Table)
|
||||||
|
@ -174,23 +203,23 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||||
db.Statement.WriteString(")")
|
db.Statement.WriteString(")")
|
||||||
outputInserted(db)
|
outputInserted(db)
|
||||||
db.Statement.WriteString(";")
|
db.Statement.WriteString(";")
|
||||||
|
|
||||||
if setIdentityInsert {
|
|
||||||
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
|
||||||
db.Statement.WriteQuoted(db.Statement.Table)
|
|
||||||
db.Statement.WriteString("OFF;")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputInserted(db *gorm.DB) {
|
func outputInserted(db *gorm.DB) {
|
||||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||||
db.Statement.WriteString(" OUTPUT ")
|
sortedKeys := []string{}
|
||||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||||
|
sortedKeys = append(sortedKeys, field.DBName)
|
||||||
|
}
|
||||||
|
sort.Strings(sortedKeys)
|
||||||
|
|
||||||
|
db.Statement.WriteString(" OUTPUT")
|
||||||
|
for idx, key := range sortedKeys {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
db.Statement.WriteString(",")
|
db.Statement.WriteString(",")
|
||||||
}
|
}
|
||||||
db.Statement.WriteString(" INSERTED.")
|
db.Statement.WriteString(" INSERTED.")
|
||||||
db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName})
|
db.Statement.AddVar(db.Statement, clause.Column{Name: key})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue