Support save slice of data

This commit is contained in:
Jinzhu 2020-06-10 00:02:14 +08:00
parent 22ff8377df
commit f3424c6864
3 changed files with 62 additions and 17 deletions

View File

@ -185,19 +185,19 @@ func AfterCreate(db *gorm.DB) {
} }
// ConvertToCreateValues convert to create values // ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch value := stmt.Dest.(type) { switch value := stmt.Dest.(type) {
case map[string]interface{}: case map[string]interface{}:
return ConvertMapToValuesForCreate(stmt, value) values = ConvertMapToValuesForCreate(stmt, value)
case []map[string]interface{}: case []map[string]interface{}:
return ConvertSliceOfMapToValuesForCreate(stmt, value) values = ConvertSliceOfMapToValuesForCreate(stmt, value)
default: default:
var ( var (
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
curTime = stmt.DB.NowFunc() curTime = stmt.DB.NowFunc()
isZero bool isZero bool
) )
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames { for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
@ -274,7 +274,30 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
} }
} }
} }
return values
} }
if stmt.UpdatingColumn {
if stmt.Schema != nil {
columns := make([]string, 0, len(stmt.Schema.DBNames)-1)
for _, name := range stmt.Schema.DBNames {
if field := stmt.Schema.LookUpField(name); field != nil {
if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 {
columns = append(columns, name)
}
}
}
onConflict := clause.OnConflict{
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
DoUpdates: clause.AssignmentColumns(columns),
}
for idx, field := range stmt.Schema.PrimaryFields {
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClause(onConflict)
}
}
return values
} }

View File

@ -22,13 +22,14 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { reflectValue := reflect.Indirect(reflect.ValueOf(value))
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} switch reflectValue.Kind() {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) case reflect.Slice, reflect.Array:
switch reflectValue.Kind() { tx.Statement.UpdatingColumn = true
case reflect.Slice, reflect.Array: tx.callbacks.Create().Execute(tx)
tx.AddError(ErrPtrStructSupported) case reflect.Struct:
case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
for idx, pf := range tx.Statement.Schema.PrimaryFields { for idx, pf := range tx.Statement.Schema.PrimaryFields {
if pv, isZero := pf.ValueOf(reflectValue); isZero { if pv, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx) tx.callbacks.Create().Execute(tx)
@ -40,12 +41,16 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.AddClause(where) tx.Statement.AddClause(where)
} }
fallthrough
default:
if len(tx.Statement.Selects) == 0 {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
} }
if len(tx.Statement.Selects) == 0 {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
return return
} }

View File

@ -90,6 +90,23 @@ func TestUpsertSlice(t *testing.T) {
} }
} }
func TestUpsertWithSave(t *testing.T) {
langs := []Language{
{Code: "upsert-save-1", Name: "Upsert-save-1"},
{Code: "upsert-save-2", Name: "Upsert-save-2"},
}
if err := DB.Save(&langs).Error; err != nil {
t.Errorf("Failed to create, got error %v", err)
}
for _, lang := range langs {
var result Language
if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
t.Errorf("Failed to query lang, got error %v", err)
}
}
}
func TestFindOrInitialize(t *testing.T) { func TestFindOrInitialize(t *testing.T) {
var user1, user2, user3, user4, user5, user6 User var user1, user2, user3, user4, user5, user6 User
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {