Try to implement the FirstOrInit method

This commit is contained in:
Jinzhu 2013-10-29 19:05:54 +08:00
parent 2e84f3b216
commit 6ac0012879
4 changed files with 50 additions and 10 deletions

View File

@ -35,6 +35,11 @@ func (s *Chain) err(err error) error {
return err
}
func (s *Chain) deleteLastError() {
s.Error = nil
s.Errors = s.Errors[:len(s.Errors)-1]
}
func (s *Chain) do(value interface{}) *Do {
var do Do
do.chain = s
@ -160,6 +165,10 @@ func (s *Chain) First(out interface{}, where ...interface{}) *Chain {
}
func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain {
if s.First(out).Error != nil {
s.do(out).initializedWithSearchCondition()
s.deleteLastError()
}
return s
}

24
do.go
View File

@ -350,14 +350,14 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
for key, value := range query.(map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", key, s.addToVars(value)))
}
return strings.Join(sqls, ",")
return strings.Join(sqls, " AND ")
case interface{}:
m := &Model{data: query, driver: s.driver}
var sqls []string
for _, field := range m.columnsHasValue("") {
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
}
return strings.Join(sqls, ",")
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
@ -383,7 +383,6 @@ func (s *Do) whereSql() (sql string) {
if !s.unscoped && s.model.hasColumn("DeletedAt") {
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
}
if !s.model.primaryKeyZero() {
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
}
@ -466,3 +465,22 @@ func (s *Do) createTable() *Do {
)
return s
}
func (s *Do) initializedWithSearchCondition() {
m := Model{data: s.value, driver: s.driver}
for _, clause := range s.whereClause {
query := clause["query"]
switch query.(type) {
case map[string]interface{}:
for key, value := range query.(map[string]interface{}) {
m.setValueByColumn(key, value, s.value)
}
case interface{}:
m := &Model{data: query, driver: s.driver}
for _, field := range m.columnsHasValue("") {
m.setValueByColumn(field.DbName, field.Value, s.value)
}
}
}
}

View File

@ -799,9 +799,14 @@ func TestSoftDelete(t *testing.T) {
}
func TestFindOrInitialize(t *testing.T) {
var user User
db.Where(User{Name: "hello world"}).FirstOrInit(&user)
if user.Name != "hello world" || user.Id != 0 {
var user1 User
db.Where(&User{Name: "hello world", Age: 33}).FirstOrInit(&user1)
if user1.Name != "hello world" || user1.Id != 0 || user1.Age != 33 {
t.Errorf("user should be initialized with search value")
}
// db.FirstOrInit(&user2, map[string]interface{}{"name": "hahaha"})
// if user2.Name != "hahaha" || user2.Id != 0 {
// t.Errorf("user should be initialized with search value")
// }
}

View File

@ -188,13 +188,21 @@ func (m *Model) callMethod(method string) error {
return nil
}
func (model *Model) returningStr() (str string) {
if model.driver == "postgres" {
str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb())
func (m *Model) returningStr() (str string) {
if m.driver == "postgres" {
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
}
return
}
func (model *Model) missingColumns() (results []string) {
func (m *Model) missingColumns() (results []string) {
return
}
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
data := reflect.ValueOf(out).Elem()
field := data.FieldByName(snakeToUpperCamel(name))
if field.IsValid() {
field.Set(reflect.ValueOf(value))
}
}