Support specify select/omit columns with table

This commit is contained in:
Jinzhu 2021-10-08 17:51:27 +08:00
parent d4c838c1ce
commit 6312d86c54
2 changed files with 20 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -627,6 +628,8 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false return false
} }
var nameMatcher = regexp.MustCompile(`\.[\W]?(.+?)[\W]?$`)
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{} results := map[string]bool{}
@ -647,6 +650,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
} }
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true results[field.DBName] = true
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
results[matches[1]] = true
} else { } else {
results[column] = true results[column] = true
} }
@ -662,6 +667,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
} }
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false results[field.DBName] = false
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
results[matches[1]] = false
} else { } else {
results[omit] = false results[omit] = false
} }

View File

@ -34,3 +34,16 @@ func TestWhereCloneCorruption(t *testing.T) {
}) })
} }
} }
func TestNameMatcher(t *testing.T) {
for k, v := range map[string]string{
"table.name": "name",
"`table`.`name`": "name",
"'table'.'name'": "name",
"'table'.name": "name",
} {
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v {
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
}
}
}