Add HasIndex method for dialect interface

This commit is contained in:
Jinzhu 2015-03-02 23:02:40 +08:00
parent 61a878dc2d
commit 34997385b0
10 changed files with 71 additions and 15 deletions

View File

@ -60,7 +60,7 @@ type Address struct {
ID int
Address1 string `sql:"not null;unique"` // Set field as not nullable and unique
Address2 string `sql:"type:varchar(100);unique"`
Post sql.NullString `sql:not null`
Post sql.NullString `sql:"not null"`
}
type Language struct {

View File

@ -86,13 +86,19 @@ func (s *commonDialect) databaseName(scope *Scope) string {
func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}

View File

@ -23,7 +23,7 @@ func TestCustomizeColumn(t *testing.T) {
DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.Model("").NewScope(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}

View File

@ -16,6 +16,7 @@ type Dialect interface {
Quote(key string) string
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool
HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string)
}

20
main.go
View File

@ -399,6 +399,16 @@ func (s *DB) AddIndex(indexName string, column ...string) *DB {
return s
}
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
s.clone().NewScope(s.Value).addIndex(true, indexName, column...)
return s
}
func (s *DB) RemoveIndex(indexName string) *DB {
s.clone().NewScope(s.Value).removeIndex(indexName)
return s
}
/*
Add foreign key to the given scope
@ -410,16 +420,6 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate
return s
}
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
s.clone().NewScope(s.Value).addIndex(true, indexName, column...)
return s
}
func (s *DB) RemoveIndex(indexName string) *DB {
s.clone().NewScope(s.Value).removeIndex(indexName)
return s
}
func (s *DB) Association(column string) *Association {
var err error
scope := s.clone().NewScope(s.Value)

View File

@ -30,22 +30,43 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err)
}
scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted")
}
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
t.Errorf("Should get to create duplicate record when having unique index")
}
@ -54,6 +75,10 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
t.Errorf("Should be able to create duplicated emails after remove unique index")
}

View File

@ -98,6 +98,12 @@ func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) boo
return count > 0
}
func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
return count > 0
}
func (s *mssql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}

View File

@ -96,6 +96,12 @@ func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) boo
return count > 0
}
func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (s *mysql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}

View File

@ -87,7 +87,7 @@ func (s *postgres) HasTable(scope *Scope, tableName string) bool {
func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0
}
@ -95,6 +95,12 @@ func (s *postgres) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
}
func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
var hstoreType = reflect.TypeOf(Hstore{})
type Hstore map[string]*string

View File

@ -80,6 +80,12 @@ func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) b
return count > 0
}
func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count)
return count > 0
}
func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
}