mirror of https://github.com/go-gorm/gorm.git
Refactor dialect
This commit is contained in:
parent
896ee534e2
commit
e159ca1914
16
dialect.go
16
dialect.go
|
@ -7,17 +7,19 @@ import (
|
|||
|
||||
type Dialect interface {
|
||||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
HasTop() bool
|
||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
||||
ReturningStr(tableName, key string) string
|
||||
SelectFromDummyTable() string
|
||||
Quote(key string) string
|
||||
HasTable(scope *Scope, tableName string) bool
|
||||
HasColumn(scope *Scope, tableName string, columnName string) bool
|
||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
||||
|
||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||
RemoveIndex(scope *Scope, indexName string)
|
||||
HasTable(scope *Scope, tableName string) bool
|
||||
HasColumn(scope *Scope, tableName string, columnName string) bool
|
||||
CurrentDatabase(scope *Scope) string
|
||||
|
||||
ReturningStr(tableName, key string) string
|
||||
LimitAndOffsetSQL(limit, offset int) string
|
||||
SelectFromDummyTable() string
|
||||
SupportLastInsertId() bool
|
||||
}
|
||||
|
||||
func NewDialect(driver string) Dialect {
|
||||
|
|
|
@ -12,12 +12,8 @@ func (commonDialect) BinVar(i int) string {
|
|||
return "$$" // ?
|
||||
}
|
||||
|
||||
func (commonDialect) SupportLastInsertId() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (commonDialect) HasTop() bool {
|
||||
return false
|
||||
func (commonDialect) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
|
@ -56,16 +52,17 @@ func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) st
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (commonDialect) ReturningStr(tableName, key string) string {
|
||||
return ""
|
||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var (
|
||||
count int
|
||||
databaseName = c.CurrentDatabase(scope)
|
||||
)
|
||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (commonDialect) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
|
||||
}
|
||||
|
||||
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
||||
|
@ -86,19 +83,6 @@ func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName stri
|
|||
return count > 0
|
||||
}
|
||||
|
||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var (
|
||||
count int
|
||||
databaseName = c.CurrentDatabase(scope)
|
||||
)
|
||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
|
||||
}
|
||||
|
||||
// RawScanInt scans the first column of the first row into the `scan' int pointer.
|
||||
// This function captures raw query errors and propagates them to the original scope.
|
||||
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
|
||||
|
@ -115,3 +99,25 @@ func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
|
|||
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) ReturningStr(tableName, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
||||
if limit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", limit)
|
||||
}
|
||||
if offset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", offset)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) SupportLastInsertId() bool {
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -10,10 +10,6 @@ type mssql struct {
|
|||
commonDialect
|
||||
}
|
||||
|
||||
func (mssql) HasTop() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
|
@ -50,6 +46,12 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
||||
var (
|
||||
count int
|
||||
|
@ -68,13 +70,24 @@ func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool
|
|||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
|
||||
return
|
||||
}
|
||||
|
||||
func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
||||
if limit < 0 && offset < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
sql += fmt.Sprintf(" OFFSET %d ROWS", offset)
|
||||
|
||||
if limit >= 0 {
|
||||
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -10,6 +10,10 @@ type mysql struct {
|
|||
commonDialect
|
||||
}
|
||||
|
||||
func (mysql) Quote(key string) string {
|
||||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
||||
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
|
@ -56,15 +60,11 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (mysql) Quote(key string) string {
|
||||
return fmt.Sprintf("`%s`", key)
|
||||
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT DATABASE()")
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql) SelectFromDummyTable() string {
|
||||
return "FROM DUAL"
|
||||
}
|
||||
|
||||
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT DATABASE()")
|
||||
return
|
||||
}
|
||||
|
|
|
@ -19,10 +19,6 @@ func (postgres) BinVar(i int) string {
|
|||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (postgres) SupportLastInsertId() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
|
@ -62,23 +58,14 @@ func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
var byteType = reflect.TypeOf(uint8(0))
|
||||
|
||||
func isByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType
|
||||
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func isUUID(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||
return false
|
||||
}
|
||||
typename := value.Type().Name()
|
||||
lower := strings.ToLower(typename)
|
||||
return "uuid" == lower || "guid" == lower
|
||||
}
|
||||
|
||||
func (s postgres) ReturningStr(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
func (postgres) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
||||
}
|
||||
|
||||
func (s postgres) HasTable(scope *Scope, tableName string) bool {
|
||||
|
@ -93,21 +80,19 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b
|
|||
return count > 0
|
||||
}
|
||||
|
||||
func (postgres) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
||||
}
|
||||
|
||||
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
|
||||
return
|
||||
}
|
||||
|
||||
func (s postgres) ReturningStr(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
||||
func (postgres) SupportLastInsertId() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var hstoreType = reflect.TypeOf(Hstore{})
|
||||
|
||||
type Hstore map[string]*string
|
||||
|
@ -152,3 +137,16 @@ func (h *Hstore) Scan(value interface{}) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||
}
|
||||
|
||||
func isUUID(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||
return false
|
||||
}
|
||||
typename := value.Type().Name()
|
||||
lower := strings.ToLower(typename)
|
||||
return "uuid" == lower || "guid" == lower
|
||||
}
|
||||
|
|
|
@ -43,6 +43,16 @@ func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
||||
}
|
||||
|
||||
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
|
||||
|
@ -55,16 +65,6 @@ func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bo
|
|||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
||||
}
|
||||
|
||||
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
|
||||
var (
|
||||
ifaces = make([]interface{}, 3)
|
||||
|
|
8
main.go
8
main.go
|
@ -146,12 +146,12 @@ func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
|||
return s.clone().search.Not(query, args...).db
|
||||
}
|
||||
|
||||
func (s *DB) Limit(value interface{}) *DB {
|
||||
return s.clone().search.Limit(value).db
|
||||
func (s *DB) Limit(limit int) *DB {
|
||||
return s.clone().search.Limit(limit).db
|
||||
}
|
||||
|
||||
func (s *DB) Offset(value interface{}) *DB {
|
||||
return s.clone().search.Offset(value).db
|
||||
func (s *DB) Offset(offset int) *DB {
|
||||
return s.clone().search.Offset(offset).db
|
||||
}
|
||||
|
||||
func (s *DB) Order(value string, reorder ...bool) *DB {
|
||||
|
|
|
@ -10,7 +10,7 @@ func (s *DB) clone() *DB {
|
|||
}
|
||||
|
||||
if s.search == nil {
|
||||
db.search = &search{}
|
||||
db.search = &search{limit: -1, offset: -1}
|
||||
} else {
|
||||
db.search = s.search.clone()
|
||||
}
|
||||
|
|
2
scope.go
2
scope.go
|
@ -272,7 +272,7 @@ func (scope *Scope) QuotedTableName() (name string) {
|
|||
// CombinedConditionSql get combined condition sql
|
||||
func (scope *Scope) CombinedConditionSql() string {
|
||||
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
|
||||
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
|
||||
scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql()
|
||||
}
|
||||
|
||||
// FieldByName find gorm.Field with name and db name
|
||||
|
|
|
@ -245,41 +245,8 @@ func (scope *Scope) orderSql() string {
|
|||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||
}
|
||||
|
||||
func (scope *Scope) limitSql() string {
|
||||
if !scope.Dialect().HasTop() {
|
||||
if len(scope.Search.limit) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " LIMIT " + scope.Search.limit
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (scope *Scope) topSql() string {
|
||||
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
|
||||
if len(scope.Search.limit) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " TOP(" + scope.Search.limit + ")"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (scope *Scope) offsetSql() string {
|
||||
if len(scope.Search.offset) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if scope.Dialect().HasTop() {
|
||||
sql := " OFFSET " + scope.Search.offset + " ROW "
|
||||
if len(scope.Search.limit) > 0 {
|
||||
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
|
||||
}
|
||||
return sql
|
||||
}
|
||||
return " OFFSET " + scope.Search.offset
|
||||
func (scope *Scope) limitAndOffsetSql() string {
|
||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
}
|
||||
|
||||
func (scope *Scope) groupSql() string {
|
||||
|
@ -318,7 +285,7 @@ func (scope *Scope) prepareQuerySql() {
|
|||
if scope.Search.raw {
|
||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
12
search.go
12
search.go
|
@ -15,8 +15,8 @@ type search struct {
|
|||
orders []string
|
||||
joins string
|
||||
preload []searchPreload
|
||||
offset string
|
||||
limit string
|
||||
offset int
|
||||
limit int
|
||||
group string
|
||||
tableName string
|
||||
raw bool
|
||||
|
@ -82,13 +82,13 @@ func (s *search) Omit(columns ...string) *search {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *search) Limit(value interface{}) *search {
|
||||
s.limit = s.getInterfaceAsSql(value)
|
||||
func (s *search) Limit(limit int) *search {
|
||||
s.limit = limit
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Offset(value interface{}) *search {
|
||||
s.offset = s.getInterfaceAsSql(value)
|
||||
func (s *search) Offset(offset int) *search {
|
||||
s.offset = offset
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue