From 7bcd95d4b882544c613fa3609a5fc91a0c0e2714 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 19 Apr 2020 23:11:56 +0800 Subject: [PATCH] Add save associations for bulk create --- callbacks/associations.go | 330 +++++++++++++++++++++++++------------- callbacks/helper.go | 11 +- gorm.go | 3 +- 3 files changed, 229 insertions(+), 115 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 6d976eac..8cc96029 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -10,41 +10,75 @@ import ( func SaveBeforeAssociations(db *gorm.DB) { if db.Statement.Schema != nil { + selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) - if !(creatable || updatable) { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { continue } switch db.Statement.ReflectValue.Kind() { case reflect.Slice: + var ( + objs []reflect.Value + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } else { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(rv) + ref.ForeignKey.Set(objs[i], pv) + } + } + } + } + } + + if elems.Len() > 0 { + if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { + for i := 0; i < elems.Len(); i++ { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elems.Index(i)) + ref.ForeignKey.Set(objs[i], pv) + } + } + } + } + } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) - - if isZero && creatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) + rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + if rv.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(rv.Interface()) } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + db.Session(&gorm.Session{}).Create(rv.Addr().Interface()) } - } else if !isZero && updatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Save(f.Interface()) - } else { - db.Session(&gorm.Session{}).Save(f.Addr().Interface()) - } - } else { - continue - } - if saveRef { for _, ref := range rel.References { if !ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(f) - ref.ForeignKey.Set(db.Statement.ReflectValue, fv) + pv, _ := ref.PrimaryKey.ValueOf(rv) + ref.ForeignKey.Set(db.Statement.ReflectValue, pv) } } } @@ -55,20 +89,58 @@ func SaveBeforeAssociations(db *gorm.DB) { } func SaveAfterAssociations(db *gorm.DB) { - // Save Has One associations - for _, rel := range db.Statement.Schema.Relationships.HasOne { - creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) - if !(creatable || updatable) { - continue - } + if db.Statement.Schema != nil { + selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if rv, zero := rel.Field.ValueOf(obj); !zero { + rv := reflect.ValueOf(rv) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(rv, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(rv, ref.PrimaryValue) + } + } + + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } + } + } + + if elems.Len() > 0 { + db.Session(&gorm.Session{}).Create(elems.Interface()) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if saveRef { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) @@ -77,98 +149,134 @@ func SaveAfterAssociations(db *gorm.DB) { ref.ForeignKey.Set(f, ref.PrimaryValue) } } - } - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) - - if isZero && creatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) - } - } else if !isZero && updatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Save(f.Interface()) - } else { - db.Session(&gorm.Session{}).Save(f.Addr().Interface()) - } - } else { - continue - } - } - } - } - - // Save Has Many associations - for _, rel := range db.Statement.Schema.Relationships.HasMany { - creatable, updatable, _ := saveAssociationCheck(db, rel.Field) - if !(creatable || updatable) { - continue - } - - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := true - if fieldType.Kind() != reflect.Ptr { - isPtr = false - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.ReflectValue.Index(i) - } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) - - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(elem, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) - } - } - - if isZero && creatable { - if isPtr { - elems = reflect.Append(elems, elem) + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) } else { - elems = reflect.Append(elems, elem.Addr()) + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) } } } } } - if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := true + if fieldType.Kind() != reflect.Ptr { + isPtr = false + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(v) + ref.ForeignKey.Set(elem, pv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(elem, ref.PrimaryValue) + } + } + + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + appendToElems(db.Statement.ReflectValue.Index(i)) + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + if elems.Len() > 0 { + db.Session(&gorm.Session{}).Create(elems.Interface()) + } + } + + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := true + if fieldType.Kind() != reflect.Ptr { + isPtr = false + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.ReflectValue.Index(i) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(elem, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(elem, ref.PrimaryValue) + } + } + + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + } + + if elems.Len() > 0 { + db.Session(&gorm.Session{}).Create(elems.Interface()) + } } } } -func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { - creatable := field.Creatable - updatable := field.Updatable - saveRef := true - - if value, ok := db.Get("gorm:association_autocreate"); creatable && ok { - creatable = utils.CheckTruth(value) +func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool { + savable := true + if value, ok := db.Get("gorm:save_association"); ok { + savable = utils.CheckTruth(value) } - if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok { - updatable = utils.CheckTruth(value) + if savable { + if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) { + return true + } } - if value, ok := db.Get("gorm:association_save_reference"); ok { - saveRef = utils.CheckTruth(value) - } - - return creatable, updatable, saveRef + return false } diff --git a/callbacks/helper.go b/callbacks/helper.go index 8a69fbd1..092c9c37 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -37,11 +37,16 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo } if stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, field := range stmt.Schema.Fields { + name := field.DBName + if name == "" { + name = field.Name + } + if requireCreate && !field.Creatable { - results[field.DBName] = false + results[name] = false } else if requireUpdate && !field.Updatable { - results[field.DBName] = false + results[name] = false } } } diff --git a/gorm.go b/gorm.go index 2d78c8d9..f8c944af 100644 --- a/gorm.go +++ b/gorm.go @@ -161,12 +161,13 @@ func (db *DB) AutoMigrate(dst ...interface{}) error { } // AddError add error to db -func (db *DB) AddError(err error) { +func (db *DB) AddError(err error) error { if db.Error == nil { db.Error = err } else if err != nil { db.Error = fmt.Errorf("%v; %w", db.Error, err) } + return db.Error } func (db *DB) getInstance() *DB {