diff --git a/README.md b/README.md index fea460fe..fe80cfcd 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO -* Better First method (First(&user, primary_key, where conditions)) * Update, Updates * Soft Delete * Even more complex where query (with map or struct) diff --git a/chain.go b/chain.go index 58300400..399538c7 100644 --- a/chain.go +++ b/chain.go @@ -145,13 +145,13 @@ func (s *Chain) Exec(sql string) *Chain { return s } -func (s *Chain) First(out interface{}) *Chain { - s.do(out).query() +func (s *Chain) First(out interface{}, where ...interface{}) *Chain { + s.do(out).query(where...) return s } -func (s *Chain) Find(out interface{}) *Chain { - s.do(out).query() +func (s *Chain) Find(out interface{}, where ...interface{}) *Chain { + s.do(out).query(where...) return s } diff --git a/do.go b/do.go index c8d7d45a..8b0141a7 100644 --- a/do.go +++ b/do.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "reflect" + "regexp" + "strconv" "strings" ) @@ -160,7 +162,11 @@ func (s *Do) prepareQuerySql() *Do { return s } -func (s *Do) query() { +func (s *Do) query(where ...interface{}) { + if len(where) > 0 { + s.where(where[0], where[1:len(where)]...) + } + var ( is_slice bool dest_type reflect.Type @@ -176,7 +182,6 @@ func (s *Do) query() { rows, err := s.db.Query(s.Sql, s.SqlVars...) defer rows.Close() s.err(err) - if rows.Err() != nil { s.err(rows.Err()) } @@ -246,8 +251,28 @@ func (s *Do) pluck(value interface{}) *Do { return s } -func (s *Do) buildWhereCondition(clause map[string]interface{}) string { - str := "( " + clause["query"].(string) + " )" +func (s *Do) where(querystring interface{}, args ...interface{}) *Do { + s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args}) + return s +} + +func (s *Do) primaryCondiation(value interface{}) string { + return fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), value) +} + +func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { + switch clause["query"].(type) { + case string: + value := clause["query"].(string) + if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + id, _ := strconv.Atoi(value) + return s.primaryCondiation(s.addToVars(id)) + } else { + str = "( " + value + " )" + } + case int, int64, int32: + return s.primaryCondiation(s.addToVars(clause["query"])) + } args := clause["args"].([]interface{}) for _, arg := range args { @@ -269,7 +294,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) string { str = strings.Replace(str, "?", s.addToVars(arg), 1) } } - return str + return } func (s *Do) whereSql() (sql string) { @@ -277,7 +302,7 @@ func (s *Do) whereSql() (sql string) { var and_conditions, or_conditions []string if !s.model.PrimaryKeyZero() { - primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue())) + primary_condiation = s.primaryCondiation(s.addToVars(s.model.PrimaryKeyValue())) } for _, clause := range s.whereClause { diff --git a/main.go b/main.go index 1ae314b7..0e5f6a32 100644 --- a/main.go +++ b/main.go @@ -25,12 +25,12 @@ func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { return s.buildORM().Where(querystring, args...) } -func (s *DB) First(out interface{}) *Chain { - return s.buildORM().First(out) +func (s *DB) First(out interface{}, where ...interface{}) *Chain { + return s.buildORM().First(out, where...) } -func (s *DB) Find(out interface{}) *Chain { - return s.buildORM().Find(out) +func (s *DB) Find(out interface{}, where ...interface{}) *Chain { + return s.buildORM().Find(out, where...) } func (s *DB) Limit(value interface{}) *Chain { diff --git a/orm_test.go b/orm_test.go index f7bf4cfd..7b7476ce 100644 --- a/orm_test.go +++ b/orm_test.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "reflect" + "strconv" "testing" "time" ) @@ -63,8 +64,34 @@ func init() { } func TestFirst(t *testing.T) { - var u1, u2 User + var u1, u2, u3, u4, u5, u6, u7 User db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2) + + db.Where("name = ?", "3").First(&u3, "age = 22").First(&u4, "age = ?", 24).First(&u5, "age = ?", 26) + if !((u5.Id == 0) && (u3.Age == 22 && u3.Name == "3") && (u4.Age == 24 && u4.Name == "3")) { + t.Errorf("Inline where condition for first when search") + } + + var us1, us2, us3, us4 []User + db.Find(&us1, "age = 22").Find(&us2, "name = ?", "3").Find(&us3, "age > ?", 20) + if !(len(us1) == 1 && len(us2) == 2 && len(us3) == 3) { + t.Errorf("Inline where condition for find when search") + } + + db.Find(&us4, "name = ? and age > ?", "3", "22") + if len(us4) != 1 { + t.Errorf("More complex inline where condition for find, %v", us4) + } + + db.First(&u6, u1.Id) + if !(u6.Id == u1.Id && u6.Id != 0) { + t.Errorf("Should find out user with int id") + } + + db.First(&u7, strconv.Itoa(int(u1.Id))) + if !(u6.Id == u1.Id && u6.Id != 0) { + t.Errorf("Should find out user with string id") + } } func TestSaveAndFind(t *testing.T) {