forked from mirror/gorm
Support save slice of data
This commit is contained in:
parent
22ff8377df
commit
f3424c6864
|
@ -185,19 +185,19 @@ func AfterCreate(db *gorm.DB) {
|
|||
}
|
||||
|
||||
// ConvertToCreateValues convert to create values
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
return ConvertMapToValuesForCreate(stmt, value)
|
||||
values = ConvertMapToValuesForCreate(stmt, value)
|
||||
case []map[string]interface{}:
|
||||
return ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
default:
|
||||
var (
|
||||
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
||||
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
||||
curTime = stmt.DB.NowFunc()
|
||||
isZero bool
|
||||
)
|
||||
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
||||
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
|
@ -274,7 +274,30 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.UpdatingColumn {
|
||||
if stmt.Schema != nil {
|
||||
columns := make([]string, 0, len(stmt.Schema.DBNames)-1)
|
||||
for _, name := range stmt.Schema.DBNames {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 {
|
||||
columns = append(columns, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConflict := clause.OnConflict{
|
||||
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
|
||||
DoUpdates: clause.AssignmentColumns(columns),
|
||||
}
|
||||
|
||||
for idx, field := range stmt.Schema.PrimaryFields {
|
||||
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
|
||||
}
|
||||
stmt.AddClause(onConflict)
|
||||
}
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,13 +22,14 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
tx.AddError(ErrPtrStructSupported)
|
||||
tx.Statement.UpdatingColumn = true
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
case reflect.Struct:
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
||||
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if pv, isZero := pf.ValueOf(reflectValue); isZero {
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
|
@ -40,12 +41,16 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
|
||||
tx.Statement.AddClause(where)
|
||||
}
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -90,6 +90,23 @@ func TestUpsertSlice(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUpsertWithSave(t *testing.T) {
|
||||
langs := []Language{
|
||||
{Code: "upsert-save-1", Name: "Upsert-save-1"},
|
||||
{Code: "upsert-save-2", Name: "Upsert-save-2"},
|
||||
}
|
||||
if err := DB.Save(&langs).Error; err != nil {
|
||||
t.Errorf("Failed to create, got error %v", err)
|
||||
}
|
||||
|
||||
for _, lang := range langs {
|
||||
var result Language
|
||||
if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
|
||||
t.Errorf("Failed to query lang, got error %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindOrInitialize(t *testing.T) {
|
||||
var user1, user2, user3, user4, user5, user6 User
|
||||
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
|
||||
|
|
Loading…
Reference in New Issue