diff --git a/main.go b/main.go index d5582685..e4af5873 100644 --- a/main.go +++ b/main.go @@ -44,40 +44,40 @@ func Open(dialect string, args ...interface{}) (*DB, error) { if len(args) == 0 { err = errors.New("invalid database source") - } else { - var source string - var dbSQL sqlCommon + return nil, err + } + var source string + var dbSQL sqlCommon - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - case sqlCommon: - source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() - dbSQL = value + switch value := args[0].(type) { + case string: + var driver = dialect + if len(args) == 1 { + source = value + } else if len(args) >= 2 { + driver = value + source = args[1].(string) } + dbSQL, err = sql.Open(driver, source) + case sqlCommon: + source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() + dbSQL = value + } - db = DB{ - dialect: newDialect(dialect, dbSQL.(*sql.DB)), - logger: defaultLogger, - callbacks: DefaultCallback, - source: source, - values: map[string]interface{}{}, - db: dbSQL, - } - db.parent = &db + db = DB{ + dialect: newDialect(dialect, dbSQL.(*sql.DB)), + logger: defaultLogger, + callbacks: DefaultCallback, + source: source, + values: map[string]interface{}{}, + db: dbSQL, + } + db.parent = &db - if err == nil { - err = db.DB().Ping() // Send a ping to make sure the database connection is alive. - if err != nil { - db.DB().Close() - } + if err == nil { + err = db.DB().Ping() // Send a ping to make sure the database connection is alive. + if err != nil { + db.DB().Close() } } diff --git a/main_test.go b/main_test.go index 7b62f46f..729e6eb2 100644 --- a/main_test.go +++ b/main_test.go @@ -761,6 +761,16 @@ func TestDdlErrors(t *testing.T) { } } +func TestOpenWithOneParameter(t *testing.T) { + db, err := gorm.Open("dialect") + if db != nil { + t.Error("Open with one parameter returned non nil for db") + } + if err == nil { + t.Error("Open with one parameter returned err as nil") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ {