package postgres import ( "fmt" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" ) type Migrator struct { migrator.Migrator } func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) return } func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { str = opt.Expression } if opt.Collate != "" { str += " COLLATE " + opt.Collate } if opt.Sort != "" { str += " " + opt.Sort } results = append(results, clause.Expr{SQL: str}) } return } func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } return m.DB.Raw( "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} createIndexSQL := "CREATE " if idx.Class != "" { createIndexSQL += idx.Class + " " } createIndexSQL += "INDEX ?" if idx.Type != "" { createIndexSQL += " USING " + idx.Type } createIndexSQL += " ON ??" if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } return m.DB.Exec(createIndexSQL, values...).Error } return fmt.Errorf("failed to create index with name %v", name) }) } func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) }) return count > 0 } func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { name := field if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName } return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", stmt.Table, name, ).Row().Scan(&count) }) return count > 0 }