From 436cca753cd784969a19477f022db4eb3d84f2ec Mon Sep 17 00:00:00 2001 From: Stephano George Date: Sat, 23 Dec 2023 21:19:41 +0800 Subject: [PATCH] fix: join and select mytable.* not working (#6761) * fix: select mytable.* not working * fix: select mytable.*: will not match `mytable."*"`. feat: increase readability of code matching table name column name --- statement.go | 22 ++++++++++++++++++---- statement_test.go | 10 ++++++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/statement.go b/statement.go index 59c0b772..b24228b2 100644 --- a/statement.go +++ b/statement.go @@ -665,7 +665,21 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) +var matchName = func() func(tableColumn string) (table, column string) { + nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`) + return func(tableColumn string) (table, column string) { + if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 { + table = matches[1] + star := matches[2] + columnName := matches[3] + if star != "" { + return table, star + } + return table, columnName + } + return "", "" + } +}() // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { @@ -686,13 +700,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = result - } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { - if matches[2] == "*" { + } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") { + if col == "*" { for _, dbName := range stmt.Schema.DBNames { results[dbName] = result } } else { - results[matches[2]] = result + results[col] = result } } else { results[column] = result diff --git a/statement_test.go b/statement_test.go index 648bc875..0995d547 100644 --- a/statement_test.go +++ b/statement_test.go @@ -56,9 +56,15 @@ func TestNameMatcher(t *testing.T) { "`name_1`": {"", "name_1"}, "`Name_1`": {"", "Name_1"}, "`Table`.`nAme`": {"Table", "nAme"}, + "my_table.*": {"my_table", "*"}, + "`my_table`.*": {"my_table", "*"}, + "User__Company.*": {"User__Company", "*"}, + "`User__Company`.*": {"User__Company", "*"}, + `"User__Company".*`: {"User__Company", "*"}, + `"table"."*"`: {"", ""}, } { - if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { - t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + if table, column := matchName(k); table != v[0] || column != v[1] { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v) } } }