gorm/callbacks/associations.go

300 lines
8.2 KiB
Go
Raw Normal View History

package callbacks
import (
"reflect"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/utils"
)
func SaveBeforeAssociations(db *gorm.DB) {
if db.Statement.Schema != nil {
2020-04-19 18:11:56 +03:00
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
2020-04-17 03:23:47 +03:00
// Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
2020-04-19 18:11:56 +03:00
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
2020-04-17 03:23:47 +03:00
continue
}
2020-04-20 06:47:29 +03:00
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(obj, pv)
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
2020-04-19 18:11:56 +03:00
var (
objs []reflect.Value
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
2020-04-19 18:11:56 +03:00
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
} else {
2020-04-20 06:47:29 +03:00
setupReferences(obj, rv)
}
2020-04-19 18:11:56 +03:00
}
}
if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil {
for i := 0; i < elems.Len(); i++ {
2020-04-20 06:47:29 +03:00
setupReferences(objs[i], elems.Index(i))
}
}
2020-04-19 18:11:56 +03:00
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
2020-04-20 06:47:29 +03:00
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
2020-04-20 06:47:29 +03:00
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Session(&gorm.Session{}).Create(rv.Interface())
setupReferences(db.Statement.ReflectValue, rv)
}
}
}
}
}
}
2020-04-17 03:23:47 +03:00
func SaveAfterAssociations(db *gorm.DB) {
2020-04-19 18:11:56 +03:00
if db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
// Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
2020-04-17 03:23:47 +03:00
2020-04-19 18:11:56 +03:00
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
2020-04-20 06:47:29 +03:00
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
2020-04-19 18:11:56 +03:00
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(rv, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(rv, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
2020-04-20 06:47:29 +03:00
elems = reflect.Append(elems, rv)
2020-04-19 18:11:56 +03:00
}
}
}
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
2020-04-20 06:47:29 +03:00
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
2020-04-17 03:23:47 +03:00
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv)
2020-04-17 03:40:07 +03:00
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue)
2020-04-17 03:23:47 +03:00
}
}
2020-04-19 18:11:56 +03:00
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero {
2020-04-20 06:47:29 +03:00
db.Session(&gorm.Session{}).Create(f.Interface())
2020-04-19 18:11:56 +03:00
}
2020-04-17 03:23:47 +03:00
}
2020-04-19 18:11:56 +03:00
}
}
2020-04-17 03:23:47 +03:00
2020-04-19 18:11:56 +03:00
// Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
2020-04-17 03:23:47 +03:00
2020-04-19 18:11:56 +03:00
fieldType := rel.Field.IndirectFieldType.Elem()
2020-04-20 06:47:29 +03:00
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
2020-04-19 18:11:56 +03:00
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v)
ref.ForeignKey.Set(elem, pv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero {
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
2020-04-17 03:23:47 +03:00
}
}
}
2020-04-19 09:29:31 +03:00
2020-04-19 18:11:56 +03:00
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
2020-04-19 09:29:31 +03:00
2020-04-19 18:11:56 +03:00
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
}
2020-04-19 09:29:31 +03:00
}
2020-04-19 18:11:56 +03:00
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
2020-04-19 09:29:31 +03:00
}
2020-04-19 18:11:56 +03:00
fieldType := rel.Field.IndirectFieldType.Elem()
2020-04-20 06:47:29 +03:00
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
2020-04-19 18:11:56 +03:00
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
2020-04-20 18:35:18 +03:00
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 0)
2020-04-20 06:47:29 +03:00
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(joinValue, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
}
2020-04-19 18:11:56 +03:00
}
2020-04-20 06:47:29 +03:00
joins = reflect.Append(joins, joinValue)
}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
2020-04-19 18:11:56 +03:00
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
2020-04-19 09:29:31 +03:00
2020-04-19 18:11:56 +03:00
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero {
2020-04-20 06:47:29 +03:00
objs = append(objs, v)
2020-04-19 18:11:56 +03:00
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
2020-04-20 06:47:29 +03:00
} else {
appendToJoins(v, elem)
2020-04-19 09:29:31 +03:00
}
}
}
}
2020-04-20 06:47:29 +03:00
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
2020-04-19 18:11:56 +03:00
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
2020-04-20 06:47:29 +03:00
for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
if joins.Len() > 0 {
db.Session(&gorm.Session{}).Create(joins.Interface())
2020-04-19 18:11:56 +03:00
}
2020-04-19 09:29:31 +03:00
}
}
2020-04-17 03:23:47 +03:00
}
2020-04-19 18:11:56 +03:00
func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool {
savable := true
if value, ok := db.Get("gorm:save_association"); ok {
savable = utils.CheckTruth(value)
}
2020-04-19 18:11:56 +03:00
if savable {
if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) {
return true
}
}
2020-04-19 18:11:56 +03:00
return false
}