diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index 7fd5e373..f079ad60 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -9,10 +9,6 @@ import ( "github.com/jinzhu/gorm/tests" ) -func TestOpen(t *testing.T) { - gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil) -} - var ( DB *gorm.DB err error diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index 35101bf3..f06af25f 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -87,3 +87,29 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { return err }) } + +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", + stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 4ffc4204..bb9726a8 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -3,6 +3,7 @@ package postgres import ( "database/sql" "fmt" + "strconv" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -29,13 +30,14 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "?" + return "$" + strconv.Itoa(len(stmt.Vars)) } func (dialector Dialector) QuoteChars() [2]byte { diff --git a/dialects/postgres/postgres_test.go b/dialects/postgres/postgres_test.go new file mode 100644 index 00000000..84c0fe53 --- /dev/null +++ b/dialects/postgres/postgres_test.go @@ -0,0 +1,29 @@ +package postgres_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/postgres" + "github.com/jinzhu/gorm/tests" +) + +var ( + DB *gorm.DB + err error +) + +func init() { + if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestCURD(t *testing.T) { + tests.RunTestsSuit(t, DB) +} + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +}