forked from mirror/gorm
improve association
This commit is contained in:
parent
efd3b7678b
commit
e1e68debdc
134
association.go
134
association.go
|
@ -8,53 +8,92 @@ import (
|
||||||
|
|
||||||
type Association struct {
|
type Association struct {
|
||||||
Scope *Scope
|
Scope *Scope
|
||||||
|
PrimaryKey interface{}
|
||||||
Column string
|
Column string
|
||||||
Error error
|
Error error
|
||||||
|
Field *Field
|
||||||
|
}
|
||||||
|
|
||||||
|
func (association *Association) err(err error) *Association {
|
||||||
|
if err != nil {
|
||||||
|
association.Error = err
|
||||||
|
}
|
||||||
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Find(value interface{}) *Association {
|
func (association *Association) Find(value interface{}) *Association {
|
||||||
|
association.Scope.related(value, association.Column)
|
||||||
|
return association.err(association.Scope.db.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (association *Association) Append(values ...interface{}) *Association {
|
||||||
scope := association.Scope
|
scope := association.Scope
|
||||||
primaryKey := scope.PrimaryKeyValue()
|
field := scope.IndirectValue().FieldByName(association.Column)
|
||||||
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
|
for _, value := range values {
|
||||||
association.Error = errors.New("primary key can't be nil")
|
reflectvalue := reflect.ValueOf(value)
|
||||||
|
if reflectvalue.Kind() == reflect.Ptr {
|
||||||
|
if reflectvalue.Elem().Kind() == reflect.Struct {
|
||||||
|
if field.Type().Elem().Kind() == reflect.Ptr {
|
||||||
|
field.Set(reflect.Append(field, reflectvalue))
|
||||||
|
} else if field.Type().Elem().Kind() == reflect.Struct {
|
||||||
|
field.Set(reflect.Append(field, reflectvalue.Elem()))
|
||||||
}
|
}
|
||||||
|
} else if reflectvalue.Elem().Kind() == reflect.Slice {
|
||||||
|
if field.Type().Elem().Kind() == reflect.Ptr {
|
||||||
|
field.Set(reflect.AppendSlice(field, reflectvalue))
|
||||||
|
} else if field.Type().Elem().Kind() == reflect.Struct {
|
||||||
|
field.Set(reflect.AppendSlice(field, reflectvalue.Elem()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if reflectvalue.Kind() == reflect.Struct && field.Type().Elem().Kind() == reflect.Struct {
|
||||||
|
field.Set(reflect.Append(field, reflectvalue))
|
||||||
|
} else if reflectvalue.Kind() == reflect.Slice && field.Type().Elem() == reflectvalue.Type().Elem() {
|
||||||
|
field.Set(reflect.AppendSlice(field, reflectvalue))
|
||||||
|
} else {
|
||||||
|
association.err(errors.New("invalid association type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scope.callCallbacks(scope.db.parent.callback.updates)
|
||||||
|
return association.err(scope.db.Error)
|
||||||
|
}
|
||||||
|
|
||||||
scopeType := scope.IndirectValue().Type()
|
func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(association.Column)); ok {
|
primaryKeys := []interface{}{}
|
||||||
field := scope.fieldFromStruct(f)
|
scope := association.Scope
|
||||||
joinTable := field.JoinTable
|
for _, value := range values {
|
||||||
if joinTable != nil && joinTable.foreignKey != "" {
|
reflectValue := reflect.ValueOf(value)
|
||||||
if joinTable.joinTable != "" {
|
if reflectValue.Kind() == reflect.Ptr {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
if reflectValue.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
newScope := scope.New(reflectValue.Index(i).Interface())
|
||||||
|
primaryKey := newScope.PrimaryKeyValue()
|
||||||
|
if !reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
|
||||||
|
primaryKeys = append(primaryKeys, primaryKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if reflectValue.Kind() == reflect.Struct {
|
||||||
newScope := scope.New(value)
|
newScope := scope.New(value)
|
||||||
joinSql := fmt.Sprintf(
|
primaryKey := newScope.PrimaryKeyValue()
|
||||||
"INNER JOIN %v ON %v.%v = %v.%v",
|
if !reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
|
||||||
scope.Quote(joinTable.joinTable),
|
primaryKeys = append(primaryKeys, primaryKey)
|
||||||
scope.Quote(joinTable.joinTable),
|
|
||||||
scope.Quote(ToSnake(joinTable.associationForeignKey)),
|
|
||||||
newScope.QuotedTableName(),
|
|
||||||
scope.Quote(newScope.PrimaryKey()))
|
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(joinTable.joinTable), scope.Quote(ToSnake(joinTable.foreignKey)))
|
|
||||||
scope.db.Joins(joinSql).Where(whereSql, primaryKey).Find(value)
|
|
||||||
} else {
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
association.Error = errors.New(fmt.Sprintf("invalid association %v for %v", association.Column, scopeType))
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
association.Error = errors.New(fmt.Sprintf("%v doesn't have column %v", scopeType, association.Column))
|
|
||||||
}
|
}
|
||||||
return association
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) Append(values interface{}) *Association {
|
if len(primaryKeys) == 0 {
|
||||||
return association
|
association.err(errors.New("no primary key found"))
|
||||||
}
|
} else {
|
||||||
|
joinTable := association.Field.JoinTable
|
||||||
func (association *Association) Delete(value interface{}) *Association {
|
// many to many
|
||||||
return association
|
if joinTable.joinTable != "" {
|
||||||
}
|
whereSql := fmt.Sprintf("%v.%v IN (?)", joinTable.joinTable, scope.Quote(ToSnake(joinTable.associationForeignKey)))
|
||||||
|
scope.db.Table(joinTable.joinTable).Where(whereSql, primaryKeys).Delete("")
|
||||||
func (association *Association) Clear(value interface{}) *Association {
|
} else {
|
||||||
|
association.err(errors.New("only many to many support delete"))
|
||||||
|
}
|
||||||
|
}
|
||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,6 +101,29 @@ func (association *Association) Replace(values interface{}) *Association {
|
||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Count(values interface{}) int {
|
func (association *Association) Clear(value interface{}) *Association {
|
||||||
return 0
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
func (association *Association) Count() (count int) {
|
||||||
|
joinTable := association.Field.JoinTable
|
||||||
|
scope := association.Scope
|
||||||
|
field := scope.IndirectValue().FieldByName(association.Column)
|
||||||
|
fieldValue := field.Interface()
|
||||||
|
|
||||||
|
// many to many
|
||||||
|
if joinTable.joinTable != "" {
|
||||||
|
newScope := scope.New(fieldValue)
|
||||||
|
whereSql := fmt.Sprintf("%v.%v IN (SELECT %v.%v FROM %v WHERE %v.%v = ?)",
|
||||||
|
newScope.QuotedTableName(),
|
||||||
|
scope.Quote(newScope.PrimaryKey()),
|
||||||
|
joinTable.joinTable,
|
||||||
|
scope.Quote(joinTable.associationForeignKey),
|
||||||
|
joinTable.joinTable,
|
||||||
|
joinTable.joinTable,
|
||||||
|
scope.Quote(joinTable.foreignKey))
|
||||||
|
scope.db.Table(newScope.QuotedTableName()).Where(whereSql, scope.PrimaryKey()).Count(&count)
|
||||||
|
}
|
||||||
|
// association.Scope.related(value, association.Column)
|
||||||
|
return -1
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,24 +127,69 @@ func TestRelated(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManyToMany(t *testing.T) {
|
func TestManyToMany(t *testing.T) {
|
||||||
var languages = []Language{{Name: "ZH"}, {Name: "EN"}, {Name: "DE"}}
|
db.Raw("delete from languages")
|
||||||
|
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
|
||||||
user := User{Name: "Many2Many", Languages: languages}
|
user := User{Name: "Many2Many", Languages: languages}
|
||||||
db.Save(&user)
|
db.Save(&user)
|
||||||
|
|
||||||
|
// Query
|
||||||
var newLanguages []Language
|
var newLanguages []Language
|
||||||
// db.Model(&user).Related(&newLanguages, "Languages")
|
db.Model(&user).Related(&newLanguages, "Languages")
|
||||||
// if len(newLanguages) != 3 {
|
if len(newLanguages) != len([]string{"ZH", "EN"}) {
|
||||||
// t.Errorf("Query many to many relations")
|
t.Errorf("Query many to many relations")
|
||||||
// }
|
}
|
||||||
|
|
||||||
newLanguages = []Language{}
|
newLanguages = []Language{}
|
||||||
db.Model(&user).Association("Languages").Find(&newLanguages)
|
db.Model(&user).Association("Languages").Find(&newLanguages)
|
||||||
if len(newLanguages) != 3 {
|
if len(newLanguages) != len([]string{"ZH", "EN"}) {
|
||||||
t.Errorf("Should be able to find many to many relations")
|
t.Errorf("Should be able to find many to many relations")
|
||||||
}
|
}
|
||||||
|
|
||||||
// db.Model(&User{}).Many2Many("Languages").Add(&Language{})
|
// Append
|
||||||
// db.Model(&User{}).Many2Many("Languages").Remove(&Language{})
|
db.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
|
||||||
|
if db.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
|
||||||
|
t.Errorf("New record should be saved when append")
|
||||||
|
}
|
||||||
|
|
||||||
|
languageA := Language{Name: "AA"}
|
||||||
|
db.Save(&languageA)
|
||||||
|
db.Model(&User{Id: user.Id}).Association("Languages").Append(languageA)
|
||||||
|
languageC := Language{Name: "CC"}
|
||||||
|
db.Save(&languageC)
|
||||||
|
db.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
|
||||||
|
db.Model(&User{Id: user.Id}).Association("Languages").Append([]Language{{Name: "DD"}, {Name: "EE"}})
|
||||||
|
|
||||||
|
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
|
||||||
|
|
||||||
|
newLanguages = []Language{}
|
||||||
|
db.Model(&user).Related(&newLanguages, "Languages")
|
||||||
|
if len(newLanguages) != len(totalLanguages) {
|
||||||
|
t.Errorf("All appended languages should be saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
var language Language
|
||||||
|
db.Where("name = ?", "EE").First(&language)
|
||||||
|
db.Model(&user).Association("Languages").Delete(language, &language)
|
||||||
|
|
||||||
|
newLanguages = []Language{}
|
||||||
|
db.Model(&user).Related(&newLanguages, "Languages")
|
||||||
|
if len(newLanguages) != len(totalLanguages)-1 {
|
||||||
|
t.Errorf("Relations should be deleted with Delete")
|
||||||
|
}
|
||||||
|
if db.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
|
||||||
|
t.Errorf("Language EE should not be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
languages = []Language{}
|
||||||
|
db.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
|
||||||
|
db.Model(&user).Association("Languages").Delete(languages, &languages)
|
||||||
|
newLanguages = []Language{}
|
||||||
|
db.Model(&user).Related(&newLanguages, "Languages")
|
||||||
|
if len(newLanguages) != len(totalLanguages)-3 {
|
||||||
|
t.Errorf("Relations should be deleted with Delete")
|
||||||
|
}
|
||||||
|
|
||||||
// db.Model(&User{}).Many2Many("Languages").Replace(&[]Language{})
|
// db.Model(&User{}).Many2Many("Languages").Replace(&[]Language{})
|
||||||
// db.Model(&User{}).Related(&[]Language{}, "Languages")
|
// db.Model(&User{}).Related(&[]Language{}, "Languages")
|
||||||
// SELECT `languages`.* FROM `languages` INNER JOIN `user_languages` ON `languages`.`id` = `user_languages`.`language_id` WHERE `user_languages`.`user_id` = 111
|
// SELECT `languages`.* FROM `languages` INNER JOIN `user_languages` ON `languages`.`id` = `user_languages`.`language_id` WHERE `user_languages`.`user_id` = 111
|
||||||
|
|
22
main.go
22
main.go
|
@ -2,6 +2,9 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
|
@ -353,5 +356,22 @@ func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
|
|
||||||
func (s *DB) Association(column string) *Association {
|
func (s *DB) Association(column string) *Association {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
return &Association{Scope: scope, Column: column}
|
|
||||||
|
primaryKey := scope.PrimaryKeyValue()
|
||||||
|
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
|
||||||
|
scope.Err(errors.New("primary key can't be nil"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var field *Field
|
||||||
|
scopeType := scope.IndirectValue().Type()
|
||||||
|
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok {
|
||||||
|
field = scope.fieldFromStruct(f)
|
||||||
|
if field.JoinTable == nil || field.JoinTable.foreignKey == "" {
|
||||||
|
scope.Err(errors.New(fmt.Sprintf("invalid association %v for %v", column, scopeType)))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(errors.New(fmt.Sprintf("%v doesn't have column %v", scopeType, column)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ func runMigration() {
|
||||||
db.Exec("drop table companies")
|
db.Exec("drop table companies")
|
||||||
db.Exec("drop table animals")
|
db.Exec("drop table animals")
|
||||||
db.Exec("drop table user_languages")
|
db.Exec("drop table user_languages")
|
||||||
|
db.Exec("drop table languages")
|
||||||
|
|
||||||
if err := db.CreateTable(&Animal{}).Error; err != nil {
|
if err := db.CreateTable(&Animal{}).Error; err != nil {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
|
|
6
scope.go
6
scope.go
|
@ -122,12 +122,12 @@ func (scope *Scope) FieldByName(name string) (interface{}, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetColumn to set the column's value
|
// SetColumn to set the column's value
|
||||||
func (scope *Scope) SetColumn(column string, value interface{}) {
|
func (scope *Scope) SetColumn(column string, value interface{}) bool {
|
||||||
if scope.Value == nil {
|
if scope.Value == nil {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value)
|
return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallMethod invoke method with necessary argument
|
// CallMethod invoke method with necessary argument
|
||||||
|
|
|
@ -416,13 +416,42 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
toScope := scope.db.NewScope(value)
|
toScope := scope.db.NewScope(value)
|
||||||
|
|
||||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||||
|
scopeType := scope.IndirectValue().Type()
|
||||||
|
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(foreignKey)); ok {
|
||||||
|
field := scope.fieldFromStruct(f)
|
||||||
|
joinTable := field.JoinTable
|
||||||
|
if joinTable != nil && joinTable.foreignKey != "" {
|
||||||
|
foreignKey = joinTable.foreignKey
|
||||||
|
|
||||||
|
// many to many relations
|
||||||
|
if joinTable.joinTable != "" {
|
||||||
|
joinSql := fmt.Sprintf(
|
||||||
|
"INNER JOIN %v ON %v.%v = %v.%v",
|
||||||
|
scope.Quote(joinTable.joinTable),
|
||||||
|
scope.Quote(joinTable.joinTable),
|
||||||
|
scope.Quote(ToSnake(joinTable.associationForeignKey)),
|
||||||
|
toScope.QuotedTableName(),
|
||||||
|
scope.Quote(toScope.PrimaryKey()))
|
||||||
|
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(joinTable.joinTable), scope.Quote(ToSnake(joinTable.foreignKey)))
|
||||||
|
toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// has one
|
||||||
if foreignValue, ok := scope.FieldByName(foreignKey); ok {
|
if foreignValue, ok := scope.FieldByName(foreignKey); ok {
|
||||||
return toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries)
|
toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries)
|
||||||
} else if toScope.HasColumn(foreignKey) {
|
return scope
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// has many
|
||||||
|
if toScope.HasColumn(foreignKey) {
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
|
sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
|
||||||
return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)
|
return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
scope.Err(errors.New(fmt.Sprintf("invalid association %v", foreignKeys)))
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue