diff --git a/join_table_handler.go b/join_table_handler.go index ac909966..07ecee2e 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -13,6 +13,8 @@ type JoinTableHandlerInterface interface { Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + SourceForeignKeys() []JoinTableForeignKey + DestinationForeignKeys() []JoinTableForeignKey } type JoinTableForeignKey struct { @@ -31,6 +33,14 @@ type JoinTableHandler struct { Destination JoinTableSource `sql:"-"` } +func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { + return s.Source.ForeignKeys +} + +func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { + return s.Destination.ForeignKeys +} + func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { s.TableName = tableName diff --git a/scope.go b/scope.go index 11bad777..de1b6159 100644 --- a/scope.go +++ b/scope.go @@ -110,6 +110,14 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } +func (scope *Scope) PrimaryFields() []*Field { + var fields = []*Field{} + for _, field := range scope.GetModelStruct().PrimaryFields { + fields = append(fields, scope.Fields()[field.DBName]) + } + return fields +} + func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if len(primaryFields) > 1 {