diff --git a/README.md b/README.md index 3ed16aa0..8a5cfb0f 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,10 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") +// You can also use an existing database connection handle +// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db := gorm.Open("postgres", dbSql) + // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) db.DB() diff --git a/main.go b/main.go index af626b06..d626d180 100644 --- a/main.go +++ b/main.go @@ -35,28 +35,41 @@ type DB struct { values map[string]interface{} } -func Open(dialect string, drivesources ...string) (DB, error) { +func Open(dialect string, args ...interface{}) (DB, error) { var db DB var err error - var driver = dialect var source string + var dbSql sqlCommon - if len(drivesources) == 0 { + if len(args) == 0 { err = errors.New("invalid database source") - } else { - if len(drivesources) == 1 { - source = drivesources[0] - } else if len(drivesources) >= 2 { - driver = drivesources[0] - source = drivesources[1] - } - - db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", - logger: defaultLogger, callback: DefaultCallback, source: source, - values: map[string]interface{}{}} - db.db, err = sql.Open(driver, source) - db.parent = &db } + + switch value := args[0].(type) { + case string: + var driver = dialect + if len(args) == 1 { + source = value + } else if len(args) >= 2 { + driver = value + source = args[1].(string) + } + dbSql, err = sql.Open(driver, source) + case sqlCommon: + if reflect.TypeOf(value).Kind() == reflect.Ptr { + source = reflect.ValueOf(value).Elem().FieldByName("dsn").String() + } else { + source = reflect.ValueOf(value).FieldByName("dsn").String() + } + dbSql = value + } + + db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", + logger: defaultLogger, callback: DefaultCallback, source: source, + values: map[string]interface{}{}} + db.db = dbSql + db.parent = &db + return db, err } diff --git a/main_test.go b/main_test.go index 1f7c5add..75f17df2 100644 --- a/main_test.go +++ b/main_test.go @@ -554,6 +554,22 @@ func TestCompatibilityMode(t *testing.T) { } } +func TestOpenExistingDB(t *testing.T) { + DB.Save(&User{Name: "jnfeinstein"}) + dialect := os.Getenv("GORM_DIALECT") + + db, err := gorm.Open(dialect, DB.DB()) + if err != nil { + t.Errorf("Should have wrapped the existing DB connection") + } + + var user User + if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { + t.Errorf("Should have found existing record") + } + +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ {