From 98869cb124aa89ab5e7514dc54bd5039c4671138 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 5 Nov 2013 07:46:06 +0800 Subject: [PATCH] Add Method Related to find associations --- README.md | 13 +++++++++++++ chain.go | 6 ++++++ do.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ gorm_test.go | 37 +++++++++++++++++++++++++++++++++---- model.go | 32 +++++++++++++++++++++++++++----- 5 files changed, 125 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 4f354620..bfc8c492 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,19 @@ db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users) db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users) ``` +### Query With Related + +```go +db.Model(&user).Related(&emails) +//// SELECT * FROM emails WHERE user_id = 111; + +db.Model(&user).Related(&address1, "BillingAddressId") +//// SELECT * FROM addresses WHERE id = 123; // 123 is the value of user's BillingAddressId + +db.Model(&email).Related(&user) +//// SELECT * FROM users WHERE id = 111; // 111 is the value of email's UserId +``` + ### Query Chains Gorm has a chainable API, so you could query like this diff --git a/chain.go b/chain.go index 65296ce7..79a8db0e 100644 --- a/chain.go +++ b/chain.go @@ -253,6 +253,12 @@ func (s *Chain) Table(name string) *Chain { return s } +func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain { + original_value := s.value + s.do(value).related(original_value, foreign_keys...) + return s +} + func (s *Chain) Debug() *Chain { s.debug = true return s diff --git a/do.go b/do.go index 3e0a215e..694b5868 100644 --- a/do.go +++ b/do.go @@ -303,6 +303,52 @@ func (s *Do) last() { s.query() } +func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) { + if has_column, is_slice, value := from.ColumnAndValue(foreign_key); has_column { + from_from = true + if is_slice { + foreign_value = to.primaryKeyValue() + } else { + foreign_value = value + } + } else if has_column, is_slice, value := to.ColumnAndValue(foreign_key); has_column { + if is_slice { + foreign_value = from.primaryKeyValue() + } else { + foreign_value = value + } + } else { + err = errors.New("Can't find valid foreign Key") + } + return +} + +func (s *Do) related(value interface{}, foreign_keys ...string) { + var foreign_value interface{} + var from_from bool + var foreign_key string + var err error + + from := &Model{data: value, driver: s.driver} + to := &Model{data: s.value, driver: s.driver} + foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id") + + for _, fk := range foreign_keys { + err, from_from, foreign_value = s.getForeignKey(from, to, snakeToUpperCamel(fk)) + if err == nil { + foreign_key = fk + break + } + } + + if from_from { + s.where(foreign_value).query() + } else { + query := fmt.Sprintf("%v = %v", toSnake(foreign_key), foreign_value) + s.where(query).query() + } +} + func (s *Do) query() { var ( is_slice bool diff --git a/gorm_test.go b/gorm_test.go index 1e6261e1..69f5d2c9 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -83,7 +83,10 @@ func init() { if err != nil { fmt.Printf("Got error when try to delete table users, %+v\n", err) } - db.Exec("drop table products") + + db.Exec("drop table products;") + db.Exec("drop table emails;") + db.Exec("drop table addresses") err = db.CreateTable(&User{}).Error if err != nil { @@ -95,6 +98,16 @@ func init() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } + err = db.CreateTable(Email{}).Error + if err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + + err = db.CreateTable(Address{}).Error + if err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + var shortForm = "2006-01-02 15:04:05" t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00") @@ -1115,7 +1128,7 @@ func TestSubStruct(t *testing.T) { db.Save(&comment3) } -func TestForReadme(t *testing.T) { +func TestRelated(t *testing.T) { user := User{ Name: "jinzhu", BillingAddress: Address{Address1: "Billing Address - Address 1"}, @@ -1124,6 +1137,22 @@ func TestForReadme(t *testing.T) { } db.Save(&user) - var users []User - db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users) + + var emails []Email + db.Model(&user).Related(&emails) + if len(emails) != 2 { + t.Errorf("Should have two emails") + } + + var address1 Address + db.Model(&user).Related(&address1, "BillingAddressId") + if address1.Address1 != "Billing Address - Address 1" { + t.Errorf("Should get billing address from user correctly") + } + + var user2 User + db.Model(&emails[0]).Related(&user2) + if user2.Id != user.Id || user2.Name != user.Name { + t.Errorf("Should get user from email correctly") + } } diff --git a/model.go b/model.go index 1b578734..4d02f980 100644 --- a/model.go +++ b/model.go @@ -242,18 +242,40 @@ func (m *Model) hasColumn(name string) bool { } } +func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { + if m.data == nil { + return + } + + data := reflect.Indirect(reflect.ValueOf(m.data)) + if data.Kind() == reflect.Slice { + has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() + is_slice = true + } else { + if has_column = data.FieldByName(name).IsValid(); has_column { + value = data.FieldByName(name).Interface() + } + } + return +} + +func (m *Model) typeName() string { + typ := reflect.Indirect(reflect.ValueOf(m.data)).Type() + if typ.Kind() == reflect.Slice { + typ = typ.Elem() + } + + return typ.Name() +} + func (m *Model) tableName() (str string, err error) { if m.data == nil { err = errors.New("Model haven't been set") return } - typ := reflect.Indirect(reflect.ValueOf(m.data)).Type() - if typ.Kind() == reflect.Slice { - typ = typ.Elem() - } + str = toSnake(m.typeName()) - str = toSnake(typ.Name()) pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"} for key, value := range pluralMap { reg := regexp.MustCompile(key + "$")