diff --git a/clause/group_by.go b/clause/group_by.go index c1383c36..88231916 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -30,8 +30,13 @@ func (groupBy GroupBy) Build(builder Builder) { // MergeClause merge group by clause func (groupBy GroupBy) MergeClause(clause *Clause) { if v, ok := clause.Expression.(GroupBy); ok { - groupBy.Columns = append(v.Columns, groupBy.Columns...) - groupBy.Having = append(v.Having, groupBy.Having...) + copiedColumns := make([]Column, len(v.Columns)) + copy(copiedColumns, v.Columns) + groupBy.Columns = append(copiedColumns, groupBy.Columns...) + + copiedHaving := make([]Expression, len(v.Having)) + copy(copiedHaving, v.Having) + groupBy.Having = append(copiedHaving, groupBy.Having...) } clause.Expression = groupBy } diff --git a/clause/order_by.go b/clause/order_by.go index 307bf930..a8a9539a 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -40,7 +40,9 @@ func (orderBy OrderBy) MergeClause(clause *Clause) { } } - orderBy.Columns = append(v.Columns, orderBy.Columns...) + copiedColumns := make([]OrderByColumn, len(v.Columns)) + copy(copiedColumns, v.Columns) + orderBy.Columns = append(copiedColumns, orderBy.Columns...) } clause.Expression = orderBy diff --git a/clause/set.go b/clause/set.go index 7704ca36..2d3965d3 100644 --- a/clause/set.go +++ b/clause/set.go @@ -32,7 +32,9 @@ func (set Set) Build(builder Builder) { // MergeClause merge assignments clauses func (set Set) MergeClause(clause *Clause) { - clause.Expression = set + copiedAssignments := make([]Assignment, len(set)) + copy(copiedAssignments, set) + clause.Expression = Set(copiedAssignments) } func Assignments(values map[string]interface{}) Set { diff --git a/clause/where.go b/clause/where.go index 015addf8..806565d1 100644 --- a/clause/where.go +++ b/clause/where.go @@ -40,7 +40,9 @@ func (where Where) Build(builder Builder) { // MergeClause merge where clauses func (where Where) MergeClause(clause *Clause) { if w, ok := clause.Expression.(Where); ok { - where.Exprs = append(w.Exprs, where.Exprs...) + copiedExpressions := make([]Expression, len(w.Exprs)) + copy(copiedExpressions, w.Exprs) + where.Exprs = append(copiedExpressions, where.Exprs...) } clause.Expression = where diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 00000000..7d730875 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,37 @@ +package gorm + +import ( + "fmt" + "reflect" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWhereCloneCorruption(t *testing.T) { + for whereCount := 1; whereCount <= 8; whereCount++ { + t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { + s := new(Statement) + for w := 0; w < whereCount; w++ { + s = s.clone() + s.AddClause(clause.Where{ + Exprs: s.BuildCondtion(fmt.Sprintf("where%d", w)), + }) + } + + s1 := s.clone() + s1.AddClause(clause.Where{ + Exprs: s.BuildCondtion("FINAL1"), + }) + s2 := s.clone() + s2.AddClause(clause.Where{ + Exprs: s.BuildCondtion("FINAL2"), + }) + + if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { + t.Errorf("Where conditions should be different") + } + }) + } +} +