forked from mirror/gorm
Add FirstOrInit support
This commit is contained in:
parent
5be642a435
commit
ac8708b500
|
@ -224,13 +224,15 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.attrs = attrs
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.assigns = attrs
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -171,14 +171,3 @@ func (like Like) NegationBuild(builder Builder) {
|
|||
builder.WriteString(" NOT LIKE ")
|
||||
builder.AddVar(builder, like.Value)
|
||||
}
|
||||
|
||||
// Map
|
||||
type Map map[interface{}]interface{}
|
||||
|
||||
func (m Map) Build(builder Builder) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
func (m Map) NegationBuild(builder Builder) {
|
||||
// TODO
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
|
@ -99,13 +100,56 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrInit(dest interface{}, where ...interface{}) (tx *DB) {
|
||||
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := tx.Statement.Schema.LookUpField(column); field != nil {
|
||||
field.Set(tx.Statement.ReflectValue, eq.Value)
|
||||
}
|
||||
case clause.Column:
|
||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
field.Set(tx.Statement.ReflectValue, eq.Value)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignExprsToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:])
|
||||
tx.assignExprsToValue(exprs)
|
||||
}
|
||||
tx.Error = nil
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:])
|
||||
tx.assignExprsToValue(exprs)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
// if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) {
|
||||
// // initialize with attrs, conds
|
||||
// }
|
||||
|
||||
// assign dest
|
||||
return
|
||||
}
|
||||
|
||||
|
|
49
statement.go
49
statement.go
|
@ -34,6 +34,8 @@ type Statement struct {
|
|||
SQL strings.Builder
|
||||
Vars []interface{}
|
||||
NamedVars []sql.NamedArg
|
||||
attrs []interface{}
|
||||
assigns []interface{}
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
|
@ -195,7 +197,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
|||
}
|
||||
|
||||
// BuildCondtion build condition
|
||||
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
|
||||
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
||||
if sql, ok := query.(string); ok {
|
||||
if i, err := strconv.Atoi(sql); err == nil {
|
||||
query = i
|
||||
|
@ -212,42 +214,53 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
|
|||
|
||||
switch v := arg.(type) {
|
||||
case clause.Expression:
|
||||
conditions = append(conditions, v)
|
||||
conds = append(conds, v)
|
||||
case *DB:
|
||||
if v.Statement == nil {
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
conditions = append(conditions, cs.Expression)
|
||||
conds = append(conds, cs.Expression)
|
||||
}
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
var clauseMap = clause.Map{}
|
||||
for i, j := range v {
|
||||
clauseMap[i] = j
|
||||
conds = append(conds, clause.Eq{Column: i, Value: j})
|
||||
}
|
||||
conditions = append(conditions, clauseMap)
|
||||
case map[string]string:
|
||||
var clauseMap = clause.Map{}
|
||||
for i, j := range v {
|
||||
clauseMap[i] = j
|
||||
conds = append(conds, clause.Eq{Column: i, Value: j})
|
||||
}
|
||||
conditions = append(conditions, clauseMap)
|
||||
case map[string]interface{}:
|
||||
var clauseMap = clause.Map{}
|
||||
for i, j := range v {
|
||||
clauseMap[i] = j
|
||||
conds = append(conds, clause.Eq{Column: i, Value: j})
|
||||
}
|
||||
conditions = append(conditions, clauseMap)
|
||||
default:
|
||||
// TODO check is struct
|
||||
// struct, slice -> ids
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, field := range s.FieldsByDBName {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
conds = append(conds, clause.Eq{Column: field.DBName, Value: v})
|
||||
}
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for _, field := range s.FieldsByDBName {
|
||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
||||
conds = append(conds, clause.Eq{Column: field.DBName, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
||||
if len(conds) == 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
||||
}
|
||||
|
||||
return conditions
|
||||
return
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
|
@ -337,7 +350,7 @@ func (stmt *Statement) reinit() {
|
|||
// return true
|
||||
// })
|
||||
|
||||
stmt.Schema = nil
|
||||
// stmt.Schema = nil
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
stmt.NamedVars = nil
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestFindOrInitialize(t *testing.T) {
|
||||
var user1, user2, user3, user4, user5, user6 User
|
||||
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
|
||||
t.Errorf("no error should happen when FirstOrInit, but got %v", err)
|
||||
}
|
||||
if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 {
|
||||
t.Errorf("user should be initialized with search value")
|
||||
}
|
||||
|
||||
DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2)
|
||||
if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 {
|
||||
t.Errorf("user should be initialized with search value")
|
||||
}
|
||||
|
||||
DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"})
|
||||
if user3.Name != "find or init 2" || user3.ID != 0 {
|
||||
t.Errorf("user should be initialized with inline search value")
|
||||
}
|
||||
|
||||
DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4)
|
||||
if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 {
|
||||
t.Errorf("user should be initialized with search value and attrs")
|
||||
}
|
||||
|
||||
DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4)
|
||||
if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 {
|
||||
t.Errorf("user should be initialized with search value and assign attrs")
|
||||
}
|
||||
|
||||
DB.Save(&User{Name: "find or init", Age: 33})
|
||||
DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5)
|
||||
if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 {
|
||||
t.Errorf("user should be found and not initialized by Attrs")
|
||||
}
|
||||
|
||||
DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6)
|
||||
if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 {
|
||||
t.Errorf("user should be found with FirstOrInit")
|
||||
}
|
||||
|
||||
DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6)
|
||||
if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 {
|
||||
t.Errorf("user should be found and updated with assigned attrs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindOrCreate(t *testing.T) {
|
||||
}
|
Loading…
Reference in New Issue