diff --git a/README.md b/README.md index 6a22304c..69b6bd16 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ type Address struct { ID int Address1 string `sql:"not null;unique"` // Set field as not nullable and unique Address2 string `sql:"type:varchar(100);unique"` - Post sql.NullString `sql:not null` + Post sql.NullString `sql:"not null"` } type Language struct { diff --git a/common_dialect.go b/common_dialect.go index d8910dcd..87be97bf 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -86,13 +86,19 @@ func (s *commonDialect) databaseName(scope *Scope) string { func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) return count > 0 } func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) + return count > 0 +} + +func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) return count > 0 } diff --git a/customize_column_test.go b/customize_column_test.go index f5a31a5b..cf4f1d1a 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -23,7 +23,7 @@ func TestCustomizeColumn(t *testing.T) { DB.DropTable(&CustomizeColumn{}) DB.AutoMigrate(&CustomizeColumn{}) - scope := DB.Model("").NewScope(&CustomizeColumn{}) + scope := DB.NewScope(&CustomizeColumn{}) if !scope.Dialect().HasColumn(scope, scope.TableName(), col) { t.Errorf("CustomizeColumn should have column %s", col) } diff --git a/dialect.go b/dialect.go index 6e76b437..0c58d61a 100644 --- a/dialect.go +++ b/dialect.go @@ -16,6 +16,7 @@ type Dialect interface { Quote(key string) string HasTable(scope *Scope, tableName string) bool HasColumn(scope *Scope, tableName string, columnName string) bool + HasIndex(scope *Scope, tableName string, indexName string) bool RemoveIndex(scope *Scope, indexName string) } diff --git a/main.go b/main.go index 71477830..5992ba43 100644 --- a/main.go +++ b/main.go @@ -399,6 +399,16 @@ func (s *DB) AddIndex(indexName string, column ...string) *DB { return s } +func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { + s.clone().NewScope(s.Value).addIndex(true, indexName, column...) + return s +} + +func (s *DB) RemoveIndex(indexName string) *DB { + s.clone().NewScope(s.Value).removeIndex(indexName) + return s +} + /* Add foreign key to the given scope @@ -410,16 +420,6 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate return s } -func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { - s.clone().NewScope(s.Value).addIndex(true, indexName, column...) - return s -} - -func (s *DB) RemoveIndex(indexName string) *DB { - s.clone().NewScope(s.Value).removeIndex(indexName) - return s -} - func (s *DB) Association(column string) *Association { var err error scope := s.clone().NewScope(s.Value) diff --git a/migration_test.go b/migration_test.go index 0a2f0852..951b79c6 100644 --- a/migration_test.go +++ b/migration_test.go @@ -30,22 +30,43 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } + scope := DB.NewScope(&Email{}) + if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { + t.Errorf("Email should have index idx_email_email") + } + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { t.Errorf("Got error when tried to remove index: %+v", err) } + if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { + t.Errorf("Email's index idx_email_email should be deleted") + } + if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { t.Errorf("Got error when tried to create index: %+v", err) } + if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email should have index idx_email_email_and_user_id") + } + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { t.Errorf("Got error when tried to remove index: %+v", err) } + if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email's index idx_email_email_and_user_id should be deleted") + } + if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { t.Errorf("Got error when tried to create index: %+v", err) } + if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email should have index idx_email_email_and_user_id") + } + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { t.Errorf("Should get to create duplicate record when having unique index") } @@ -54,6 +75,10 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } + if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email's index idx_email_email_and_user_id should be deleted") + } + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { t.Errorf("Should be able to create duplicated emails after remove unique index") } diff --git a/mssql.go b/mssql.go index 720dc615..3323874c 100644 --- a/mssql.go +++ b/mssql.go @@ -98,6 +98,12 @@ func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) boo return count > 0 } +func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) + return count > 0 +} + func (s *mssql) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/mysql.go b/mysql.go index 3c077b4b..e608619d 100644 --- a/mysql.go +++ b/mysql.go @@ -96,6 +96,12 @@ func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) boo return count > 0 } +func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) + return count > 0 +} + func (s *mysql) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/postgres.go b/postgres.go index 994ed9b8..b8722bac 100644 --- a/postgres.go +++ b/postgres.go @@ -87,7 +87,7 @@ func (s *postgres) HasTable(scope *Scope, tableName string) bool { func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) return count > 0 } @@ -95,6 +95,12 @@ func (s *postgres) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } +func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count) + return count > 0 +} + var hstoreType = reflect.TypeOf(Hstore{}) type Hstore map[string]*string diff --git a/sqlite3.go b/sqlite3.go index 2ff10790..e24d2410 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -80,6 +80,12 @@ func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) b return count > 0 } +func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count) + return count > 0 +} + func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) }