Save associations based on creatable/updatable permission, close #4056

This commit is contained in:
Jinzhu 2021-02-07 14:24:11 +08:00
parent 4373aa01ab
commit deff0594ee
3 changed files with 230 additions and 224 deletions

View File

@ -9,79 +9,81 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil { return func(db *gorm.DB) {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Belongs To associations // Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue continue
} }
setupReferences := func(obj reflect.Value, elem reflect.Value) { setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References { for _, ref := range rel.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem) pv, _ := ref.PrimaryKey.ValueOf(elem)
db.AddError(ref.ForeignKey.Set(obj, pv)) db.AddError(ref.ForeignKey.Set(obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv dest[ref.ForeignKey.DBName] = pv
if _, ok := dest[rel.Name]; ok { if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface() dest[rel.Name] = elem.Interface()
}
} }
} }
} }
} }
}
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var ( var (
objs []reflect.Value objs []reflect.Value
fieldType = rel.Field.FieldType fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr isPtr = fieldType.Kind() == reflect.Ptr
) )
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct { if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj) objs = append(objs, obj)
if isPtr { if isPtr {
elems = reflect.Append(elems, rv) elems = reflect.Append(elems, rv)
} else { } else {
elems = reflect.Append(elems, rv.Addr()) elems = reflect.Append(elems, rv.Addr())
}
}
} else {
break
}
}
if elems.Len() > 0 {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
} }
} }
} else {
break
} }
} case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
if elems.Len() > 0 { rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { if rv.Kind() != reflect.Ptr {
for i := 0; i < elems.Len(); i++ { rv = rv.Addr()
setupReferences(objs[i], elems.Index(i))
} }
}
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv) setupReferences(db.Statement.ReflectValue, rv)
}
} }
} }
} }
@ -89,53 +91,133 @@ func SaveBeforeAssociations(db *gorm.DB) {
} }
} }
func SaveAfterAssociations(db *gorm.DB) { func SaveAfterAssociations(create bool) func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil { return func(db *gorm.DB) {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Has One associations // Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne { for _, rel := range db.Statement.Schema.Relationships.HasOne {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue continue
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
}
}
elems = reflect.Append(elems, rv)
}
}
}
if elems.Len() > 0 {
assignmentColumns := []string{}
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
assignmentColumns := []string{}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue)
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
}
}
} }
switch db.Statement.ReflectValue.Kind() { // Save Has Many associations
case reflect.Slice, reflect.Array: for _, rel := range db.Statement.Schema.Relationships.HasMany {
var ( if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
fieldType = rel.Field.FieldType continue
isPtr = fieldType.Kind() == reflect.Ptr }
)
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr { if !isPtr {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < f.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i) elem := f.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) pv, _ := ref.PrimaryKey.ValueOf(v)
db.AddError(ref.ForeignKey.Set(rv, fv)) ref.ForeignKey.Set(elem, pv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) ref.ForeignKey.Set(elem, ref.PrimaryValue)
} }
} }
elems = reflect.Append(elems, rv) if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
} }
} }
} }
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 { if elems.Len() > 0 {
assignmentColumns := []string{} assignmentColumns := []string{}
for _, ref := range rel.References { for _, ref := range rel.References {
@ -144,162 +226,84 @@ func SaveAfterAssociations(db *gorm.DB) {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
} }
case reflect.Struct: }
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
assignmentColumns := []string{} // Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(f, fv) ref.ForeignKey.Set(joinValue, fv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue) ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
} }
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
joins = reflect.Append(joins, joinValue)
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
} }
}
}
// Save Has Many associations appendToElems := func(v reflect.Value) {
for _, rel := range db.Statement.Schema.Relationships.HasMany { if _, zero := rel.Field.ValueOf(v); !zero {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { f := reflect.Indirect(rel.Field.ReflectValueOf(v))
continue
}
fieldType := rel.Field.IndirectFieldType.Elem() for i := 0; i < f.Len(); i++ {
isPtr := fieldType.Kind() == reflect.Ptr elem := f.Index(i)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ { objs = append(objs, v)
elem := f.Index(i) if isPtr {
for _, ref := range rel.References { elems = reflect.Append(elems, elem)
if ref.OwnPrimaryKey { } else {
pv, _ := ref.PrimaryKey.ValueOf(v) elems = reflect.Append(elems, elem.Addr())
ref.ForeignKey.Set(elem, pv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
} }
} }
}
}
if isPtr { switch db.Statement.ReflectValue.Kind() {
elems = reflect.Append(elems, elem) case reflect.Slice, reflect.Array:
} else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
elems = reflect.Append(elems, elem.Addr()) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
} }
} }
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
} }
}
switch db.Statement.ReflectValue.Kind() { if elems.Len() > 0 {
case reflect.Slice, reflect.Array: if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
obj := db.Statement.ReflectValue.Index(i) }
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj) for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
} }
} }
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 { if joins.Len() > 0 {
assignmentColumns := []string{} db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
for _, ref := range rel.References { SkipHooks: db.Statement.SkipHooks,
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) DisableNestedTransaction: true,
}).Create(joins.Interface()).Error)
} }
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
}
}
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(joinValue, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
}
}
joins = reflect.Append(joins, joinValue)
}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
objs = append(objs, v)
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
}
for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
if joins.Len() > 0 {
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
}).Create(joins.Interface()).Error)
} }
} }
} }

View File

@ -17,9 +17,9 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback := db.Callback().Create() createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config)) createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate) createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
@ -40,9 +40,9 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update) updateCallback.Register("gorm:update", Update)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)

View File

@ -235,6 +235,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil { if schema.parseRelation(field); schema.err != nil {
return schema, schema.err return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
} }
} }