Fix mssql tests

This commit is contained in:
Jinzhu 2020-06-02 00:03:38 +08:00
parent e490e09db5
commit 9807fffdbc
1 changed files with 62 additions and 33 deletions

View File

@ -2,6 +2,7 @@ package mssql
import ( import (
"reflect" "reflect"
"sort"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
@ -17,10 +18,35 @@ func Create(db *gorm.DB) {
} }
if db.Statement.SQL.String() == "" { if db.Statement.SQL.String() == "" {
setIdentityInsert := false
c := db.Statement.Clauses["ON CONFLICT"] c := db.Statement.Clauses["ON CONFLICT"]
onConflict, hasConflict := c.Expression.(clause.OnConflict) onConflict, hasConflict := c.Expression.(clause.OnConflict)
if hasConflict { if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
setIdentityInsert = false
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
_, isZero := field.ValueOf(db.Statement.ReflectValue)
setIdentityInsert = !isZero
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
_, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i))
setIdentityInsert = !isZero
break
}
}
if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) {
setIdentityInsert = true
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString(" ON;")
} else {
setIdentityInsert = false
}
}
if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 {
MergeCreate(db, onConflict) MergeCreate(db, onConflict)
} else { } else {
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
@ -55,10 +81,16 @@ func Create(db *gorm.DB) {
db.Statement.WriteString(";") db.Statement.WriteString(";")
} else { } else {
db.Statement.WriteString("DEFAULT VALUES") db.Statement.WriteString("DEFAULT VALUES;")
} }
} }
} }
if setIdentityInsert {
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString(" OFF;")
}
} }
if !db.DryRun { if !db.DryRun {
@ -67,25 +99,32 @@ func Create(db *gorm.DB) {
if err == nil { if err == nil {
defer rows.Close() defer rows.Close()
if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
sortedKeys := []string{}
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
sortedKeys = append(sortedKeys, field.DBName)
}
sort.Strings(sortedKeys)
returnningFields := make([]*schema.Field, len(sortedKeys))
for idx, key := range sortedKeys {
returnningFields[idx] = db.Statement.Schema.LookUpField(key)
}
values := make([]interface{}, len(returnningFields))
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if len(db.Statement.Schema.PrimaryFields) > 0 {
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
for rows.Next() { for rows.Next() {
for idx, field := range db.Statement.Schema.PrimaryFields { for idx, field := range returnningFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
} }
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
} }
}
case reflect.Struct: case reflect.Struct:
if len(db.Statement.Schema.PrimaryFields) > 0 { for idx, field := range returnningFields {
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
for idx, field := range db.Statement.Schema.PrimaryFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} }
@ -103,16 +142,6 @@ func Create(db *gorm.DB) {
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
values := callbacks.ConvertToCreateValues(db.Statement) values := callbacks.ConvertToCreateValues(db.Statement)
setIdentityInsert := false
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
if field.DataType == schema.Int || field.DataType == schema.Uint {
setIdentityInsert = true
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString("ON;")
}
}
db.Statement.WriteString("MERGE INTO ") db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table) db.Statement.WriteQuoted(db.Statement.Table)
@ -174,23 +203,23 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
db.Statement.WriteString(")") db.Statement.WriteString(")")
outputInserted(db) outputInserted(db)
db.Statement.WriteString(";") db.Statement.WriteString(";")
if setIdentityInsert {
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString("OFF;")
}
} }
func outputInserted(db *gorm.DB) { func outputInserted(db *gorm.DB) {
if len(db.Statement.Schema.PrimaryFields) > 0 { if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
sortedKeys := []string{}
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
sortedKeys = append(sortedKeys, field.DBName)
}
sort.Strings(sortedKeys)
db.Statement.WriteString(" OUTPUT") db.Statement.WriteString(" OUTPUT")
for idx, field := range db.Statement.Schema.PrimaryFields { for idx, key := range sortedKeys {
if idx > 0 { if idx > 0 {
db.Statement.WriteString(",") db.Statement.WriteString(",")
} }
db.Statement.WriteString(" INSERTED.") db.Statement.WriteString(" INSERTED.")
db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName}) db.Statement.AddVar(db.Statement, clause.Column{Name: key})
} }
} }
} }