2016-03-05 13:54:59 +03:00
package mssql
2016-03-05 14:22:29 +03:00
import (
2016-03-05 17:50:49 +03:00
"database/sql"
2016-03-05 14:22:29 +03:00
"fmt"
2016-03-05 17:50:49 +03:00
"reflect"
"strings"
"time"
2016-03-05 14:22:29 +03:00
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
)
func setIdentityInsert ( scope * gorm . Scope ) {
2016-03-05 16:24:54 +03:00
if scope . Dialect ( ) . GetName ( ) == "mssql" {
scope . NewDB ( ) . Exec ( fmt . Sprintf ( "SET IDENTITY_INSERT %v ON" , scope . TableName ( ) ) )
}
2016-03-05 14:22:29 +03:00
}
func init ( ) {
gorm . DefaultCallback . Create ( ) . After ( "gorm:begin_transaction" ) . Register ( "mssql:set_identity_insert" , setIdentityInsert )
2016-03-05 17:50:49 +03:00
gorm . RegisterDialect ( "mssql" , & mssql { } )
}
type mssql struct {
db * sql . DB
2016-05-22 01:13:26 +03:00
gorm . DefaultForeignKeyNamer
2016-03-05 17:50:49 +03:00
}
func ( mssql ) GetName ( ) string {
return "mssql"
}
func ( s * mssql ) SetDB ( db * sql . DB ) {
s . db = db
}
func ( mssql ) BindVar ( i int ) string {
return "$$" // ?
}
func ( mssql ) Quote ( key string ) string {
return fmt . Sprintf ( ` "%s" ` , key )
}
func ( mssql ) DataTypeOf ( field * gorm . StructField ) string {
var dataValue , sqlType , size , additionalType = gorm . ParseFieldStructForDialect ( field )
if sqlType == "" {
switch dataValue . Kind ( ) {
case reflect . Bool :
sqlType = "bit"
case reflect . Int , reflect . Int8 , reflect . Int16 , reflect . Int32 , reflect . Uint , reflect . Uint8 , reflect . Uint16 , reflect . Uint32 , reflect . Uintptr :
if _ , ok := field . TagSettings [ "AUTO_INCREMENT" ] ; ok || field . IsPrimaryKey {
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect . Int64 , reflect . Uint64 :
if _ , ok := field . TagSettings [ "AUTO_INCREMENT" ] ; ok || field . IsPrimaryKey {
sqlType = "bigint IDENTITY(1,1)"
} else {
sqlType = "bigint"
}
case reflect . Float32 , reflect . Float64 :
sqlType = "float"
case reflect . String :
if size > 0 && size < 65532 {
sqlType = fmt . Sprintf ( "nvarchar(%d)" , size )
} else {
sqlType = "text"
}
case reflect . Struct :
if _ , ok := dataValue . Interface ( ) . ( time . Time ) ; ok {
sqlType = "datetime2"
}
default :
if _ , ok := dataValue . Interface ( ) . ( [ ] byte ) ; ok {
if size > 0 && size < 65532 {
sqlType = fmt . Sprintf ( "varchar(%d)" , size )
} else {
sqlType = "text"
}
}
}
}
if sqlType == "" {
panic ( fmt . Sprintf ( "invalid sql type %s (%s) for mssql" , dataValue . Type ( ) . Name ( ) , dataValue . Kind ( ) . String ( ) ) )
}
if strings . TrimSpace ( additionalType ) == "" {
return sqlType
}
return fmt . Sprintf ( "%v %v" , sqlType , additionalType )
}
func ( s mssql ) HasIndex ( tableName string , indexName string ) bool {
var count int
s . db . QueryRow ( "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)" , indexName , tableName ) . Scan ( & count )
return count > 0
}
func ( s mssql ) RemoveIndex ( tableName string , indexName string ) error {
_ , err := s . db . Exec ( fmt . Sprintf ( "DROP INDEX %v ON %v" , indexName , s . Quote ( tableName ) ) )
return err
}
func ( s mssql ) HasForeignKey ( tableName string , foreignKeyName string ) bool {
return false
}
func ( s mssql ) HasTable ( tableName string ) bool {
var count int
s . db . QueryRow ( "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?" , tableName , s . currentDatabase ( ) ) . Scan ( & count )
return count > 0
}
func ( s mssql ) HasColumn ( tableName string , columnName string ) bool {
var count int
s . db . QueryRow ( "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?" , s . currentDatabase ( ) , tableName , columnName ) . Scan ( & count )
return count > 0
}
func ( s mssql ) currentDatabase ( ) ( name string ) {
s . db . QueryRow ( "SELECT DB_NAME() AS [Current Database]" ) . Scan ( & name )
return
}
2016-06-16 11:31:50 +03:00
func ( mssql ) LimitAndOffsetSQL ( limit , offset int ) ( sql string ) {
2016-03-05 17:50:49 +03:00
if limit > 0 || offset > 0 {
if offset < 0 {
offset = 0
}
sql += fmt . Sprintf ( " OFFSET %d ROWS" , offset )
if limit >= 0 {
sql += fmt . Sprintf ( " FETCH NEXT %d ROWS ONLY" , limit )
}
}
return
}
func ( mssql ) SelectFromDummyTable ( ) string {
return ""
}
2016-03-07 09:54:20 +03:00
func ( mssql ) LastInsertIDReturningSuffix ( tableName , columnName string ) string {
2016-03-05 17:50:49 +03:00
return ""
2016-03-05 14:22:29 +03:00
}