diff --git a/README.md b/README.md index 64476405..eb013d17 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO -* Complex where query (in) * Order * Limit * Select diff --git a/orm_test.go b/orm_test.go index f12d084b..bb318bc6 100644 --- a/orm_test.go +++ b/orm_test.go @@ -182,10 +182,42 @@ func TestComplexWhere(t *testing.T) { } users = []User{} - a := db.Where("name in (?)", []string{"1", "3"}).Find(&users) - a.db.Query("SELECT * FROM users WHERE ( name in ($1) )", "'1', '2'") + db.Where("name in (?)", []string{"1", "3"}).Find(&users) if len(users) != 3 { t.Errorf("Should only found 3 users's name in (1, 3), but have %v", len(users)) } + + var user_ids []int64 + for _, user := range users { + user_ids = append(user_ids, user.Id) + } + users = []User{} + db.Where("id in (?)", user_ids).Find(&users) + if len(users) != 3 { + t.Errorf("Should only found 3 users's name in (1, 3) - search by id, but have %v", len(users)) + } + + users = []User{} + db.Where("name in (?)", []string{"1", "2"}).Find(&users) + + if len(users) != 2 { + t.Errorf("Should only found 2 users's name in (1, 2), but have %v", len(users)) + } + + user_ids = []int64{} + for _, user := range users { + user_ids = append(user_ids, user.Id) + } + users = []User{} + db.Where("id in (?)", user_ids).Find(&users) + if len(users) != 2 { + t.Errorf("Should only found 2 users's name in (1, 2) - search by id, but have %v", len(users)) + } + + users = []User{} + db.Where("id in (?)", user_ids[0]).Find(&users) + if len(users) != 1 { + t.Errorf("Should only found 1 users's name in (1, 2) - search by the first id, but have %v", len(users)) + } } diff --git a/sql.go b/sql.go index d8857314..60d6a642 100644 --- a/sql.go +++ b/sql.go @@ -144,16 +144,19 @@ func (s *Orm) whereSql() (sql string) { str := "( " + clause["query"].(string) + " )" args := clause["args"].([]interface{}) for _, arg := range args { + switch reflect.TypeOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + v := reflect.ValueOf(arg) - switch arg.(type) { - case []string, []int, []int64, []int32: var temp_marks []string - for _ = range arg.([]string) { + for i := 0; i < v.Len(); i++ { temp_marks = append(temp_marks, "?") } + str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) - for _, a := range arg.([]string) { - str = strings.Replace(str, "?", s.addToVars(a), 1) + + for i := 0; i < v.Len(); i++ { + str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1) } default: str = strings.Replace(str, "?", s.addToVars(arg), 1)