MSSQL Create() fix: Add LastInsertIDReturningSuffix to dialect (#2690)

* MSSQL Create() fix: Add LastInsertIDReturningSuffix to dialect

Per https://github.com/denisenkom/go-mssqldb/issues/355

* MSSQL Create() fix: Added OUTPUT query to Create() builder
This commit is contained in:
Alex Stockwell 2019-10-17 07:54:11 -07:00 committed by Jinzhu
parent 1bca5dbdd9
commit 820b5f244a
5 changed files with 52 additions and 13 deletions

View File

@ -101,6 +101,7 @@ func createCallback(scope *Scope) {
}
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
@ -113,18 +114,19 @@ func createCallback(scope *Scope) {
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
addExtraSpaceIfExist(insertModifier),
scope.QuotedTableName(),
strings.Join(columns, ","),
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}
// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
// execute create sql: no primaryField
if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
@ -136,7 +138,26 @@ func createCallback(scope *Scope) {
}
}
}
} else {
return
}
// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
return
}
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
@ -145,7 +166,7 @@ func createCallback(scope *Scope) {
} else {
scope.Err(ErrUnaddressable)
}
}
return
}
}

View File

@ -40,6 +40,8 @@ type Dialect interface {
LimitAndOffsetSQL(limit, offset interface{}) string
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// DefaultValueStr

View File

@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

View File

@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) {
return
}
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
return ""
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}

View File

@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
if len(columns) == 0 {
// No OUTPUT to query
return ""
}
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}