From 9807fffdbce47865d911eca391a76c8ba0f02db1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 00:03:38 +0800 Subject: [PATCH] Fix mssql tests --- dialects/mssql/create.go | 95 ++++++++++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 33 deletions(-) diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index 6820bb7b..84732427 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -2,6 +2,7 @@ package mssql import ( "reflect" + "sort" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -17,10 +18,35 @@ func Create(db *gorm.DB) { } if db.Statement.SQL.String() == "" { + setIdentityInsert := false c := db.Statement.Clauses["ON CONFLICT"] onConflict, hasConflict := c.Expression.(clause.OnConflict) - if hasConflict { + if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { + setIdentityInsert = false + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + _, isZero := field.ValueOf(db.Statement.ReflectValue) + setIdentityInsert = !isZero + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + _, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i)) + setIdentityInsert = !isZero + break + } + } + + if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) { + setIdentityInsert = true + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString(" ON;") + } else { + setIdentityInsert = false + } + } + + if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 { MergeCreate(db, onConflict) } else { db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) @@ -55,10 +81,16 @@ func Create(db *gorm.DB) { db.Statement.WriteString(";") } else { - db.Statement.WriteString("DEFAULT VALUES") + db.Statement.WriteString("DEFAULT VALUES;") } } } + + if setIdentityInsert { + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString(" OFF;") + } } if !db.DryRun { @@ -67,25 +99,32 @@ func Create(db *gorm.DB) { if err == nil { defer rows.Close() - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) + if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + sortedKeys := []string{} + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + sortedKeys = append(sortedKeys, field.DBName) + } + sort.Strings(sortedKeys) + returnningFields := make([]*schema.Field, len(sortedKeys)) + for idx, key := range sortedKeys { + returnningFields[idx] = db.Statement.Schema.LookUpField(key) + } + + values := make([]interface{}, len(returnningFields)) + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: for rows.Next() { - for idx, field := range db.Statement.Schema.PrimaryFields { + for idx, field := range returnningFields { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() } db.RowsAffected++ db.AddError(rows.Scan(values...)) } - } - case reflect.Struct: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - - for idx, field := range db.Statement.Schema.PrimaryFields { + case reflect.Struct: + for idx, field := range returnningFields { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } @@ -103,16 +142,6 @@ func Create(db *gorm.DB) { func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { values := callbacks.ConvertToCreateValues(db.Statement) - setIdentityInsert := false - - if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { - if field.DataType == schema.Int || field.DataType == schema.Uint { - setIdentityInsert = true - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString("ON;") - } - } db.Statement.WriteString("MERGE INTO ") db.Statement.WriteQuoted(db.Statement.Table) @@ -174,23 +203,23 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { db.Statement.WriteString(")") outputInserted(db) db.Statement.WriteString(";") - - if setIdentityInsert { - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString("OFF;") - } } func outputInserted(db *gorm.DB) { - if len(db.Statement.Schema.PrimaryFields) > 0 { - db.Statement.WriteString(" OUTPUT ") - for idx, field := range db.Statement.Schema.PrimaryFields { + if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + sortedKeys := []string{} + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + sortedKeys = append(sortedKeys, field.DBName) + } + sort.Strings(sortedKeys) + + db.Statement.WriteString(" OUTPUT") + for idx, key := range sortedKeys { if idx > 0 { db.Statement.WriteString(",") } db.Statement.WriteString(" INSERTED.") - db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName}) + db.Statement.AddVar(db.Statement, clause.Column{Name: key}) } } }