From 0daaf1747cfa4e4850376ad50a7834fb78b0cc0e Mon Sep 17 00:00:00 2001 From: abhijeet45 Date: Thu, 22 Aug 2024 16:33:42 +0530 Subject: [PATCH] fix: AfterQuery using safer right trim while clearing from clause's join added as part of https://github.com/go-gorm/gorm/pull/7027 (#7153) Co-authored-by: Abhijeet Bhowmik --- callbacks/query.go | 2 +- utils/utils.go | 11 ++++++++ utils/utils_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index 9b2b17ea..bbf238a9 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -288,7 +288,7 @@ func AfterQuery(db *gorm.DB) { // clear the joins after query because preload need it if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { fromClause := db.Statement.Clauses["FROM"] - fromClause.Expression = clause.From{Tables: v.Tables, Joins: v.Joins[:len(v.Joins)-len(db.Statement.Joins)]} // keep the original From Joins + fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins db.Statement.Clauses["FROM"] = fromClause } if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { diff --git a/utils/utils.go b/utils/utils.go index b8d30b35..fc615d73 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -166,3 +166,14 @@ func SplitNestedRelationName(name string) []string { func JoinNestedRelationNames(relationNames []string) string { return strings.Join(relationNames, nestedRelationSplit) } + +// RTrimSlice Right trims the given slice by given length +func RTrimSlice[T any](v []T, trimLen int) []T { + if trimLen >= len(v) { // trimLen greater than slice len means fully sliced + return v[:0] + } + if trimLen < 0 { // negative trimLen is ignored + return v[:] + } + return v[:len(v)-trimLen] +} diff --git a/utils/utils_test.go b/utils/utils_test.go index 8ff42af8..089cc4c8 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -138,3 +138,64 @@ func TestToString(t *testing.T) { }) } } + +func TestRTrimSlice(t *testing.T) { + tests := []struct { + name string + input []int + trimLen int + expected []int + }{ + { + name: "Trim two elements from end", + input: []int{1, 2, 3, 4, 5}, + trimLen: 2, + expected: []int{1, 2, 3}, + }, + { + name: "Trim entire slice", + input: []int{1, 2, 3}, + trimLen: 3, + expected: []int{}, + }, + { + name: "Trim length greater than slice length", + input: []int{1, 2, 3}, + trimLen: 5, + expected: []int{}, + }, + { + name: "Zero trim length", + input: []int{1, 2, 3}, + trimLen: 0, + expected: []int{1, 2, 3}, + }, + { + name: "Trim one element from end", + input: []int{1, 2, 3}, + trimLen: 1, + expected: []int{1, 2}, + }, + { + name: "Empty slice", + input: []int{}, + trimLen: 2, + expected: []int{}, + }, + { + name: "Negative trim length (should be treated as zero)", + input: []int{1, 2, 3}, + trimLen: -1, + expected: []int{1, 2, 3}, + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + result := RTrimSlice(testcase.input, testcase.trimLen) + if !AssertEqual(result, testcase.expected) { + t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected) + } + }) + } +}