Allow .Open to reuse existing database connections.

This commit allows you to pass a string or an existing database
connection as the source for gorm. The dialect is still required
because a) there is no common reference to it as far as i know, and
b) gorm allows the dialect to differ from the driver. So, for the sake
of simplicity, you still have to specity the dialect.

This is useful if you have an existing transaction, but still
want to use gorm to format your queries.

This is dependent on the defintion of DB in pkg database/sql having
the field 'dsn', which is the database source, obtained via reflect.
This commit is contained in:
jnfeinstein 2014-11-22 11:24:26 -08:00 committed by Jinzhu
parent 0e22837c04
commit 1d6b59f223
3 changed files with 49 additions and 16 deletions

View File

@ -106,6 +106,10 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True")
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
// You can also use an existing database connection handle
// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db := gorm.Open("postgres", dbSql)
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
db.DB()

45
main.go
View File

@ -35,28 +35,41 @@ type DB struct {
values map[string]interface{}
}
func Open(dialect string, drivesources ...string) (DB, error) {
func Open(dialect string, args ...interface{}) (DB, error) {
var db DB
var err error
var driver = dialect
var source string
var dbSql sqlCommon
if len(drivesources) == 0 {
if len(args) == 0 {
err = errors.New("invalid database source")
} else {
if len(drivesources) == 1 {
source = drivesources[0]
} else if len(drivesources) >= 2 {
driver = drivesources[0]
source = drivesources[1]
}
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
logger: defaultLogger, callback: DefaultCallback, source: source,
values: map[string]interface{}{}}
db.db, err = sql.Open(driver, source)
db.parent = &db
}
switch value := args[0].(type) {
case string:
var driver = dialect
if len(args) == 1 {
source = value
} else if len(args) >= 2 {
driver = value
source = args[1].(string)
}
dbSql, err = sql.Open(driver, source)
case sqlCommon:
if reflect.TypeOf(value).Kind() == reflect.Ptr {
source = reflect.ValueOf(value).Elem().FieldByName("dsn").String()
} else {
source = reflect.ValueOf(value).FieldByName("dsn").String()
}
dbSql = value
}
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
logger: defaultLogger, callback: DefaultCallback, source: source,
values: map[string]interface{}{}}
db.db = dbSql
db.parent = &db
return db, err
}

View File

@ -554,6 +554,22 @@ func TestCompatibilityMode(t *testing.T) {
}
}
func TestOpenExistingDB(t *testing.T) {
DB.Save(&User{Name: "jnfeinstein"})
dialect := os.Getenv("GORM_DIALECT")
db, err := gorm.Open(dialect, DB.DB())
if err != nil {
t.Errorf("Should have wrapped the existing DB connection")
}
var user User
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
t.Errorf("Should have found existing record")
}
}
func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {