Add save associations for bulk create

This commit is contained in:
Jinzhu 2020-04-19 23:11:56 +08:00
parent 158bacefbe
commit 7bcd95d4b8
3 changed files with 229 additions and 115 deletions

View File

@ -10,41 +10,75 @@ import (
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
// Save Belongs To associations // Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) if !saveAssociationCheck(db, rel, selectColumns, restricted) {
if !(creatable || updatable) {
continue continue
} }
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice: case reflect.Slice:
var (
objs []reflect.Value
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
} else {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(rv)
ref.ForeignKey.Set(objs[i], pv)
}
}
}
}
}
if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil {
for i := 0; i < elems.Len(); i++ {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elems.Index(i))
ref.ForeignKey.Set(objs[i], pv)
}
}
}
}
}
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
_, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
if rv.Kind() == reflect.Ptr {
if isZero && creatable { db.Session(&gorm.Session{}).Create(rv.Interface())
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Create(f.Interface())
} else { } else {
db.Session(&gorm.Session{}).Create(f.Addr().Interface()) db.Session(&gorm.Session{}).Create(rv.Addr().Interface())
} }
} else if !isZero && updatable {
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Save(f.Interface())
} else {
db.Session(&gorm.Session{}).Save(f.Addr().Interface())
}
} else {
continue
}
if saveRef {
for _, ref := range rel.References { for _, ref := range rel.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(f) pv, _ := ref.PrimaryKey.ValueOf(rv)
ref.ForeignKey.Set(db.Statement.ReflectValue, fv) ref.ForeignKey.Set(db.Statement.ReflectValue, pv)
} }
} }
} }
@ -55,20 +89,58 @@ func SaveBeforeAssociations(db *gorm.DB) {
} }
func SaveAfterAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) {
// Save Has One associations if db.Statement.Schema != nil {
for _, rel := range db.Statement.Schema.Relationships.HasOne { selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field)
if !(creatable || updatable) {
continue
}
switch db.Statement.ReflectValue.Kind() { // Save Has One associations
case reflect.Slice: for _, rel := range db.Statement.Schema.Relationships.HasOne {
case reflect.Struct: if !saveAssociationCheck(db, rel, selectColumns, restricted) {
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { continue
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) }
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if rv, zero := rel.Field.ValueOf(obj); !zero {
rv := reflect.ValueOf(rv)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(rv, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(rv, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
}
}
}
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if saveRef {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
@ -77,98 +149,134 @@ func SaveAfterAssociations(db *gorm.DB) {
ref.ForeignKey.Set(f, ref.PrimaryValue) ref.ForeignKey.Set(f, ref.PrimaryValue)
} }
} }
}
_, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero {
if f.Kind() == reflect.Ptr {
if isZero && creatable { db.Session(&gorm.Session{}).Create(f.Interface())
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Create(f.Interface())
} else {
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
}
} else if !isZero && updatable {
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Save(f.Interface())
} else {
db.Session(&gorm.Session{}).Save(f.Addr().Interface())
}
} else {
continue
}
}
}
}
// Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany {
creatable, updatable, _ := saveAssociationCheck(db, rel.Field)
if !(creatable || updatable) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := true
if fieldType.Kind() != reflect.Ptr {
isPtr = false
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
db.Statement.ReflectValue.Index(i)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
_, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(elem, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
}
}
if isZero && creatable {
if isPtr {
elems = reflect.Append(elems, elem)
} else { } else {
elems = reflect.Append(elems, elem.Addr()) db.Session(&gorm.Session{}).Create(f.Addr().Interface())
} }
} }
} }
} }
} }
if elems.Len() > 0 { // Save Has Many associations
db.Session(&gorm.Session{}).Create(elems.Interface()) for _, rel := range db.Statement.Schema.Relationships.HasMany {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := true
if fieldType.Kind() != reflect.Ptr {
isPtr = false
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v)
ref.ForeignKey.Set(elem, pv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero {
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
}
}
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := true
if fieldType.Kind() != reflect.Ptr {
isPtr = false
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
db.Statement.ReflectValue.Index(i)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(elem, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero {
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
}
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
}
} }
} }
} }
func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool {
creatable := field.Creatable savable := true
updatable := field.Updatable if value, ok := db.Get("gorm:save_association"); ok {
saveRef := true savable = utils.CheckTruth(value)
if value, ok := db.Get("gorm:association_autocreate"); creatable && ok {
creatable = utils.CheckTruth(value)
} }
if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok { if savable {
updatable = utils.CheckTruth(value) if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) {
return true
}
} }
if value, ok := db.Get("gorm:association_save_reference"); ok { return false
saveRef = utils.CheckTruth(value)
}
return creatable, updatable, saveRef
} }

View File

@ -37,11 +37,16 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo
} }
if stmt.Schema != nil { if stmt.Schema != nil {
for _, field := range stmt.Schema.FieldsByDBName { for _, field := range stmt.Schema.Fields {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable { if requireCreate && !field.Creatable {
results[field.DBName] = false results[name] = false
} else if requireUpdate && !field.Updatable { } else if requireUpdate && !field.Updatable {
results[field.DBName] = false results[name] = false
} }
} }
} }

View File

@ -161,12 +161,13 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
} }
// AddError add error to db // AddError add error to db
func (db *DB) AddError(err error) { func (db *DB) AddError(err error) error {
if db.Error == nil { if db.Error == nil {
db.Error = err db.Error = err
} else if err != nil { } else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err) db.Error = fmt.Errorf("%v; %w", db.Error, err)
} }
return db.Error
} }
func (db *DB) getInstance() *DB { func (db *DB) getInstance() *DB {