fix: AfterQuery using safer right trim while clearing from clause's join added as part of https://github.com/go-gorm/gorm/pull/7027 (#7153)

Co-authored-by: Abhijeet Bhowmik <abhijeet.bhowmik@cambiumnetworks.com>
This commit is contained in:
abhijeet45 2024-08-22 16:33:42 +05:30 committed by GitHub
parent 0dbfda5d7e
commit 0daaf1747c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 1 deletions

View File

@ -288,7 +288,7 @@ func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it // clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"] fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: v.Joins[:len(v.Joins)-len(db.Statement.Joins)]} // keep the original From Joins fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause db.Statement.Clauses["FROM"] = fromClause
} }
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {

View File

@ -166,3 +166,14 @@ func SplitNestedRelationName(name string) []string {
func JoinNestedRelationNames(relationNames []string) string { func JoinNestedRelationNames(relationNames []string) string {
return strings.Join(relationNames, nestedRelationSplit) return strings.Join(relationNames, nestedRelationSplit)
} }
// RTrimSlice Right trims the given slice by given length
func RTrimSlice[T any](v []T, trimLen int) []T {
if trimLen >= len(v) { // trimLen greater than slice len means fully sliced
return v[:0]
}
if trimLen < 0 { // negative trimLen is ignored
return v[:]
}
return v[:len(v)-trimLen]
}

View File

@ -138,3 +138,64 @@ func TestToString(t *testing.T) {
}) })
} }
} }
func TestRTrimSlice(t *testing.T) {
tests := []struct {
name string
input []int
trimLen int
expected []int
}{
{
name: "Trim two elements from end",
input: []int{1, 2, 3, 4, 5},
trimLen: 2,
expected: []int{1, 2, 3},
},
{
name: "Trim entire slice",
input: []int{1, 2, 3},
trimLen: 3,
expected: []int{},
},
{
name: "Trim length greater than slice length",
input: []int{1, 2, 3},
trimLen: 5,
expected: []int{},
},
{
name: "Zero trim length",
input: []int{1, 2, 3},
trimLen: 0,
expected: []int{1, 2, 3},
},
{
name: "Trim one element from end",
input: []int{1, 2, 3},
trimLen: 1,
expected: []int{1, 2},
},
{
name: "Empty slice",
input: []int{},
trimLen: 2,
expected: []int{},
},
{
name: "Negative trim length (should be treated as zero)",
input: []int{1, 2, 3},
trimLen: -1,
expected: []int{1, 2, 3},
},
}
for _, testcase := range tests {
t.Run(testcase.name, func(t *testing.T) {
result := RTrimSlice(testcase.input, testcase.trimLen)
if !AssertEqual(result, testcase.expected) {
t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected)
}
})
}
}