2016-03-05 13:54:59 +03:00
package mssql
2016-03-05 14:22:29 +03:00
import (
2018-07-27 02:30:57 +03:00
"database/sql/driver"
"encoding/json"
"errors"
2016-03-05 14:22:29 +03:00
"fmt"
2016-03-05 17:50:49 +03:00
"reflect"
2016-06-21 06:13:33 +03:00
"strconv"
2016-03-05 17:50:49 +03:00
"strings"
"time"
2016-03-05 14:22:29 +03:00
2018-07-27 02:30:57 +03:00
// Importing mssql driver package only in dialect file, otherwide not needed
2016-03-05 14:22:29 +03:00
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
)
2017-03-22 17:57:13 +03:00
func setIdentityInsert ( scope * gorm . Scope ) {
if scope . Dialect ( ) . GetName ( ) == "mssql" {
for _ , field := range scope . PrimaryFields ( ) {
2018-09-10 02:11:00 +03:00
if _ , ok := field . TagSettingsGet ( "AUTO_INCREMENT" ) ; ok && ! field . IsBlank {
2017-03-22 17:57:13 +03:00
scope . NewDB ( ) . Exec ( fmt . Sprintf ( "SET IDENTITY_INSERT %v ON" , scope . TableName ( ) ) )
scope . InstanceSet ( "mssql:identity_insert_on" , true )
}
}
}
}
func turnOffIdentityInsert ( scope * gorm . Scope ) {
if scope . Dialect ( ) . GetName ( ) == "mssql" {
if _ , ok := scope . InstanceGet ( "mssql:identity_insert_on" ) ; ok {
scope . NewDB ( ) . Exec ( fmt . Sprintf ( "SET IDENTITY_INSERT %v OFF" , scope . TableName ( ) ) )
}
}
}
2016-03-05 14:22:29 +03:00
func init ( ) {
2017-03-22 17:57:13 +03:00
gorm . DefaultCallback . Create ( ) . After ( "gorm:begin_transaction" ) . Register ( "mssql:set_identity_insert" , setIdentityInsert )
gorm . DefaultCallback . Create ( ) . Before ( "gorm:commit_or_rollback_transaction" ) . Register ( "mssql:turn_off_identity_insert" , turnOffIdentityInsert )
2016-03-05 17:50:49 +03:00
gorm . RegisterDialect ( "mssql" , & mssql { } )
}
type mssql struct {
2017-03-14 23:32:38 +03:00
db gorm . SQLCommon
2016-05-22 01:13:26 +03:00
gorm . DefaultForeignKeyNamer
2016-03-05 17:50:49 +03:00
}
func ( mssql ) GetName ( ) string {
return "mssql"
}
2017-03-14 23:32:38 +03:00
func ( s * mssql ) SetDB ( db gorm . SQLCommon ) {
2016-03-05 17:50:49 +03:00
s . db = db
}
func ( mssql ) BindVar ( i int ) string {
2017-04-18 11:13:02 +03:00
return "$$$" // ?
2016-03-05 17:50:49 +03:00
}
func ( mssql ) Quote ( key string ) string {
2018-02-10 07:28:01 +03:00
return fmt . Sprintf ( ` [%s] ` , key )
2016-03-05 17:50:49 +03:00
}
2017-01-15 16:24:53 +03:00
func ( s * mssql ) DataTypeOf ( field * gorm . StructField ) string {
var dataValue , sqlType , size , additionalType = gorm . ParseFieldStructForDialect ( field , s )
2016-03-05 17:50:49 +03:00
if sqlType == "" {
switch dataValue . Kind ( ) {
case reflect . Bool :
sqlType = "bit"
case reflect . Int , reflect . Int8 , reflect . Int16 , reflect . Int32 , reflect . Uint , reflect . Uint8 , reflect . Uint16 , reflect . Uint32 , reflect . Uintptr :
2018-02-10 03:30:05 +03:00
if s . fieldCanAutoIncrement ( field ) {
2018-09-10 02:11:00 +03:00
field . TagSettingsSet ( "AUTO_INCREMENT" , "AUTO_INCREMENT" )
2016-03-05 17:50:49 +03:00
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect . Int64 , reflect . Uint64 :
2018-02-10 03:30:05 +03:00
if s . fieldCanAutoIncrement ( field ) {
2018-09-10 02:11:00 +03:00
field . TagSettingsSet ( "AUTO_INCREMENT" , "AUTO_INCREMENT" )
2016-03-05 17:50:49 +03:00
sqlType = "bigint IDENTITY(1,1)"
} else {
sqlType = "bigint"
}
case reflect . Float32 , reflect . Float64 :
sqlType = "float"
case reflect . String :
2017-04-19 10:21:56 +03:00
if size > 0 && size < 8000 {
2016-03-05 17:50:49 +03:00
sqlType = fmt . Sprintf ( "nvarchar(%d)" , size )
} else {
2017-04-19 10:21:56 +03:00
sqlType = "nvarchar(max)"
2016-03-05 17:50:49 +03:00
}
case reflect . Struct :
if _ , ok := dataValue . Interface ( ) . ( time . Time ) ; ok {
2017-08-11 12:18:49 +03:00
sqlType = "datetimeoffset"
2016-03-05 17:50:49 +03:00
}
default :
2017-04-19 10:21:56 +03:00
if gorm . IsByteArrayOrSlice ( dataValue ) {
if size > 0 && size < 8000 {
sqlType = fmt . Sprintf ( "varbinary(%d)" , size )
2016-03-05 17:50:49 +03:00
} else {
2017-04-19 10:21:56 +03:00
sqlType = "varbinary(max)"
2016-03-05 17:50:49 +03:00
}
}
}
}
if sqlType == "" {
panic ( fmt . Sprintf ( "invalid sql type %s (%s) for mssql" , dataValue . Type ( ) . Name ( ) , dataValue . Kind ( ) . String ( ) ) )
}
if strings . TrimSpace ( additionalType ) == "" {
return sqlType
}
return fmt . Sprintf ( "%v %v" , sqlType , additionalType )
}
2018-02-10 03:30:05 +03:00
func ( s mssql ) fieldCanAutoIncrement ( field * gorm . StructField ) bool {
2018-09-10 02:11:00 +03:00
if value , ok := field . TagSettingsGet ( "AUTO_INCREMENT" ) ; ok {
2018-02-10 03:30:05 +03:00
return value != "FALSE"
}
return field . IsPrimaryKey
}
2016-03-05 17:50:49 +03:00
func ( s mssql ) HasIndex ( tableName string , indexName string ) bool {
var count int
s . db . QueryRow ( "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)" , indexName , tableName ) . Scan ( & count )
return count > 0
}
func ( s mssql ) RemoveIndex ( tableName string , indexName string ) error {
_ , err := s . db . Exec ( fmt . Sprintf ( "DROP INDEX %v ON %v" , indexName , s . Quote ( tableName ) ) )
return err
}
func ( s mssql ) HasForeignKey ( tableName string , foreignKeyName string ) bool {
2018-05-12 09:29:00 +03:00
var count int
currentDatabase , tableName := currentDatabaseAndTable ( & s , tableName )
s . db . QueryRow ( ` SELECT count ( * )
FROM sys . foreign_keys as F inner join sys . tables as T on F . parent_object_id = T . object_id
inner join information_schema . tables as I on I . TABLE_NAME = T . name
WHERE F . name = ?
AND T . Name = ? AND I . TABLE_CATALOG = ? ; ` , foreignKeyName , tableName , currentDatabase ) . Scan ( & count )
return count > 0
2016-03-05 17:50:49 +03:00
}
func ( s mssql ) HasTable ( tableName string ) bool {
var count int
2018-02-03 15:27:19 +03:00
currentDatabase , tableName := currentDatabaseAndTable ( & s , tableName )
s . db . QueryRow ( "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?" , tableName , currentDatabase ) . Scan ( & count )
2016-03-05 17:50:49 +03:00
return count > 0
}
func ( s mssql ) HasColumn ( tableName string , columnName string ) bool {
var count int
2018-02-03 15:27:19 +03:00
currentDatabase , tableName := currentDatabaseAndTable ( & s , tableName )
s . db . QueryRow ( "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?" , currentDatabase , tableName , columnName ) . Scan ( & count )
2016-03-05 17:50:49 +03:00
return count > 0
}
2018-02-09 17:58:34 +03:00
func ( s mssql ) ModifyColumn ( tableName string , columnName string , typ string ) error {
_ , err := s . db . Exec ( fmt . Sprintf ( "ALTER TABLE %v ALTER COLUMN %v %v" , tableName , columnName , typ ) )
return err
}
2016-07-11 16:37:44 +03:00
func ( s mssql ) CurrentDatabase ( ) ( name string ) {
2016-03-05 17:50:49 +03:00
s . db . QueryRow ( "SELECT DB_NAME() AS [Current Database]" ) . Scan ( & name )
return
}
2016-06-21 06:13:33 +03:00
func ( mssql ) LimitAndOffsetSQL ( limit , offset interface { } ) ( sql string ) {
if offset != nil {
2017-04-18 18:31:56 +03:00
if parsedOffset , err := strconv . ParseInt ( fmt . Sprint ( offset ) , 0 , 0 ) ; err == nil && parsedOffset >= 0 {
2016-06-21 06:13:33 +03:00
sql += fmt . Sprintf ( " OFFSET %d ROWS" , parsedOffset )
2016-03-05 17:50:49 +03:00
}
}
2017-02-22 00:23:01 +03:00
if limit != nil {
2017-04-18 18:31:56 +03:00
if parsedLimit , err := strconv . ParseInt ( fmt . Sprint ( limit ) , 0 , 0 ) ; err == nil && parsedLimit >= 0 {
2017-02-22 00:23:01 +03:00
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"
}
sql += fmt . Sprintf ( " FETCH NEXT %d ROWS ONLY" , parsedLimit )
}
}
2016-03-05 17:50:49 +03:00
return
}
func ( mssql ) SelectFromDummyTable ( ) string {
return ""
}
2016-03-07 09:54:20 +03:00
func ( mssql ) LastInsertIDReturningSuffix ( tableName , columnName string ) string {
2016-03-05 17:50:49 +03:00
return ""
2016-03-05 14:22:29 +03:00
}
2018-02-03 15:27:19 +03:00
2018-02-12 12:39:34 +03:00
func ( mssql ) DefaultValueStr ( ) string {
return "DEFAULT VALUES"
}
2018-02-03 15:27:19 +03:00
func currentDatabaseAndTable ( dialect gorm . Dialect , tableName string ) ( string , string ) {
if strings . Contains ( tableName , "." ) {
splitStrings := strings . SplitN ( tableName , "." , 2 )
return splitStrings [ 0 ] , splitStrings [ 1 ]
}
return dialect . CurrentDatabase ( ) , tableName
}
2018-07-27 02:30:57 +03:00
// JSON type to support easy handling of JSON data in character table fields
// using golang json.RawMessage for deferred decoding/encoding
type JSON struct {
json . RawMessage
}
// Value get value of JSON
func ( j JSON ) Value ( ) ( driver . Value , error ) {
if len ( j . RawMessage ) == 0 {
return nil , nil
}
return j . MarshalJSON ( )
}
// Scan scan value into JSON
func ( j * JSON ) Scan ( value interface { } ) error {
str , ok := value . ( string )
if ! ok {
return errors . New ( fmt . Sprint ( "Failed to unmarshal JSONB value (strcast):" , value ) )
}
bytes := [ ] byte ( str )
return json . Unmarshal ( bytes , j )
}