diff --git a/chain.go b/chain.go index 2c54236a..ecdd8247 100644 --- a/chain.go +++ b/chain.go @@ -21,6 +21,7 @@ type Chain struct { whereClause []map[string]interface{} orClause []map[string]interface{} initAttrs []interface{} + assignAttrs []interface{} selectStr string orderStrs []string offsetStr string @@ -157,10 +158,7 @@ func (s *Chain) Update(column string, value interface{}) *Chain { } func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...bool) *Chain { - do := s.do(s.value) - do.updateAttrs = values - do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0] - do.update() + s.do(s.value).setUpdateAttrs(values, ignore_protected_attrs...).update() return s } @@ -181,17 +179,24 @@ func (s *Chain) Attrs(attrs interface{}) *Chain { return s } +func (s *Chain) Assign(attrs interface{}) *Chain { + s.assignAttrs = append(s.assignAttrs, attrs) + return s +} + func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain { if s.First(out, where...).Error != nil { - s.do(out).where(where...).where(s.initAttrs).initializeWithSearchCondition() + s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() s.deleteLastError() + } else { + s.do(out).update() } return s } func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain { if s.First(out, where...).Error != nil { - s.do(out).where(where...).where(s.initAttrs).initializeWithSearchCondition() + s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() s.deleteLastError() s.Save(out) } diff --git a/do.go b/do.go index 2aa7acda..5f19927d 100644 --- a/do.go +++ b/do.go @@ -146,6 +146,12 @@ func (s *Do) create() { return } +func (s *Do) setUpdateAttrs(values map[string]interface{}, ignore_protected_attrs ...bool) *Do { + s.updateAttrs = values + s.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0] + return s +} + func (s *Do) prepareUpdateAttrs() (results map[string]interface{}, update bool) { if len(s.updateAttrs) > 0 { results, update = s.model.updatedColumnsAndValues(s.updateAttrs) diff --git a/gorm_test.go b/gorm_test.go index fe225343..fe48fbc5 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -867,6 +867,11 @@ func TestFindOrInitialize(t *testing.T) { t.Errorf("user should be initialized with search value and attrs") } + db.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and assigned attrs") + } + db.Save(&User{Name: "find or init", Age: 33}) db.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user5) @@ -878,6 +883,11 @@ func TestFindOrInitialize(t *testing.T) { if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { t.Errorf("user should be found with FirstOrInit") } + + db.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { + // t.Errorf("user should be found and updated with assigned attrs") + } } func TestFindOrCreate(t *testing.T) { diff --git a/main.go b/main.go index 27c5f535..f7ca92c2 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,10 @@ func (s *DB) Attrs(attrs interface{}) *Chain { return s.buildChain().Attrs(attrs) } +func (s *DB) Assign(attrs interface{}) *Chain { + return s.buildChain().Assign(attrs) +} + func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *Chain { return s.buildChain().FirstOrInit(out, where...) }