diff --git a/chainable_api.go b/chainable_api.go index fe11e474..4df8780e 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -41,17 +41,21 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } -var tableRegexp = regexp.MustCompile(`(?i).+ AS (\w+)\s*$`) +var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) // Table specify the table you would like to run db operations -func (db *DB) Table(name string) (tx *DB) { +func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if strings.Contains(name, " ") { - tx.Statement.TableExpr = &clause.Expr{SQL: name} + if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { + tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { tx.Statement.Table = results[1] return } + } else if tables := strings.Split(name, "."); len(tables) == 2 { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = tables[1] + return } tx.Statement.Table = name diff --git a/statement.go b/statement.go index 6641aed8..5f4238ef 100644 --- a/statement.go +++ b/statement.go @@ -377,6 +377,12 @@ func (stmt *Statement) Build(clauses ...string) { func (stmt *Statement) Parse(value interface{}) (err error) { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { + stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} + stmt.Table = tables[1] + return + } + stmt.Table = stmt.Schema.Table } return err diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2c593a70..1b002049 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -79,7 +79,7 @@ func TestMigrateWithUniqueIndex(t *testing.T) { } } -func TestTable(t *testing.T) { +func TestMigrateTable(t *testing.T) { type TableStruct struct { gorm.Model Name string @@ -112,7 +112,7 @@ func TestTable(t *testing.T) { } } -func TestIndexes(t *testing.T) { +func TestMigrateIndexes(t *testing.T) { type IndexStruct struct { gorm.Model Name string `gorm:"size:255;index"` @@ -162,7 +162,7 @@ func TestIndexes(t *testing.T) { } } -func TestColumns(t *testing.T) { +func TestMigrateColumns(t *testing.T) { type ColumnStruct struct { gorm.Model Name string diff --git a/tests/table_test.go b/tests/table_test.go new file mode 100644 index 00000000..b96af170 --- /dev/null +++ b/tests/table_test.go @@ -0,0 +1,52 @@ +package tests_test + +import ( + "regexp" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type UserWithTable struct { + gorm.Model + Name string +} + +func (UserWithTable) TableName() string { + return "gorm.user" +} + +func TestTable(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } +}