mirror of https://github.com/go-gorm/gorm.git
Add Method Related to find associations
This commit is contained in:
parent
0b22775dd7
commit
98869cb124
13
README.md
13
README.md
|
@ -257,6 +257,19 @@ db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users)
|
|||
db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users)
|
||||
```
|
||||
|
||||
### Query With Related
|
||||
|
||||
```go
|
||||
db.Model(&user).Related(&emails)
|
||||
//// SELECT * FROM emails WHERE user_id = 111;
|
||||
|
||||
db.Model(&user).Related(&address1, "BillingAddressId")
|
||||
//// SELECT * FROM addresses WHERE id = 123; // 123 is the value of user's BillingAddressId
|
||||
|
||||
db.Model(&email).Related(&user)
|
||||
//// SELECT * FROM users WHERE id = 111; // 111 is the value of email's UserId
|
||||
```
|
||||
|
||||
### Query Chains
|
||||
|
||||
Gorm has a chainable API, so you could query like this
|
||||
|
|
6
chain.go
6
chain.go
|
@ -253,6 +253,12 @@ func (s *Chain) Table(name string) *Chain {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain {
|
||||
original_value := s.value
|
||||
s.do(value).related(original_value, foreign_keys...)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Debug() *Chain {
|
||||
s.debug = true
|
||||
return s
|
||||
|
|
46
do.go
46
do.go
|
@ -303,6 +303,52 @@ func (s *Do) last() {
|
|||
s.query()
|
||||
}
|
||||
|
||||
func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) {
|
||||
if has_column, is_slice, value := from.ColumnAndValue(foreign_key); has_column {
|
||||
from_from = true
|
||||
if is_slice {
|
||||
foreign_value = to.primaryKeyValue()
|
||||
} else {
|
||||
foreign_value = value
|
||||
}
|
||||
} else if has_column, is_slice, value := to.ColumnAndValue(foreign_key); has_column {
|
||||
if is_slice {
|
||||
foreign_value = from.primaryKeyValue()
|
||||
} else {
|
||||
foreign_value = value
|
||||
}
|
||||
} else {
|
||||
err = errors.New("Can't find valid foreign Key")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) related(value interface{}, foreign_keys ...string) {
|
||||
var foreign_value interface{}
|
||||
var from_from bool
|
||||
var foreign_key string
|
||||
var err error
|
||||
|
||||
from := &Model{data: value, driver: s.driver}
|
||||
to := &Model{data: s.value, driver: s.driver}
|
||||
foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id")
|
||||
|
||||
for _, fk := range foreign_keys {
|
||||
err, from_from, foreign_value = s.getForeignKey(from, to, snakeToUpperCamel(fk))
|
||||
if err == nil {
|
||||
foreign_key = fk
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if from_from {
|
||||
s.where(foreign_value).query()
|
||||
} else {
|
||||
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), foreign_value)
|
||||
s.where(query).query()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) query() {
|
||||
var (
|
||||
is_slice bool
|
||||
|
|
37
gorm_test.go
37
gorm_test.go
|
@ -83,7 +83,10 @@ func init() {
|
|||
if err != nil {
|
||||
fmt.Printf("Got error when try to delete table users, %+v\n", err)
|
||||
}
|
||||
db.Exec("drop table products")
|
||||
|
||||
db.Exec("drop table products;")
|
||||
db.Exec("drop table emails;")
|
||||
db.Exec("drop table addresses")
|
||||
|
||||
err = db.CreateTable(&User{}).Error
|
||||
if err != nil {
|
||||
|
@ -95,6 +98,16 @@ func init() {
|
|||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
err = db.CreateTable(Email{}).Error
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
err = db.CreateTable(Address{}).Error
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
var shortForm = "2006-01-02 15:04:05"
|
||||
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
||||
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
|
||||
|
@ -1115,7 +1128,7 @@ func TestSubStruct(t *testing.T) {
|
|||
db.Save(&comment3)
|
||||
}
|
||||
|
||||
func TestForReadme(t *testing.T) {
|
||||
func TestRelated(t *testing.T) {
|
||||
user := User{
|
||||
Name: "jinzhu",
|
||||
BillingAddress: Address{Address1: "Billing Address - Address 1"},
|
||||
|
@ -1124,6 +1137,22 @@ func TestForReadme(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Save(&user)
|
||||
var users []User
|
||||
db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users)
|
||||
|
||||
var emails []Email
|
||||
db.Model(&user).Related(&emails)
|
||||
if len(emails) != 2 {
|
||||
t.Errorf("Should have two emails")
|
||||
}
|
||||
|
||||
var address1 Address
|
||||
db.Model(&user).Related(&address1, "BillingAddressId")
|
||||
if address1.Address1 != "Billing Address - Address 1" {
|
||||
t.Errorf("Should get billing address from user correctly")
|
||||
}
|
||||
|
||||
var user2 User
|
||||
db.Model(&emails[0]).Related(&user2)
|
||||
if user2.Id != user.Id || user2.Name != user.Name {
|
||||
t.Errorf("Should get user from email correctly")
|
||||
}
|
||||
}
|
||||
|
|
32
model.go
32
model.go
|
@ -242,18 +242,40 @@ func (m *Model) hasColumn(name string) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
|
||||
if m.data == nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := reflect.Indirect(reflect.ValueOf(m.data))
|
||||
if data.Kind() == reflect.Slice {
|
||||
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
||||
is_slice = true
|
||||
} else {
|
||||
if has_column = data.FieldByName(name).IsValid(); has_column {
|
||||
value = data.FieldByName(name).Interface()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) typeName() string {
|
||||
typ := reflect.Indirect(reflect.ValueOf(m.data)).Type()
|
||||
if typ.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
return typ.Name()
|
||||
}
|
||||
|
||||
func (m *Model) tableName() (str string, err error) {
|
||||
if m.data == nil {
|
||||
err = errors.New("Model haven't been set")
|
||||
return
|
||||
}
|
||||
|
||||
typ := reflect.Indirect(reflect.ValueOf(m.data)).Type()
|
||||
if typ.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
str = toSnake(m.typeName())
|
||||
|
||||
str = toSnake(typ.Name())
|
||||
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
|
||||
for key, value := range pluralMap {
|
||||
reg := regexp.MustCompile(key + "$")
|
||||
|
|
Loading…
Reference in New Issue