From e159ca1914e74d9fdbbac34274472a65ea3576f2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Jan 2016 20:32:52 +0800 Subject: [PATCH] Refactor dialect --- dialect.go | 16 ++++++------ dialect_common.go | 60 +++++++++++++++++++++++++-------------------- dialect_mssql.go | 33 +++++++++++++++++-------- dialect_mysql.go | 14 +++++------ dialect_postgres.go | 56 ++++++++++++++++++++---------------------- dialect_sqlite3.go | 20 +++++++-------- main.go | 8 +++--- main_private.go | 2 +- scope.go | 2 +- scope_private.go | 39 +++-------------------------- search.go | 12 ++++----- 11 files changed, 124 insertions(+), 138 deletions(-) diff --git a/dialect.go b/dialect.go index aa23e31b..1fa61925 100644 --- a/dialect.go +++ b/dialect.go @@ -7,17 +7,19 @@ import ( type Dialect interface { BinVar(i int) string - SupportLastInsertId() bool - HasTop() bool - SqlTag(value reflect.Value, size int, autoIncrease bool) string - ReturningStr(tableName, key string) string - SelectFromDummyTable() string Quote(key string) string - HasTable(scope *Scope, tableName string) bool - HasColumn(scope *Scope, tableName string, columnName string) bool + SqlTag(value reflect.Value, size int, autoIncrease bool) string + HasIndex(scope *Scope, tableName string, indexName string) bool RemoveIndex(scope *Scope, indexName string) + HasTable(scope *Scope, tableName string) bool + HasColumn(scope *Scope, tableName string, columnName string) bool CurrentDatabase(scope *Scope) string + + ReturningStr(tableName, key string) string + LimitAndOffsetSQL(limit, offset int) string + SelectFromDummyTable() string + SupportLastInsertId() bool } func NewDialect(driver string) Dialect { diff --git a/dialect_common.go b/dialect_common.go index 7f08b04f..ade7c068 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -12,12 +12,8 @@ func (commonDialect) BinVar(i int) string { return "$$" // ? } -func (commonDialect) SupportLastInsertId() bool { - return true -} - -func (commonDialect) HasTop() bool { - return false +func (commonDialect) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) } func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { @@ -56,16 +52,17 @@ func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) st panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) } -func (commonDialect) ReturningStr(tableName, key string) string { - return "" +func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { + var ( + count int + databaseName = c.CurrentDatabase(scope) + ) + c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName) + return count > 0 } -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) +func (commonDialect) RemoveIndex(scope *Scope, indexName string) { + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) } func (c commonDialect) HasTable(scope *Scope, tableName string) bool { @@ -86,19 +83,6 @@ func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName stri return count > 0 } -func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { - var ( - count int - databaseName = c.CurrentDatabase(scope) - ) - c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName) - return count > 0 -} - -func (commonDialect) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) -} - // RawScanInt scans the first column of the first row into the `scan' int pointer. // This function captures raw query errors and propagates them to the original scope. func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) { @@ -115,3 +99,25 @@ func (commonDialect) CurrentDatabase(scope *Scope) (name string) { scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) return } + +func (commonDialect) ReturningStr(tableName, key string) string { + return "" +} + +func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) { + if limit >= 0 { + sql += fmt.Sprintf(" LIMIT %d", limit) + } + if offset >= 0 { + sql += fmt.Sprintf(" OFFSET %d", offset) + } + return +} + +func (commonDialect) SelectFromDummyTable() string { + return "" +} + +func (commonDialect) SupportLastInsertId() bool { + return true +} diff --git a/dialect_mssql.go b/dialect_mssql.go index a9bd1e52..82fba7d1 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -10,10 +10,6 @@ type mssql struct { commonDialect } -func (mssql) HasTop() bool { - return true -} - func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: @@ -50,6 +46,12 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) } +func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName) + return count > 0 +} + func (s mssql) HasTable(scope *Scope, tableName string) bool { var ( count int @@ -68,13 +70,24 @@ func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool return count > 0 } -func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName) - return count > 0 -} - func (s mssql) CurrentDatabase(scope *Scope) (name string) { s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]") return } + +func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) { + if limit < 0 && offset < 0 { + return + } + + if offset < 0 { + offset = 0 + } + + sql += fmt.Sprintf(" OFFSET %d ROWS", offset) + + if limit >= 0 { + sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit) + } + return +} diff --git a/dialect_mysql.go b/dialect_mysql.go index 9e1d56d3..b6f9a22b 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -10,6 +10,10 @@ type mysql struct { commonDialect } +func (mysql) Quote(key string) string { + return fmt.Sprintf("`%s`", key) +} + func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: @@ -56,15 +60,11 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) +func (s mysql) CurrentDatabase(scope *Scope) (name string) { + s.RawScanString(scope, &name, "SELECT DATABASE()") + return } func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } - -func (s mysql) CurrentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT DATABASE()") - return -} diff --git a/dialect_postgres.go b/dialect_postgres.go index 3b083dfa..0b16816c 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -19,10 +19,6 @@ func (postgres) BinVar(i int) string { return fmt.Sprintf("$%v", i) } -func (postgres) SupportLastInsertId() bool { - return false -} - func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: @@ -62,23 +58,14 @@ func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -var byteType = reflect.TypeOf(uint8(0)) - -func isByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType +func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName) + return count > 0 } -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} - -func (s postgres) ReturningStr(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) +func (postgres) RemoveIndex(scope *Scope, indexName string) { + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) } func (s postgres) HasTable(scope *Scope, tableName string) bool { @@ -93,21 +80,19 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b return count > 0 } -func (postgres) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) -} - -func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName) - return count > 0 -} - func (s postgres) CurrentDatabase(scope *Scope) (name string) { s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()") return } +func (s postgres) ReturningStr(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", tableName, key) +} + +func (postgres) SupportLastInsertId() bool { + return false +} + var hstoreType = reflect.TypeOf(Hstore{}) type Hstore map[string]*string @@ -152,3 +137,16 @@ func (h *Hstore) Scan(value interface{}) error { return nil } + +func isByteArrayOrSlice(value reflect.Value) bool { + return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) +} + +func isUUID(value reflect.Value) bool { + if value.Kind() != reflect.Array || value.Type().Len() != 16 { + return false + } + typename := value.Type().Name() + lower := strings.ToLower(typename) + return "uuid" == lower || "guid" == lower +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index d052d2c1..82546dbb 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -43,6 +43,16 @@ func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } +func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) + return count > 0 +} + +func (sqlite3) RemoveIndex(scope *Scope, indexName string) { + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) +} + func (s sqlite3) HasTable(scope *Scope, tableName string) bool { var count int s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName) @@ -55,16 +65,6 @@ func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bo return count > 0 } -func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) - return count > 0 -} - -func (sqlite3) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) -} - func (sqlite3) CurrentDatabase(scope *Scope) (name string) { var ( ifaces = make([]interface{}, 3) diff --git a/main.go b/main.go index a56f282f..461329fa 100644 --- a/main.go +++ b/main.go @@ -146,12 +146,12 @@ func (s *DB) Not(query interface{}, args ...interface{}) *DB { return s.clone().search.Not(query, args...).db } -func (s *DB) Limit(value interface{}) *DB { - return s.clone().search.Limit(value).db +func (s *DB) Limit(limit int) *DB { + return s.clone().search.Limit(limit).db } -func (s *DB) Offset(value interface{}) *DB { - return s.clone().search.Offset(value).db +func (s *DB) Offset(offset int) *DB { + return s.clone().search.Offset(offset).db } func (s *DB) Order(value string, reorder ...bool) *DB { diff --git a/main_private.go b/main_private.go index bd097ce0..a6e5a6a9 100644 --- a/main_private.go +++ b/main_private.go @@ -10,7 +10,7 @@ func (s *DB) clone() *DB { } if s.search == nil { - db.search = &search{} + db.search = &search{limit: -1, offset: -1} } else { db.search = s.search.clone() } diff --git a/scope.go b/scope.go index 1608a99b..8ee4bdd5 100644 --- a/scope.go +++ b/scope.go @@ -272,7 +272,7 @@ func (scope *Scope) QuotedTableName() (name string) { // CombinedConditionSql get combined condition sql func (scope *Scope) CombinedConditionSql() string { return scope.joinsSql() + scope.whereSql() + scope.groupSql() + - scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() + scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql() } // FieldByName find gorm.Field with name and db name diff --git a/scope_private.go b/scope_private.go index ef16cf93..dc1676e8 100644 --- a/scope_private.go +++ b/scope_private.go @@ -245,41 +245,8 @@ func (scope *Scope) orderSql() string { return " ORDER BY " + strings.Join(scope.Search.orders, ",") } -func (scope *Scope) limitSql() string { - if !scope.Dialect().HasTop() { - if len(scope.Search.limit) == 0 { - return "" - } - return " LIMIT " + scope.Search.limit - } - - return "" -} - -func (scope *Scope) topSql() string { - if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 { - if len(scope.Search.limit) == 0 { - return "" - } - return " TOP(" + scope.Search.limit + ")" - } - - return "" -} - -func (scope *Scope) offsetSql() string { - if len(scope.Search.offset) == 0 { - return "" - } - - if scope.Dialect().HasTop() { - sql := " OFFSET " + scope.Search.offset + " ROW " - if len(scope.Search.limit) > 0 { - sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY" - } - return sql - } - return " OFFSET " + scope.Search.offset +func (scope *Scope) limitAndOffsetSql() string { + return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) } func (scope *Scope) groupSql() string { @@ -318,7 +285,7 @@ func (scope *Scope) prepareQuerySql() { if scope.Search.raw { scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) } else { - scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) } return } diff --git a/search.go b/search.go index 166b9a86..c6d070f0 100644 --- a/search.go +++ b/search.go @@ -15,8 +15,8 @@ type search struct { orders []string joins string preload []searchPreload - offset string - limit string + offset int + limit int group string tableName string raw bool @@ -82,13 +82,13 @@ func (s *search) Omit(columns ...string) *search { return s } -func (s *search) Limit(value interface{}) *search { - s.limit = s.getInterfaceAsSql(value) +func (s *search) Limit(limit int) *search { + s.limit = limit return s } -func (s *search) Offset(value interface{}) *search { - s.offset = s.getInterfaceAsSql(value) +func (s *search) Offset(offset int) *search { + s.offset = offset return s }