diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index 412d86c6..4707a637 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -23,6 +23,10 @@ func (m Migrator) HasTable(value interface{}) bool { func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Raw( "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", name, stmt.Table, diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index f06af25f..b144f573 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -37,11 +37,15 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem return } -func (m Migrator) HasIndex(value interface{}, indexName string) bool { +func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Raw( - "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName, + "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, ).Row().Scan(&count) }) @@ -50,10 +54,7 @@ func (m Migrator) HasIndex(value interface{}, indexName string) bool { func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - err := fmt.Errorf("failed to create index with name %v", name) - indexes := stmt.Schema.ParseIndexes() - - if idx, ok := indexes[name]; ok { + if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} @@ -73,18 +74,9 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } return m.DB.Exec(createIndexSQL, values...).Error - } else if field := stmt.Schema.LookUpField(name); field != nil { - for _, idx := range indexes { - for _, idxOpt := range idx.Fields { - if idxOpt.Field == field { - if err = m.CreateIndex(value, idx.Name); err != nil { - return err - } - } - } - } } - return err + + return fmt.Errorf("failed to create index with name %v", name) }) } diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 601de126..5f3671b4 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -2,6 +2,7 @@ package sqlite import ( "fmt" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" @@ -37,17 +38,6 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND sql LIKE ?", - "index", stmt.Table, "%INDEX "+name+" ON%", - ).Row().Scan(&count) - }) - return count > 0 -} - func (m Migrator) CreateConstraint(interface{}, string) error { return gorm.ErrNotImplemented } @@ -83,10 +73,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - err := fmt.Errorf("failed to create index with name %v", name) - indexes := stmt.Schema.ParseIndexes() - - if idx, ok := indexes[name]; ok { + if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} @@ -106,17 +93,44 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } return m.DB.Exec(createIndexSQL, values...).Error - } else if field := stmt.Schema.LookUpField(name); field != nil { - for _, idx := range indexes { - for _, idxOpt := range idx.Fields { - if idxOpt.Field == field { - if err = m.CreateIndex(value, idx.Name); err != nil { - return err - } - } - } - } } - return err + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, + ).Row().Scan(&count) + return nil + }) + return count > 0 +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + var sql string + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) + if sql != "" { + return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error + } + return fmt.Errorf("failed to find index with name %v", oldName) + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error }) } diff --git a/migrator/migrator.go b/migrator/migrator.go index cab266a3..1b0edf68 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -418,10 +418,7 @@ type BuildIndexOptionsInterface interface { func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - err := fmt.Errorf("failed to create index with name %v", name) - indexes := stmt.Schema.ParseIndexes() - - if idx, ok := indexes[name]; ok { + if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} @@ -441,23 +438,18 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } return m.DB.Exec(createIndexSQL, values...).Error - } else if field := stmt.Schema.LookUpField(name); field != nil { - for _, idx := range indexes { - for _, idxOpt := range idx.Fields { - if idxOpt.Field == field { - if err = m.CreateIndex(value, idx.Name); err != nil { - return err - } - } - } - } } - return err + + return fmt.Errorf("failed to create index with name %v", name) }) } func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error }) } @@ -466,6 +458,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Raw( "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name, diff --git a/schema/index.go b/schema/index.go index c5c96aa4..4228bba2 100644 --- a/schema/index.go +++ b/schema/index.go @@ -52,6 +52,23 @@ func (schema *Schema) ParseIndexes() map[string]Index { return indexes } +func (schema *Schema) LookIndex(name string) *Index { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { + return &index + } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } + } + + return nil +} + func parseFieldIndexes(field *Field) (indexes []Index) { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { if value != "" { diff --git a/tests/delete_test.go b/tests/delete_test.go index 8be072d3..3f17f1a1 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -46,3 +46,21 @@ func TestDelete(t *testing.T) { } } } + +func TestInlineCondDelete(t *testing.T) { + user1 := *GetUser("inline_delete_1", Config{}) + user2 := *GetUser("inline_delete_2", Config{}) + DB.Save(&user1).Save(&user2) + + if DB.Delete(&User{}, user1.ID).Error != nil { + t.Errorf("No error should happen when delete a record") + } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { + t.Errorf("User can't be found after delete") + } + + if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) + } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { + t.Errorf("User can't be found after delete") + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 917fba75..d944dfa2 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" ) @@ -27,3 +28,46 @@ func TestMigrate(t *testing.T) { } } } + +func TestIndexes(t *testing.T) { + type User struct { + gorm.Model + Name string `gorm:"index"` + } + + if err := DB.Migrator().CreateIndex(&User{}, "Name"); err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } + + if !DB.Migrator().HasIndex(&User{}, "Name") { + t.Errorf("Failed to find index for user's name") + } + + if err := DB.Migrator().DropIndex(&User{}, "Name"); err != nil { + t.Errorf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&User{}, "Name") { + t.Errorf("Should not find index for user's name after delete") + } + + if err := DB.Migrator().CreateIndex(&User{}, "Name"); err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } + + if err := DB.Migrator().RenameIndex(&User{}, "idx_users_name", "idx_users_name_1"); err != nil { + t.Errorf("no error should happen when rename index, but got %v", err) + } + + if !DB.Migrator().HasIndex(&User{}, "idx_users_name_1") { + t.Errorf("Should find index for user's name after rename") + } + + if err := DB.Migrator().DropIndex(&User{}, "idx_users_name_1"); err != nil { + t.Errorf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&User{}, "idx_users_name_1") { + t.Errorf("Should not find index for user's name after delete") + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 0c24a888..9435b2b1 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -7,6 +7,8 @@ fi for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then + echo "testing ${dialect}..." + if [ "$GORM_VERBOSE" = "" ] then DEBUG=false GORM_DIALECT=${dialect} go test -race ./...