diff --git a/chain.go b/chain.go index bacce7b2..d26fadc8 100644 --- a/chain.go +++ b/chain.go @@ -35,6 +35,11 @@ func (s *Chain) err(err error) error { return err } +func (s *Chain) deleteLastError() { + s.Error = nil + s.Errors = s.Errors[:len(s.Errors)-1] +} + func (s *Chain) do(value interface{}) *Do { var do Do do.chain = s @@ -160,6 +165,10 @@ func (s *Chain) First(out interface{}, where ...interface{}) *Chain { } func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain { + if s.First(out).Error != nil { + s.do(out).initializedWithSearchCondition() + s.deleteLastError() + } return s } diff --git a/do.go b/do.go index 511bc725..1e35108b 100644 --- a/do.go +++ b/do.go @@ -350,14 +350,14 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { for key, value := range query.(map[string]interface{}) { sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", key, s.addToVars(value))) } - return strings.Join(sqls, ",") + return strings.Join(sqls, " AND ") case interface{}: m := &Model{data: query, driver: s.driver} var sqls []string for _, field := range m.columnsHasValue("") { sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value))) } - return strings.Join(sqls, ",") + return strings.Join(sqls, " AND ") } args := clause["args"].([]interface{}) @@ -383,7 +383,6 @@ func (s *Do) whereSql() (sql string) { if !s.unscoped && s.model.hasColumn("DeletedAt") { primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')") } - if !s.model.primaryKeyZero() { primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue()))) } @@ -466,3 +465,22 @@ func (s *Do) createTable() *Do { ) return s } + +func (s *Do) initializedWithSearchCondition() { + m := Model{data: s.value, driver: s.driver} + + for _, clause := range s.whereClause { + query := clause["query"] + switch query.(type) { + case map[string]interface{}: + for key, value := range query.(map[string]interface{}) { + m.setValueByColumn(key, value, s.value) + } + case interface{}: + m := &Model{data: query, driver: s.driver} + for _, field := range m.columnsHasValue("") { + m.setValueByColumn(field.DbName, field.Value, s.value) + } + } + } +} diff --git a/gorm_test.go b/gorm_test.go index d9b9af99..e1d96806 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -799,9 +799,14 @@ func TestSoftDelete(t *testing.T) { } func TestFindOrInitialize(t *testing.T) { - var user User - db.Where(User{Name: "hello world"}).FirstOrInit(&user) - if user.Name != "hello world" || user.Id != 0 { + var user1 User + db.Where(&User{Name: "hello world", Age: 33}).FirstOrInit(&user1) + if user1.Name != "hello world" || user1.Id != 0 || user1.Age != 33 { t.Errorf("user should be initialized with search value") } + + // db.FirstOrInit(&user2, map[string]interface{}{"name": "hahaha"}) + // if user2.Name != "hahaha" || user2.Id != 0 { + // t.Errorf("user should be initialized with search value") + // } } diff --git a/model.go b/model.go index ac64d319..4d200626 100644 --- a/model.go +++ b/model.go @@ -188,13 +188,21 @@ func (m *Model) callMethod(method string) error { return nil } -func (model *Model) returningStr() (str string) { - if model.driver == "postgres" { - str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb()) +func (m *Model) returningStr() (str string) { + if m.driver == "postgres" { + str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb()) } return } -func (model *Model) missingColumns() (results []string) { +func (m *Model) missingColumns() (results []string) { return } + +func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) { + data := reflect.ValueOf(out).Elem() + field := data.FieldByName(snakeToUpperCamel(name)) + if field.IsValid() { + field.Set(reflect.ValueOf(value)) + } +}