From f999240e106552c62eef70d29d1da93d95f76a5b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 May 2020 20:54:50 +0800 Subject: [PATCH] Define association API, add conds to when preloading --- association.go | 54 +++++++++++++++++++++++++++++++++++++++++++- callbacks/preload.go | 10 ++++---- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/association.go b/association.go index 14bc54b6..a9345255 100644 --- a/association.go +++ b/association.go @@ -1,9 +1,61 @@ package gorm +import ( + "fmt" + + "github.com/jinzhu/gorm/schema" +) + // Association Mode contains some helper methods to handle relationship things easily. type Association struct { + DB *DB + Relationship *schema.Relationship + Error error } func (db *DB) Association(column string) *Association { - return nil + association := &Association{DB: db} + + if err := db.Statement.Parse(db.Statement.Model); err == nil { + association.Relationship = db.Statement.Schema.Relationships.Relations[column] + + if association.Relationship == nil { + association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + } + } else { + association.Error = err + } + + return association +} + +func (association *Association) Find(out interface{}, conds ...interface{}) error { + if association.Error == nil { + for _, ref := range association.Relationship.References { + if ref.OwnPrimaryKey { + } + } + } + + return association.Error +} + +func (association *Association) Append(values ...interface{}) error { + return association.Error +} + +func (association *Association) Replace(values ...interface{}) error { + return association.Error +} + +func (association *Association) Delete(values ...interface{}) error { + return association.Error +} + +func (association *Association) Clear() error { + return association.Error +} + +func (association *Association) Count() int { + return 0 } diff --git a/callbacks/preload.go b/callbacks/preload.go index 8ab014f6..aaac31b5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -84,7 +84,7 @@ func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Fiel return dataResults, results } -func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}) reflect.Value { +func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}, conds []interface{}) reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -94,12 +94,12 @@ func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, for idx, r := range foreignValues { queryValues[idx] = r[0] } - tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface()) + tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface(), conds...) } else { for idx, r := range foreignValues { queryValues[idx] = r } - tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface()) + tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface(), conds...) } return results.Elem() @@ -139,7 +139,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) - joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues) + joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues, nil) // convert join identity map to relation identity map fieldValues := make([]reflect.Value, len(foreignFields)) @@ -177,7 +177,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) } - reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues) + reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues, conds) fieldValues := make([]reflect.Value, len(foreignFields)) for i := 0; i < reflectResults.Len(); i++ {