diff --git a/association.go b/association.go index dcf5803f..542a27ca 100644 --- a/association.go +++ b/association.go @@ -7,54 +7,93 @@ import ( ) type Association struct { - Scope *Scope - Column string - Error error + Scope *Scope + PrimaryKey interface{} + Column string + Error error + Field *Field +} + +func (association *Association) err(err error) *Association { + if err != nil { + association.Error = err + } + return association } func (association *Association) Find(value interface{}) *Association { + association.Scope.related(value, association.Column) + return association.err(association.Scope.db.Error) +} + +func (association *Association) Append(values ...interface{}) *Association { scope := association.Scope - primaryKey := scope.PrimaryKeyValue() - if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { - association.Error = errors.New("primary key can't be nil") - } - - scopeType := scope.IndirectValue().Type() - if f, ok := scopeType.FieldByName(SnakeToUpperCamel(association.Column)); ok { - field := scope.fieldFromStruct(f) - joinTable := field.JoinTable - if joinTable != nil && joinTable.foreignKey != "" { - if joinTable.joinTable != "" { - newScope := scope.New(value) - joinSql := fmt.Sprintf( - "INNER JOIN %v ON %v.%v = %v.%v", - scope.Quote(joinTable.joinTable), - scope.Quote(joinTable.joinTable), - scope.Quote(ToSnake(joinTable.associationForeignKey)), - newScope.QuotedTableName(), - scope.Quote(newScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(joinTable.joinTable), scope.Quote(ToSnake(joinTable.foreignKey))) - scope.db.Joins(joinSql).Where(whereSql, primaryKey).Find(value) - } else { + field := scope.IndirectValue().FieldByName(association.Column) + for _, value := range values { + reflectvalue := reflect.ValueOf(value) + if reflectvalue.Kind() == reflect.Ptr { + if reflectvalue.Elem().Kind() == reflect.Struct { + if field.Type().Elem().Kind() == reflect.Ptr { + field.Set(reflect.Append(field, reflectvalue)) + } else if field.Type().Elem().Kind() == reflect.Struct { + field.Set(reflect.Append(field, reflectvalue.Elem())) + } + } else if reflectvalue.Elem().Kind() == reflect.Slice { + if field.Type().Elem().Kind() == reflect.Ptr { + field.Set(reflect.AppendSlice(field, reflectvalue)) + } else if field.Type().Elem().Kind() == reflect.Struct { + field.Set(reflect.AppendSlice(field, reflectvalue.Elem())) + } } + } else if reflectvalue.Kind() == reflect.Struct && field.Type().Elem().Kind() == reflect.Struct { + field.Set(reflect.Append(field, reflectvalue)) + } else if reflectvalue.Kind() == reflect.Slice && field.Type().Elem() == reflectvalue.Type().Elem() { + field.Set(reflect.AppendSlice(field, reflectvalue)) } else { - association.Error = errors.New(fmt.Sprintf("invalid association %v for %v", association.Column, scopeType)) + association.err(errors.New("invalid association type")) } - } else { - association.Error = errors.New(fmt.Sprintf("%v doesn't have column %v", scopeType, association.Column)) } - return association + scope.callCallbacks(scope.db.parent.callback.updates) + return association.err(scope.db.Error) } -func (association *Association) Append(values interface{}) *Association { - return association -} +func (association *Association) Delete(values ...interface{}) *Association { + primaryKeys := []interface{}{} + scope := association.Scope + for _, value := range values { + reflectValue := reflect.ValueOf(value) + if reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + if reflectValue.Kind() == reflect.Slice { + for i := 0; i < reflectValue.Len(); i++ { + newScope := scope.New(reflectValue.Index(i).Interface()) + primaryKey := newScope.PrimaryKeyValue() + if !reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { + primaryKeys = append(primaryKeys, primaryKey) + } + } + } else if reflectValue.Kind() == reflect.Struct { + newScope := scope.New(value) + primaryKey := newScope.PrimaryKeyValue() + if !reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { + primaryKeys = append(primaryKeys, primaryKey) + } + } + } -func (association *Association) Delete(value interface{}) *Association { - return association -} - -func (association *Association) Clear(value interface{}) *Association { + if len(primaryKeys) == 0 { + association.err(errors.New("no primary key found")) + } else { + joinTable := association.Field.JoinTable + // many to many + if joinTable.joinTable != "" { + whereSql := fmt.Sprintf("%v.%v IN (?)", joinTable.joinTable, scope.Quote(ToSnake(joinTable.associationForeignKey))) + scope.db.Table(joinTable.joinTable).Where(whereSql, primaryKeys).Delete("") + } else { + association.err(errors.New("only many to many support delete")) + } + } return association } @@ -62,6 +101,29 @@ func (association *Association) Replace(values interface{}) *Association { return association } -func (association *Association) Count(values interface{}) int { - return 0 +func (association *Association) Clear(value interface{}) *Association { + return association +} + +func (association *Association) Count() (count int) { + joinTable := association.Field.JoinTable + scope := association.Scope + field := scope.IndirectValue().FieldByName(association.Column) + fieldValue := field.Interface() + + // many to many + if joinTable.joinTable != "" { + newScope := scope.New(fieldValue) + whereSql := fmt.Sprintf("%v.%v IN (SELECT %v.%v FROM %v WHERE %v.%v = ?)", + newScope.QuotedTableName(), + scope.Quote(newScope.PrimaryKey()), + joinTable.joinTable, + scope.Quote(joinTable.associationForeignKey), + joinTable.joinTable, + joinTable.joinTable, + scope.Quote(joinTable.foreignKey)) + scope.db.Table(newScope.QuotedTableName()).Where(whereSql, scope.PrimaryKey()).Count(&count) + } + // association.Scope.related(value, association.Column) + return -1 } diff --git a/association_test.go b/association_test.go index a1b94b07..3c8f7f5c 100644 --- a/association_test.go +++ b/association_test.go @@ -127,24 +127,69 @@ func TestRelated(t *testing.T) { } func TestManyToMany(t *testing.T) { - var languages = []Language{{Name: "ZH"}, {Name: "EN"}, {Name: "DE"}} + db.Raw("delete from languages") + var languages = []Language{{Name: "ZH"}, {Name: "EN"}} user := User{Name: "Many2Many", Languages: languages} db.Save(&user) + // Query var newLanguages []Language - // db.Model(&user).Related(&newLanguages, "Languages") - // if len(newLanguages) != 3 { - // t.Errorf("Query many to many relations") - // } + db.Model(&user).Related(&newLanguages, "Languages") + if len(newLanguages) != len([]string{"ZH", "EN"}) { + t.Errorf("Query many to many relations") + } newLanguages = []Language{} db.Model(&user).Association("Languages").Find(&newLanguages) - if len(newLanguages) != 3 { + if len(newLanguages) != len([]string{"ZH", "EN"}) { t.Errorf("Should be able to find many to many relations") } - // db.Model(&User{}).Many2Many("Languages").Add(&Language{}) - // db.Model(&User{}).Many2Many("Languages").Remove(&Language{}) + // Append + db.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) + if db.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { + t.Errorf("New record should be saved when append") + } + + languageA := Language{Name: "AA"} + db.Save(&languageA) + db.Model(&User{Id: user.Id}).Association("Languages").Append(languageA) + languageC := Language{Name: "CC"} + db.Save(&languageC) + db.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) + db.Model(&User{Id: user.Id}).Association("Languages").Append([]Language{{Name: "DD"}, {Name: "EE"}}) + + totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} + + newLanguages = []Language{} + db.Model(&user).Related(&newLanguages, "Languages") + if len(newLanguages) != len(totalLanguages) { + t.Errorf("All appended languages should be saved") + } + + // Delete + var language Language + db.Where("name = ?", "EE").First(&language) + db.Model(&user).Association("Languages").Delete(language, &language) + + newLanguages = []Language{} + db.Model(&user).Related(&newLanguages, "Languages") + if len(newLanguages) != len(totalLanguages)-1 { + t.Errorf("Relations should be deleted with Delete") + } + if db.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { + t.Errorf("Language EE should not be deleted") + } + + languages = []Language{} + db.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) + db.Model(&user).Association("Languages").Delete(languages, &languages) + newLanguages = []Language{} + db.Model(&user).Related(&newLanguages, "Languages") + if len(newLanguages) != len(totalLanguages)-3 { + t.Errorf("Relations should be deleted with Delete") + } + // db.Model(&User{}).Many2Many("Languages").Replace(&[]Language{}) // db.Model(&User{}).Related(&[]Language{}, "Languages") // SELECT `languages`.* FROM `languages` INNER JOIN `user_languages` ON `languages`.`id` = `user_languages`.`language_id` WHERE `user_languages`.`user_id` = 111 diff --git a/main.go b/main.go index 4aef5885..e5e2fdb9 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,9 @@ package gorm import ( "database/sql" + "errors" + "fmt" + "reflect" ) type DB struct { @@ -353,5 +356,22 @@ func (s *DB) RemoveIndex(indexName string) *DB { func (s *DB) Association(column string) *Association { scope := s.clone().NewScope(s.Value) - return &Association{Scope: scope, Column: column} + + primaryKey := scope.PrimaryKeyValue() + if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { + scope.Err(errors.New("primary key can't be nil")) + } + + var field *Field + scopeType := scope.IndirectValue().Type() + if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok { + field = scope.fieldFromStruct(f) + if field.JoinTable == nil || field.JoinTable.foreignKey == "" { + scope.Err(errors.New(fmt.Sprintf("invalid association %v for %v", column, scopeType))) + } + } else { + scope.Err(errors.New(fmt.Sprintf("%v doesn't have column %v", scopeType, column))) + } + + return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field} } diff --git a/migration_test.go b/migration_test.go index 607668ad..008329ae 100644 --- a/migration_test.go +++ b/migration_test.go @@ -19,6 +19,7 @@ func runMigration() { db.Exec("drop table companies") db.Exec("drop table animals") db.Exec("drop table user_languages") + db.Exec("drop table languages") if err := db.CreateTable(&Animal{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) diff --git a/scope.go b/scope.go index cb1313c5..fb7ac207 100644 --- a/scope.go +++ b/scope.go @@ -122,12 +122,12 @@ func (scope *Scope) FieldByName(name string) (interface{}, bool) { } // SetColumn to set the column's value -func (scope *Scope) SetColumn(column string, value interface{}) { +func (scope *Scope) SetColumn(column string, value interface{}) bool { if scope.Value == nil { - return + return false } - setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value) + return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value) } // CallMethod invoke method with necessary argument diff --git a/scope_private.go b/scope_private.go index ae87dbcd..de7fa8fd 100644 --- a/scope_private.go +++ b/scope_private.go @@ -416,13 +416,42 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - if foreignValue, ok := scope.FieldByName(foreignKey); ok { - return toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) - } else if toScope.HasColumn(foreignKey) { + scopeType := scope.IndirectValue().Type() + if f, ok := scopeType.FieldByName(SnakeToUpperCamel(foreignKey)); ok { + field := scope.fieldFromStruct(f) + joinTable := field.JoinTable + if joinTable != nil && joinTable.foreignKey != "" { + foreignKey = joinTable.foreignKey + + // many to many relations + if joinTable.joinTable != "" { + joinSql := fmt.Sprintf( + "INNER JOIN %v ON %v.%v = %v.%v", + scope.Quote(joinTable.joinTable), + scope.Quote(joinTable.joinTable), + scope.Quote(ToSnake(joinTable.associationForeignKey)), + toScope.QuotedTableName(), + scope.Quote(toScope.PrimaryKey())) + whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(joinTable.joinTable), scope.Quote(ToSnake(joinTable.foreignKey))) + toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) + return scope + } + } + + // has one + if foreignValue, ok := scope.FieldByName(foreignKey); ok { + toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) + return scope + } + } + + // has many + if toScope.HasColumn(foreignKey) { sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))) return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) } } + scope.Err(errors.New(fmt.Sprintf("invalid association %v", foreignKeys))) return scope }