diff --git a/statement.go b/statement.go index 3b76f653..bea4f7f0 100644 --- a/statement.go +++ b/statement.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "regexp" "sort" "strconv" "strings" @@ -627,6 +628,8 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } +var nameMatcher = regexp.MustCompile(`\.[\W]?(.+?)[\W]?$`) + // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} @@ -647,6 +650,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { + results[matches[1]] = true } else { results[column] = true } @@ -662,6 +667,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false + } else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 { + results[matches[1]] = false } else { results[omit] = false } diff --git a/statement_test.go b/statement_test.go index 03ad81dc..3f099d61 100644 --- a/statement_test.go +++ b/statement_test.go @@ -34,3 +34,16 @@ func TestWhereCloneCorruption(t *testing.T) { }) } } + +func TestNameMatcher(t *testing.T) { + for k, v := range map[string]string{ + "table.name": "name", + "`table`.`name`": "name", + "'table'.'name'": "name", + "'table'.name": "name", + } { + if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + } + } +}