diff --git a/gorm.go b/gorm.go index 0de6860b..a5f8bbfd 100644 --- a/gorm.go +++ b/gorm.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" "sync" "time" @@ -220,6 +221,21 @@ func (db *DB) AddError(err error) error { return db.Error } +// DB returns `*sql.DB` +func (db *DB) DB() (*sql.DB, error) { + connPool := db.ConnPool + + if stmtDB, ok := connPool.(*PreparedStmtDB); ok { + connPool = stmtDB.ConnPool + } + + if sqldb, ok := connPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, errors.New("invalid db") +} + func (db *DB) getInstance() *DB { if db.clone > 0 { tx := &DB{Config: db.Config} diff --git a/tests/tests_test.go b/tests/tests_test.go index 09850003..c80fb849 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -24,6 +24,15 @@ func init() { log.Printf("failed to connect database, got error %v\n", err) os.Exit(1) } else { + sqlDB, err := DB.DB() + if err == nil { + err = sqlDB.Ping() + } + + if err != nil { + log.Printf("failed to connect database, got error %v\n", err) + } + RunMigrations() } }