diff --git a/chain.go b/chain.go index 91ab87b0..f38ba241 100644 --- a/chain.go +++ b/chain.go @@ -171,24 +171,29 @@ func (s *Chain) Exec(sql string) *Chain { func (s *Chain) First(out interface{}, where ...interface{}) *Chain { do := s.do(out) do.limitStr = "1" - do.query(where...) + do.where(where...).query() return s } func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain { - if s.First(out).Error != nil { - s.do(out).initializedWithSearchCondition() + if s.First(out, where...).Error != nil { + s.do(out).where(where...).initializeWithSearchCondition() s.deleteLastError() } return s } func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain { + if s.First(out, where...).Error != nil { + s.do(out).where(where...).initializeWithSearchCondition() + s.deleteLastError() + s.Save(out) + } return s } func (s *Chain) Find(out interface{}, where ...interface{}) *Chain { - s.do(out).query(where...) + s.do(out).where(where...).query() return s } diff --git a/do.go b/do.go index 1e35108b..8b700090 100644 --- a/do.go +++ b/do.go @@ -206,11 +206,7 @@ func (s *Do) prepareQuerySql() { return } -func (s *Do) query(where ...interface{}) { - if len(where) > 0 { - s.where(where[0], where[1:len(where)]...) - } - +func (s *Do) query() { var ( is_slice bool dest_type reflect.Type @@ -320,9 +316,11 @@ func (s *Do) pluck(column string, value interface{}) { return } -func (s *Do) where(querystring interface{}, args ...interface{}) { - s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args}) - return +func (s *Do) where(where ...interface{}) *Do { + if len(where) > 0 { + s.whereClause = append(s.whereClause, map[string]interface{}{"query": where[0], "args": where[1:len(where)]}) + } + return s } func (s *Do) primaryCondiation(value interface{}) string { @@ -466,7 +464,7 @@ func (s *Do) createTable() *Do { return s } -func (s *Do) initializedWithSearchCondition() { +func (s *Do) initializeWithSearchCondition() { m := Model{data: s.value, driver: s.driver} for _, clause := range s.whereClause { diff --git a/gorm_test.go b/gorm_test.go index 7b84a60c..cc740c6b 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -800,14 +800,27 @@ func TestSoftDelete(t *testing.T) { } func TestFindOrInitialize(t *testing.T) { - var user1 User - db.Where(&User{Name: "hello world", Age: 33}).FirstOrInit(&user1) - if user1.Name != "hello world" || user1.Id != 0 || user1.Age != 33 { + var user1, user2 User + db.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) + if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { t.Errorf("user should be initialized with search value") } - // db.FirstOrInit(&user2, map[string]interface{}{"name": "hahaha"}) - // if user2.Name != "hahaha" || user2.Id != 0 { - // t.Errorf("user should be initialized with search value") - // } + db.FirstOrInit(&user2, map[string]interface{}{"name": "find or init 2"}) + if user2.Name != "find or init 2" || user2.Id != 0 { + t.Errorf("user should be initialized with inline search value") + } +} + +func TestFindOrCreate(t *testing.T) { + var user1, user2 User + db.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) + if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { + t.Errorf("user should be created with search value") + } + + db.FirstOrCreate(&user2, map[string]interface{}{"name": "find or create 2"}) + if user2.Name != "find or create 2" || user2.Id == 0 { + t.Errorf("user should be created with inline search value") + } }