diff --git a/chainable_api.go b/chainable_api.go index 3e509f12..7ee20324 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "regexp" "strings" "gorm.io/gorm/clause" @@ -40,9 +41,19 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } +var tableRegexp = regexp.MustCompile("(?i).+ AS (\\w+)\\s*$") + // Table specify the table you would like to run db operations func (db *DB) Table(name string) (tx *DB) { tx = db.getInstance() + if strings.Contains(name, " ") { + tx.Statement.FullTable = name + if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { + tx.Statement.Table = results[1] + return + } + } + tx.Statement.Table = name return } diff --git a/statement.go b/statement.go index 00feeac5..142c7c31 100644 --- a/statement.go +++ b/statement.go @@ -19,6 +19,7 @@ import ( // Statement statement type Statement struct { *DB + FullTable string Table string Model interface{} Unscoped bool @@ -69,7 +70,11 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { - stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + if stmt.FullTable != "" { + writer.WriteString(stmt.FullTable) + } else { + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } } else if v.Raw { writer.WriteString(v.Name) } else { @@ -374,6 +379,7 @@ func (stmt *Statement) Build(clauses ...string) { func (stmt *Statement) Parse(value interface{}) (err error) { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { stmt.Table = stmt.Schema.Table + stmt.FullTable = stmt.Schema.Table } return err } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 634ee1cb..e6038947 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -24,6 +24,22 @@ func TestRow(t *testing.T) { if age != 10 { t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) } + + table := "gorm.users" + if DB.Dialector.Name() != "mysql" { + table = "users" // other databases doesn't support select with `database.table` + } + + DB.Table(table).Where(map[string]interface{}{"name": user2.Name}).Update("age", 20) + + row = DB.Table(table+" as u").Where("u.name = ?", user2.Name).Select("age").Row() + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 20 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } } func TestRows(t *testing.T) {