From deff0594eee29ae94d66ae476771522252f5b6a7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Feb 2021 14:24:11 +0800 Subject: [PATCH] Save associations based on creatable/updatable permission, close #4056 --- callbacks/associations.go | 444 +++++++++++++++++++------------------- callbacks/callbacks.go | 8 +- schema/schema.go | 2 + 3 files changed, 230 insertions(+), 224 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 7b01247e..28c769e7 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -9,79 +9,81 @@ import ( "gorm.io/gorm/schema" ) -func SaveBeforeAssociations(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) +func SaveBeforeAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) - // Save Belongs To associations - for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } + // Save Belongs To associations + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } - setupReferences := func(obj reflect.Value, elem reflect.Value) { - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elem) - db.AddError(ref.ForeignKey.Set(obj, pv)) + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elem) + db.AddError(ref.ForeignKey.Set(obj, pv)) - if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { - dest[ref.ForeignKey.DBName] = pv - if _, ok := dest[rel.Name]; ok { - dest[rel.Name] = elem.Interface() + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } } } } } - } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - objs []reflect.Value - fieldType = rel.Field.FieldType - isPtr = fieldType.Kind() == reflect.Ptr - ) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + objs []reflect.Value + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } + } else { + break + } + } + + if elems.Len() > 0 { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + for i := 0; i < elems.Len(); i++ { + setupReferences(objs[i], elems.Index(i)) } } - } else { - break } - } - - if elems.Len() > 0 { - if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { - for i := 0; i < elems.Len(); i++ { - setupReferences(objs[i], elems.Index(i)) + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() } - } - } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { - setupReferences(db.Statement.ReflectValue, rv) + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + setupReferences(db.Statement.ReflectValue, rv) + } } } } @@ -89,53 +91,133 @@ func SaveBeforeAssociations(db *gorm.DB) { } } -func SaveAfterAssociations(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) +func SaveAfterAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) - // Save Has One associations - for _, rel := range db.Statement.Schema.Relationships.HasOne { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } + } + } + + if elems.Len() > 0 { + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } + + assignmentColumns := []string{} + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(f, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(f, ref.PrimaryValue) + } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + } + } } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - fieldType = rel.Field.FieldType - isPtr = fieldType.Kind() == reflect.Ptr - ) + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr if !isPtr { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) + pv, _ := ref.PrimaryKey.ValueOf(v) + ref.ForeignKey.Set(elem, pv) } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + ref.ForeignKey.Set(elem, ref.PrimaryValue) } } - elems = reflect.Append(elems, rv) + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } } } } + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + if elems.Len() > 0 { assignmentColumns := []string{} for _, ref := range rel.References { @@ -144,162 +226,84 @@ func SaveAfterAssociations(db *gorm.DB) { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if f.Kind() != reflect.Ptr { - f = f.Addr() - } + } - assignmentColumns := []string{} + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) + objs := []reflect.Value{} + + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(f, fv) + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(joinValue, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(f, ref.PrimaryValue) + ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + } else { + fv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(joinValue, fv) } - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - - saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + joins = reflect.Append(joins, joinValue) } - } - } - // Save Has Many associations - for _, rel := range db.Statement.Schema.Relationships.HasMany { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := fieldType.Kind() == reflect.Ptr - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(v) - ref.ForeignKey.Set(elem, pv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) + objs = append(objs, v) + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) } } + } + } - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) } } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) } - } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - appendToElems(obj) + if elems.Len() > 0 { + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + } + + for i := 0; i < elems.Len(); i++ { + appendToJoins(objs[i], elems.Index(i)) } } - case reflect.Struct: - appendToElems(db.Statement.ReflectValue) - } - if elems.Len() > 0 { - assignmentColumns := []string{} - for _, ref := range rel.References { - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + if joins.Len() > 0 { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }).Create(joins.Interface()).Error) } - - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) - } - } - - // Save Many2Many associations - for _, rel := range db.Statement.Schema.Relationships.Many2Many { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } - - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := fieldType.Kind() == reflect.Ptr - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) - objs := []reflect.Value{} - - appendToJoins := func(obj reflect.Value, elem reflect.Value) { - joinValue := reflect.New(rel.JoinTable.ModelType) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(joinValue, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(joinValue, ref.PrimaryValue) - } else { - fv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(joinValue, fv) - } - } - joins = reflect.Append(joins, joinValue) - } - - appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) - - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) - - objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) - } - } - } - } - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - appendToElems(obj) - } - } - case reflect.Struct: - appendToElems(db.Statement.ReflectValue) - } - - if elems.Len() > 0 { - if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) - } - - for i := 0; i < elems.Len(); i++ { - appendToJoins(objs[i], elems.Index(i)) - } - } - - if joins.Len() > 0 { - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ - SkipHooks: db.Statement.SkipHooks, - DisableNestedTransaction: true, - }).Create(joins.Interface()).Error) } } } diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index dda4b046..7bb27318 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -17,9 +17,9 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback := db.Callback().Create() createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) - createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register("gorm:create", Create(config)) - createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) @@ -40,9 +40,9 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) - updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:update", Update) - updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/schema/schema.go b/schema/schema.go index e36ed7b6..d08842e6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -235,6 +235,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if schema.parseRelation(field); schema.err != nil { return schema, schema.err + } else { + schema.FieldsByName[field.Name] = field } }