Add FirstOrCreate support

This commit is contained in:
Jinzhu 2020-05-28 16:10:10 +08:00
parent ac8708b500
commit dca5244387
3 changed files with 116 additions and 11 deletions

View File

@ -129,7 +129,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.attrs) > 0 { if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]) exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs) tx.assignExprsToValue(exprs)
} }
tx.Error = nil tx.Error = nil
@ -137,19 +137,54 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.assigns) > 0 { if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs) tx.assignExprsToValue(exprs)
} }
return return
} }
func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
// if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) {
// // initialize with attrs, conds tx.Error = nil
// }
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs)
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs)
}
return tx.Create(dest)
} else if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:])
assigns := map[string]interface{}{}
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
default:
}
}
}
return tx.Model(dest).Updates(assigns)
}
// assign dest
return return
} }
@ -307,3 +342,7 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx.callbacks.Raw().Execute(tx) tx.callbacks.Raw().Execute(tx)
return return
} }
func (db *DB) RecordNotFound() bool {
return errors.Is(db.Error, ErrRecordNotFound)
}

View File

@ -203,6 +203,8 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
query = i query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
} else if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}}
} }
} }
@ -238,15 +240,22 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:
for _, field := range s.FieldsByDBName { for _, field := range s.Fields {
if v, isZero := field.ValueOf(reflectValue); !isZero { if v, isZero := field.ValueOf(reflectValue); !isZero {
if field.DBName == "" {
conds = append(conds, clause.Eq{Column: field.Name, Value: v})
} else {
conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) conds = append(conds, clause.Eq{Column: field.DBName, Value: v})
} }
} }
}
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.FieldsByDBName { for _, field := range s.Fields {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
if field.DBName == "" {
conds = append(conds, clause.Eq{Column: field.Name, Value: v})
} else {
conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) conds = append(conds, clause.Eq{Column: field.DBName, Value: v})
} }
} }
@ -255,6 +264,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
} }
} }
} }
}
if len(conds) == 0 { if len(conds) == 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"testing" "testing"
"time"
. "github.com/jinzhu/gorm/tests" . "github.com/jinzhu/gorm/tests"
) )
@ -53,4 +54,59 @@ func TestFindOrInitialize(t *testing.T) {
} }
func TestFindOrCreate(t *testing.T) { func TestFindOrCreate(t *testing.T) {
var user1, user2, user3, user4, user5, user6, user7, user8 User
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1)
if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 {
t.Errorf("user should be created with search value")
}
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2)
if user1.ID != user2.ID || user2.Name != "find or create" || user2.ID == 0 || user2.Age != 33 {
t.Errorf("user should be created with search value")
}
DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"})
if user3.Name != "find or create 2" || user3.ID == 0 {
t.Errorf("user should be created with inline search value")
}
DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4)
if user4.Name != "find or create 3" || user4.ID == 0 || user4.Age != 44 {
t.Errorf("user should be created with search value and attrs")
}
updatedAt1 := user4.UpdatedAt
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("UpdateAt should be changed when update values with assign")
}
DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4)
if user4.Name != "find or create 4" || user4.ID == 0 || user4.Age != 44 {
t.Errorf("user should be created with search value and assigned attrs")
}
DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5)
if user5.Name != "find or create" || user5.ID == 0 || user5.Age != 33 {
t.Errorf("user should be found and not initialized by Attrs")
}
DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6)
if user6.Name != "find or create" || user6.ID == 0 || user6.Age != 44 {
t.Errorf("user should be found and updated with assigned attrs")
}
DB.Where(&User{Name: "find or create"}).Find(&user7)
if user7.Name != "find or create" || user7.ID == 0 || user7.Age != 44 {
t.Errorf("user should be found and updated with assigned attrs")
}
DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, Account: Account{Number: "1231231231"}, Pets: []*Pet{{Name: "first_or_create_pet1"}, {Name: "first_or_create_pet2"}}}).FirstOrCreate(&user8)
if DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).RecordNotFound() {
t.Errorf("has many association should be saved")
}
if DB.Where("number = ?", "1231231231").First(&Account{}).RecordNotFound() {
t.Errorf("belongs to association should be saved")
}
} }