Refactor dialect

This commit is contained in:
Jinzhu 2016-01-18 20:32:52 +08:00
parent 896ee534e2
commit e159ca1914
11 changed files with 124 additions and 138 deletions

View File

@ -7,17 +7,19 @@ import (
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
HasTop() bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string
ReturningStr(tableName, key string) string
SelectFromDummyTable() string
Quote(key string) string
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string
HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string)
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool
CurrentDatabase(scope *Scope) string
ReturningStr(tableName, key string) string
LimitAndOffsetSQL(limit, offset int) string
SelectFromDummyTable() string
SupportLastInsertId() bool
}
func NewDialect(driver string) Dialect {

View File

@ -12,12 +12,8 @@ func (commonDialect) BinVar(i int) string {
return "$$" // ?
}
func (commonDialect) SupportLastInsertId() bool {
return true
}
func (commonDialect) HasTop() bool {
return false
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
@ -56,16 +52,17 @@ func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) st
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}
func (commonDialect) ReturningStr(tableName, key string) string {
return ""
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var (
count int
databaseName = c.CurrentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
return count > 0
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
}
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
@ -86,19 +83,6 @@ func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName stri
return count > 0
}
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var (
count int
databaseName = c.CurrentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
return count > 0
}
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
}
// RawScanInt scans the first column of the first row into the `scan' int pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
@ -115,3 +99,25 @@ func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
return
}
func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}
func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
if limit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", limit)
}
if offset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", offset)
}
return
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) SupportLastInsertId() bool {
return true
}

View File

@ -10,10 +10,6 @@ type mssql struct {
commonDialect
}
func (mssql) HasTop() bool {
return true
}
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
@ -50,6 +46,12 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
return count > 0
}
func (s mssql) HasTable(scope *Scope, tableName string) bool {
var (
count int
@ -68,13 +70,24 @@ func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool
return count > 0
}
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
return count > 0
}
func (s mssql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
return
}
func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) {
if limit < 0 && offset < 0 {
return
}
if offset < 0 {
offset = 0
}
sql += fmt.Sprintf(" OFFSET %d ROWS", offset)
if limit >= 0 {
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit)
}
return
}

View File

@ -10,6 +10,10 @@ type mysql struct {
commonDialect
}
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
@ -56,15 +60,11 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DATABASE()")
return
}
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DATABASE()")
return
}

View File

@ -19,10 +19,6 @@ func (postgres) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (postgres) SupportLastInsertId() bool {
return false
}
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
@ -62,23 +58,14 @@ func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}
var byteType = reflect.TypeOf(uint8(0))
func isByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
return count > 0
}
func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}
func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (s postgres) HasTable(scope *Scope, tableName string) bool {
@ -93,21 +80,19 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b
return count > 0
}
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
return count > 0
}
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
return
}
func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
func (postgres) SupportLastInsertId() bool {
return false
}
var hstoreType = reflect.TypeOf(Hstore{})
type Hstore map[string]*string
@ -152,3 +137,16 @@ func (h *Hstore) Scan(value interface{}) error {
return nil
}
func isByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}
func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}

View File

@ -43,6 +43,16 @@ func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
return count > 0
}
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
@ -55,16 +65,6 @@ func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bo
return count > 0
}
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
return count > 0
}
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
var (
ifaces = make([]interface{}, 3)

View File

@ -146,12 +146,12 @@ func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.Not(query, args...).db
}
func (s *DB) Limit(value interface{}) *DB {
return s.clone().search.Limit(value).db
func (s *DB) Limit(limit int) *DB {
return s.clone().search.Limit(limit).db
}
func (s *DB) Offset(value interface{}) *DB {
return s.clone().search.Offset(value).db
func (s *DB) Offset(offset int) *DB {
return s.clone().search.Offset(offset).db
}
func (s *DB) Order(value string, reorder ...bool) *DB {

View File

@ -10,7 +10,7 @@ func (s *DB) clone() *DB {
}
if s.search == nil {
db.search = &search{}
db.search = &search{limit: -1, offset: -1}
} else {
db.search = s.search.clone()
}

View File

@ -272,7 +272,7 @@ func (scope *Scope) QuotedTableName() (name string) {
// CombinedConditionSql get combined condition sql
func (scope *Scope) CombinedConditionSql() string {
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql()
}
// FieldByName find gorm.Field with name and db name

View File

@ -245,41 +245,8 @@ func (scope *Scope) orderSql() string {
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
}
func (scope *Scope) limitSql() string {
if !scope.Dialect().HasTop() {
if len(scope.Search.limit) == 0 {
return ""
}
return " LIMIT " + scope.Search.limit
}
return ""
}
func (scope *Scope) topSql() string {
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
if len(scope.Search.limit) == 0 {
return ""
}
return " TOP(" + scope.Search.limit + ")"
}
return ""
}
func (scope *Scope) offsetSql() string {
if len(scope.Search.offset) == 0 {
return ""
}
if scope.Dialect().HasTop() {
sql := " OFFSET " + scope.Search.offset + " ROW "
if len(scope.Search.limit) > 0 {
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
}
return sql
}
return " OFFSET " + scope.Search.offset
func (scope *Scope) limitAndOffsetSql() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
}
func (scope *Scope) groupSql() string {
@ -318,7 +285,7 @@ func (scope *Scope) prepareQuerySql() {
if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else {
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
}
return
}

View File

@ -15,8 +15,8 @@ type search struct {
orders []string
joins string
preload []searchPreload
offset string
limit string
offset int
limit int
group string
tableName string
raw bool
@ -82,13 +82,13 @@ func (s *search) Omit(columns ...string) *search {
return s
}
func (s *search) Limit(value interface{}) *search {
s.limit = s.getInterfaceAsSql(value)
func (s *search) Limit(limit int) *search {
s.limit = limit
return s
}
func (s *search) Offset(value interface{}) *search {
s.offset = s.getInterfaceAsSql(value)
func (s *search) Offset(offset int) *search {
s.offset = offset
return s
}