Add FirstOrInit support

This commit is contained in:
Jinzhu 2020-05-28 13:12:56 +08:00
parent 5be642a435
commit ac8708b500
5 changed files with 136 additions and 32 deletions

View File

@ -224,13 +224,15 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
return return
} }
func (db *DB) Assign(attrs ...interface{}) (tx *DB) { func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.attrs = attrs
return return
} }
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.assigns = attrs
return return
} }

View File

@ -171,14 +171,3 @@ func (like Like) NegationBuild(builder Builder) {
builder.WriteString(" NOT LIKE ") builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value) builder.AddVar(builder, like.Value)
} }
// Map
type Map map[interface{}]interface{}
func (m Map) Build(builder Builder) {
// TODO
}
func (m Map) NegationBuild(builder Builder) {
// TODO
}

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"errors"
"reflect" "reflect"
"strings" "strings"
@ -99,13 +100,56 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
return return
} }
func (db *DB) FirstOrInit(dest interface{}, where ...interface{}) (tx *DB) { func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
field.Set(tx.Statement.ReflectValue, eq.Value)
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
field.Set(tx.Statement.ReflectValue, eq.Value)
}
default:
}
}
}
}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:])
tx.assignExprsToValue(exprs)
}
tx.Error = nil
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:])
tx.assignExprsToValue(exprs)
}
return return
} }
func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
// if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) {
// // initialize with attrs, conds
// }
// assign dest
return return
} }

View File

@ -34,6 +34,8 @@ type Statement struct {
SQL strings.Builder SQL strings.Builder
Vars []interface{} Vars []interface{}
NamedVars []sql.NamedArg NamedVars []sql.NamedArg
attrs []interface{}
assigns []interface{}
} }
// StatementModifier statement modifier interface // StatementModifier statement modifier interface
@ -195,7 +197,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
} }
// BuildCondtion build condition // BuildCondtion build condition
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) {
if sql, ok := query.(string); ok { if sql, ok := query.(string); ok {
if i, err := strconv.Atoi(sql); err == nil { if i, err := strconv.Atoi(sql); err == nil {
query = i query = i
@ -212,42 +214,53 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
switch v := arg.(type) { switch v := arg.(type) {
case clause.Expression: case clause.Expression:
conditions = append(conditions, v) conds = append(conds, v)
case *DB: case *DB:
if v.Statement == nil { if v.Statement == nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok { if cs, ok := v.Statement.Clauses["WHERE"]; ok {
conditions = append(conditions, cs.Expression) conds = append(conds, cs.Expression)
} }
} }
case map[interface{}]interface{}: case map[interface{}]interface{}:
var clauseMap = clause.Map{}
for i, j := range v { for i, j := range v {
clauseMap[i] = j conds = append(conds, clause.Eq{Column: i, Value: j})
} }
conditions = append(conditions, clauseMap)
case map[string]string: case map[string]string:
var clauseMap = clause.Map{}
for i, j := range v { for i, j := range v {
clauseMap[i] = j conds = append(conds, clause.Eq{Column: i, Value: j})
} }
conditions = append(conditions, clauseMap)
case map[string]interface{}: case map[string]interface{}:
var clauseMap = clause.Map{}
for i, j := range v { for i, j := range v {
clauseMap[i] = j conds = append(conds, clause.Eq{Column: i, Value: j})
} }
conditions = append(conditions, clauseMap)
default: default:
// TODO check is struct reflectValue := reflect.Indirect(reflect.ValueOf(arg))
// struct, slice -> ids if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.FieldsByDBName {
if v, isZero := field.ValueOf(reflectValue); !isZero {
conds = append(conds, clause.Eq{Column: field.DBName, Value: v})
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.FieldsByDBName {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
conds = append(conds, clause.Eq{Column: field.DBName, Value: v})
}
}
}
}
}
} }
} }
if len(conditions) == 0 { if len(conds) == 0 {
conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args}) conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
} }
return conditions return
} }
// Build build sql with clauses names // Build build sql with clauses names
@ -337,7 +350,7 @@ func (stmt *Statement) reinit() {
// return true // return true
// }) // })
stmt.Schema = nil // stmt.Schema = nil
stmt.SQL.Reset() stmt.SQL.Reset()
stmt.Vars = nil stmt.Vars = nil
stmt.NamedVars = nil stmt.NamedVars = nil

56
tests/upsert_test.go Normal file
View File

@ -0,0 +1,56 @@
package tests_test
import (
"testing"
. "github.com/jinzhu/gorm/tests"
)
func TestFindOrInitialize(t *testing.T) {
var user1, user2, user3, user4, user5, user6 User
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
t.Errorf("no error should happen when FirstOrInit, but got %v", err)
}
if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 {
t.Errorf("user should be initialized with search value")
}
DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2)
if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 {
t.Errorf("user should be initialized with search value")
}
DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"})
if user3.Name != "find or init 2" || user3.ID != 0 {
t.Errorf("user should be initialized with inline search value")
}
DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4)
if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 {
t.Errorf("user should be initialized with search value and attrs")
}
DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4)
if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 {
t.Errorf("user should be initialized with search value and assign attrs")
}
DB.Save(&User{Name: "find or init", Age: 33})
DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5)
if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 {
t.Errorf("user should be found and not initialized by Attrs")
}
DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6)
if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 {
t.Errorf("user should be found with FirstOrInit")
}
DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6)
if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 {
t.Errorf("user should be found and updated with assigned attrs")
}
}
func TestFindOrCreate(t *testing.T) {
}