From 1d6b59f2237af563cffd65bade1a753265168a1c Mon Sep 17 00:00:00 2001 From: jnfeinstein Date: Sat, 22 Nov 2014 11:24:26 -0800 Subject: [PATCH] Allow .Open to reuse existing database connections. This commit allows you to pass a string or an existing database connection as the source for gorm. The dialect is still required because a) there is no common reference to it as far as i know, and b) gorm allows the dialect to differ from the driver. So, for the sake of simplicity, you still have to specity the dialect. This is useful if you have an existing transaction, but still want to use gorm to format your queries. This is dependent on the defintion of DB in pkg database/sql having the field 'dsn', which is the database source, obtained via reflect. --- README.md | 4 ++++ main.go | 45 +++++++++++++++++++++++++++++---------------- main_test.go | 16 ++++++++++++++++ 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 3ed16aa0..8a5cfb0f 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,10 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") +// You can also use an existing database connection handle +// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db := gorm.Open("postgres", dbSql) + // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) db.DB() diff --git a/main.go b/main.go index af626b06..d626d180 100644 --- a/main.go +++ b/main.go @@ -35,28 +35,41 @@ type DB struct { values map[string]interface{} } -func Open(dialect string, drivesources ...string) (DB, error) { +func Open(dialect string, args ...interface{}) (DB, error) { var db DB var err error - var driver = dialect var source string + var dbSql sqlCommon - if len(drivesources) == 0 { + if len(args) == 0 { err = errors.New("invalid database source") - } else { - if len(drivesources) == 1 { - source = drivesources[0] - } else if len(drivesources) >= 2 { - driver = drivesources[0] - source = drivesources[1] - } - - db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", - logger: defaultLogger, callback: DefaultCallback, source: source, - values: map[string]interface{}{}} - db.db, err = sql.Open(driver, source) - db.parent = &db } + + switch value := args[0].(type) { + case string: + var driver = dialect + if len(args) == 1 { + source = value + } else if len(args) >= 2 { + driver = value + source = args[1].(string) + } + dbSql, err = sql.Open(driver, source) + case sqlCommon: + if reflect.TypeOf(value).Kind() == reflect.Ptr { + source = reflect.ValueOf(value).Elem().FieldByName("dsn").String() + } else { + source = reflect.ValueOf(value).FieldByName("dsn").String() + } + dbSql = value + } + + db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", + logger: defaultLogger, callback: DefaultCallback, source: source, + values: map[string]interface{}{}} + db.db = dbSql + db.parent = &db + return db, err } diff --git a/main_test.go b/main_test.go index 1f7c5add..75f17df2 100644 --- a/main_test.go +++ b/main_test.go @@ -554,6 +554,22 @@ func TestCompatibilityMode(t *testing.T) { } } +func TestOpenExistingDB(t *testing.T) { + DB.Save(&User{Name: "jnfeinstein"}) + dialect := os.Getenv("GORM_DIALECT") + + db, err := gorm.Open(dialect, DB.DB()) + if err != nil { + t.Errorf("Should have wrapped the existing DB connection") + } + + var user User + if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { + t.Errorf("Should have found existing record") + } + +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ {