forked from mirror/gorm
Try to implement the FirstOrInit method
This commit is contained in:
parent
2e84f3b216
commit
6ac0012879
9
chain.go
9
chain.go
|
@ -35,6 +35,11 @@ func (s *Chain) err(err error) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *Chain) deleteLastError() {
|
||||
s.Error = nil
|
||||
s.Errors = s.Errors[:len(s.Errors)-1]
|
||||
}
|
||||
|
||||
func (s *Chain) do(value interface{}) *Do {
|
||||
var do Do
|
||||
do.chain = s
|
||||
|
@ -160,6 +165,10 @@ func (s *Chain) First(out interface{}, where ...interface{}) *Chain {
|
|||
}
|
||||
|
||||
func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain {
|
||||
if s.First(out).Error != nil {
|
||||
s.do(out).initializedWithSearchCondition()
|
||||
s.deleteLastError()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
24
do.go
24
do.go
|
@ -350,14 +350,14 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
for key, value := range query.(map[string]interface{}) {
|
||||
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", key, s.addToVars(value)))
|
||||
}
|
||||
return strings.Join(sqls, ",")
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
m := &Model{data: query, driver: s.driver}
|
||||
var sqls []string
|
||||
for _, field := range m.columnsHasValue("") {
|
||||
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
|
||||
}
|
||||
return strings.Join(sqls, ",")
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
|
@ -383,7 +383,6 @@ func (s *Do) whereSql() (sql string) {
|
|||
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
||||
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
|
||||
}
|
||||
|
||||
if !s.model.primaryKeyZero() {
|
||||
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
|
||||
}
|
||||
|
@ -466,3 +465,22 @@ func (s *Do) createTable() *Do {
|
|||
)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) initializedWithSearchCondition() {
|
||||
m := Model{data: s.value, driver: s.driver}
|
||||
|
||||
for _, clause := range s.whereClause {
|
||||
query := clause["query"]
|
||||
switch query.(type) {
|
||||
case map[string]interface{}:
|
||||
for key, value := range query.(map[string]interface{}) {
|
||||
m.setValueByColumn(key, value, s.value)
|
||||
}
|
||||
case interface{}:
|
||||
m := &Model{data: query, driver: s.driver}
|
||||
for _, field := range m.columnsHasValue("") {
|
||||
m.setValueByColumn(field.DbName, field.Value, s.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
11
gorm_test.go
11
gorm_test.go
|
@ -799,9 +799,14 @@ func TestSoftDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFindOrInitialize(t *testing.T) {
|
||||
var user User
|
||||
db.Where(User{Name: "hello world"}).FirstOrInit(&user)
|
||||
if user.Name != "hello world" || user.Id != 0 {
|
||||
var user1 User
|
||||
db.Where(&User{Name: "hello world", Age: 33}).FirstOrInit(&user1)
|
||||
if user1.Name != "hello world" || user1.Id != 0 || user1.Age != 33 {
|
||||
t.Errorf("user should be initialized with search value")
|
||||
}
|
||||
|
||||
// db.FirstOrInit(&user2, map[string]interface{}{"name": "hahaha"})
|
||||
// if user2.Name != "hahaha" || user2.Id != 0 {
|
||||
// t.Errorf("user should be initialized with search value")
|
||||
// }
|
||||
}
|
||||
|
|
16
model.go
16
model.go
|
@ -188,13 +188,21 @@ func (m *Model) callMethod(method string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (model *Model) returningStr() (str string) {
|
||||
if model.driver == "postgres" {
|
||||
str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb())
|
||||
func (m *Model) returningStr() (str string) {
|
||||
if m.driver == "postgres" {
|
||||
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (model *Model) missingColumns() (results []string) {
|
||||
func (m *Model) missingColumns() (results []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
|
||||
data := reflect.ValueOf(out).Elem()
|
||||
field := data.FieldByName(snakeToUpperCamel(name))
|
||||
if field.IsValid() {
|
||||
field.Set(reflect.ValueOf(value))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue