diff --git a/chain.go b/chain.go index a80eaaa1..fdc7767a 100644 --- a/chain.go +++ b/chain.go @@ -215,21 +215,6 @@ func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain { return s } -func (s *Chain) CreateTable(value interface{}) *Chain { - s.do(value).createTable() - return s -} - -func (s *Chain) DropTable(value interface{}) *Chain { - s.do(value).dropTable() - return s -} - -func (s *Chain) AutoMigrate(value interface{}) *Chain { - s.do(value).autoMigrate() - return s -} - func (s *Chain) Unscoped() *Chain { s.unscoped = true return s @@ -281,6 +266,41 @@ func (s *Chain) Rollback() *Chain { return s } +func (s *Chain) CreateTable(value interface{}) *Chain { + s.do(value).createTable() + return s +} + +func (s *Chain) DropTable(value interface{}) *Chain { + s.do(value).dropTable() + return s +} + +func (s *Chain) AutoMigrate(value interface{}) *Chain { + s.do(value).autoMigrate() + return s +} + +func (s *Chain) UpdateColumn(column string, typ string) *Chain { + s.do(s.value).updateColumn(column, typ) + return s +} + +func (s *Chain) DropColumn(column string) *Chain { + s.do(s.value).dropColumn(column) + return s +} + +func (s *Chain) AddIndex(column string, index_name ...string) *Chain { + s.do(s.value).addIndex(column, index_name...) + return s +} + +func (s *Chain) RemoveIndex(column string) *Chain { + s.do(s.value).removeIndex(column) + return s +} + func (s *Chain) validSql(str string) (result bool) { result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str) if !result { diff --git a/do.go b/do.go index a624fffa..9a170da1 100644 --- a/do.go +++ b/do.go @@ -649,25 +649,45 @@ func (s *Do) createTable() *Do { } } - s.sql = fmt.Sprintf( - "CREATE TABLE %v (%v)", - s.tableName(), - strings.Join(sqls, ","), - ) + s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.tableName(), strings.Join(sqls, ",")) s.exec() return s } func (s *Do) dropTable() *Do { - s.sql = fmt.Sprintf( - "DROP TABLE %v", - s.tableName(), - ) + s.sql = fmt.Sprintf("DROP TABLE %v", s.tableName()) s.exec() return s } +func (s *Do) updateColumn(column string, typ string) { + s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.tableName(), column, typ) + s.exec() +} + +func (s *Do) dropColumn(column string) { + s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.tableName(), column) + s.exec() +} + +func (s *Do) addIndex(column string, names ...string) { + var index_name string + if len(names) > 0 { + index_name = names[0] + } else { + index_name = fmt.Sprintf("index_%v_on_%v", s.tableName(), column) + } + + s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.tableName(), column) + s.exec() +} + +func (s *Do) removeIndex(index_name string) { + s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.tableName()) + s.exec() +} + func (s *Do) autoMigrate() *Do { var table_name string sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.tableName()))