Support save slice of data

This commit is contained in:
Jinzhu 2020-06-10 00:02:14 +08:00
parent 22ff8377df
commit f3424c6864
3 changed files with 62 additions and 17 deletions

View File

@ -185,19 +185,19 @@ func AfterCreate(db *gorm.DB) {
}
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch value := stmt.Dest.(type) {
case map[string]interface{}:
return ConvertMapToValuesForCreate(stmt, value)
values = ConvertMapToValuesForCreate(stmt, value)
case []map[string]interface{}:
return ConvertSliceOfMapToValuesForCreate(stmt, value)
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
default:
var (
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
curTime = stmt.DB.NowFunc()
isZero bool
)
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
@ -274,7 +274,30 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
}
}
}
}
if stmt.UpdatingColumn {
if stmt.Schema != nil {
columns := make([]string, 0, len(stmt.Schema.DBNames)-1)
for _, name := range stmt.Schema.DBNames {
if field := stmt.Schema.LookUpField(name); field != nil {
if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 {
columns = append(columns, name)
}
}
}
onConflict := clause.OnConflict{
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
DoUpdates: clause.AssignmentColumns(columns),
}
for idx, field := range stmt.Schema.PrimaryFields {
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClause(onConflict)
}
}
return values
}
}

View File

@ -22,13 +22,14 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
tx.AddError(ErrPtrStructSupported)
tx.Statement.UpdatingColumn = true
tx.callbacks.Create().Execute(tx)
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
for idx, pf := range tx.Statement.Schema.PrimaryFields {
if pv, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx)
@ -40,12 +41,16 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.AddClause(where)
}
}
fallthrough
default:
if len(tx.Statement.Selects) == 0 {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
}
return
}

View File

@ -90,6 +90,23 @@ func TestUpsertSlice(t *testing.T) {
}
}
func TestUpsertWithSave(t *testing.T) {
langs := []Language{
{Code: "upsert-save-1", Name: "Upsert-save-1"},
{Code: "upsert-save-2", Name: "Upsert-save-2"},
}
if err := DB.Save(&langs).Error; err != nil {
t.Errorf("Failed to create, got error %v", err)
}
for _, lang := range langs {
var result Language
if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
t.Errorf("Failed to query lang, got error %v", err)
}
}
}
func TestFindOrInitialize(t *testing.T) {
var user1, user2, user3, user4, user5, user6 User
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {