Merge pull request #1210 from baijum/error-zero-args

If no arguments, return error at once
This commit is contained in:
Jinzhu 2016-10-23 22:58:23 +08:00 committed by GitHub
commit c1b9cf186e
2 changed files with 40 additions and 30 deletions

60
main.go
View File

@ -44,40 +44,40 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
if len(args) == 0 { if len(args) == 0 {
err = errors.New("invalid database source") err = errors.New("invalid database source")
} else { return nil, err
var source string }
var dbSQL sqlCommon var source string
var dbSQL sqlCommon
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
var driver = dialect var driver = dialect
if len(args) == 1 { if len(args) == 1 {
source = value source = value
} else if len(args) >= 2 { } else if len(args) >= 2 {
driver = value driver = value
source = args[1].(string) source = args[1].(string)
}
dbSQL, err = sql.Open(driver, source)
case sqlCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSQL = value
} }
dbSQL, err = sql.Open(driver, source)
case sqlCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSQL = value
}
db = DB{ db = DB{
dialect: newDialect(dialect, dbSQL.(*sql.DB)), dialect: newDialect(dialect, dbSQL.(*sql.DB)),
logger: defaultLogger, logger: defaultLogger,
callbacks: DefaultCallback, callbacks: DefaultCallback,
source: source, source: source,
values: map[string]interface{}{}, values: map[string]interface{}{},
db: dbSQL, db: dbSQL,
} }
db.parent = &db db.parent = &db
if err == nil { if err == nil {
err = db.DB().Ping() // Send a ping to make sure the database connection is alive. err = db.DB().Ping() // Send a ping to make sure the database connection is alive.
if err != nil { if err != nil {
db.DB().Close() db.DB().Close()
}
} }
} }

View File

@ -761,6 +761,16 @@ func TestDdlErrors(t *testing.T) {
} }
} }
func TestOpenWithOneParameter(t *testing.T) {
db, err := gorm.Open("dialect")
if db != nil {
t.Error("Open with one parameter returned non nil for db")
}
if err == nil {
t.Error("Open with one parameter returned err as nil")
}
}
func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
b.N = 2000 b.N = 2000
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {