forked from mirror/gorm
Allow .Open to reuse existing database connections.
This commit allows you to pass a string or an existing database connection as the source for gorm. The dialect is still required because a) there is no common reference to it as far as i know, and b) gorm allows the dialect to differ from the driver. So, for the sake of simplicity, you still have to specity the dialect. This is useful if you have an existing transaction, but still want to use gorm to format your queries. This is dependent on the defintion of DB in pkg database/sql having the field 'dsn', which is the database source, obtained via reflect.
This commit is contained in:
parent
0e22837c04
commit
1d6b59f223
|
@ -106,6 +106,10 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True")
|
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True")
|
||||||
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
|
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
|
||||||
|
|
||||||
|
// You can also use an existing database connection handle
|
||||||
|
// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
|
// db := gorm.Open("postgres", dbSql)
|
||||||
|
|
||||||
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
|
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
|
||||||
db.DB()
|
db.DB()
|
||||||
|
|
||||||
|
|
45
main.go
45
main.go
|
@ -35,28 +35,41 @@ type DB struct {
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Open(dialect string, drivesources ...string) (DB, error) {
|
func Open(dialect string, args ...interface{}) (DB, error) {
|
||||||
var db DB
|
var db DB
|
||||||
var err error
|
var err error
|
||||||
var driver = dialect
|
|
||||||
var source string
|
var source string
|
||||||
|
var dbSql sqlCommon
|
||||||
|
|
||||||
if len(drivesources) == 0 {
|
if len(args) == 0 {
|
||||||
err = errors.New("invalid database source")
|
err = errors.New("invalid database source")
|
||||||
} else {
|
|
||||||
if len(drivesources) == 1 {
|
|
||||||
source = drivesources[0]
|
|
||||||
} else if len(drivesources) >= 2 {
|
|
||||||
driver = drivesources[0]
|
|
||||||
source = drivesources[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
|
|
||||||
logger: defaultLogger, callback: DefaultCallback, source: source,
|
|
||||||
values: map[string]interface{}{}}
|
|
||||||
db.db, err = sql.Open(driver, source)
|
|
||||||
db.parent = &db
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch value := args[0].(type) {
|
||||||
|
case string:
|
||||||
|
var driver = dialect
|
||||||
|
if len(args) == 1 {
|
||||||
|
source = value
|
||||||
|
} else if len(args) >= 2 {
|
||||||
|
driver = value
|
||||||
|
source = args[1].(string)
|
||||||
|
}
|
||||||
|
dbSql, err = sql.Open(driver, source)
|
||||||
|
case sqlCommon:
|
||||||
|
if reflect.TypeOf(value).Kind() == reflect.Ptr {
|
||||||
|
source = reflect.ValueOf(value).Elem().FieldByName("dsn").String()
|
||||||
|
} else {
|
||||||
|
source = reflect.ValueOf(value).FieldByName("dsn").String()
|
||||||
|
}
|
||||||
|
dbSql = value
|
||||||
|
}
|
||||||
|
|
||||||
|
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
|
||||||
|
logger: defaultLogger, callback: DefaultCallback, source: source,
|
||||||
|
values: map[string]interface{}{}}
|
||||||
|
db.db = dbSql
|
||||||
|
db.parent = &db
|
||||||
|
|
||||||
return db, err
|
return db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
16
main_test.go
16
main_test.go
|
@ -554,6 +554,22 @@ func TestCompatibilityMode(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenExistingDB(t *testing.T) {
|
||||||
|
DB.Save(&User{Name: "jnfeinstein"})
|
||||||
|
dialect := os.Getenv("GORM_DIALECT")
|
||||||
|
|
||||||
|
db, err := gorm.Open(dialect, DB.DB())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should have wrapped the existing DB connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
|
||||||
|
t.Errorf("Should have found existing record")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 2000
|
b.N = 2000
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
|
|
Loading…
Reference in New Issue