mirror of https://github.com/go-gorm/gorm.git
Possible to specify table name
This commit is contained in:
parent
3bd7daf889
commit
2a20e551ed
14
README.md
14
README.md
|
@ -149,6 +149,16 @@ db.Where("name <> ?", "jinzhu").Where("age >= ? and role <> ?", 20, "admin").Fin
|
|||
// Create Table with struct
|
||||
db.CreateTable(&User{})
|
||||
|
||||
// Specify Table Name
|
||||
db.Table("deleted_users").CreateTable(&User{})
|
||||
db.Table("users").Pluck("age", &ages)
|
||||
//// ages -> select age from users;
|
||||
var deleted_users []User
|
||||
db.Table("deleted_users").Find(&deleted_users)
|
||||
//// deleted_users -> select * from deleted_users;
|
||||
db.Table("deleted_users").Find(&deleted_user)
|
||||
//// deleted_user -> select * from deleted_users limit 1;
|
||||
|
||||
// Run Raw SQL
|
||||
db.Exec("drop table users;")
|
||||
|
||||
|
@ -185,6 +195,10 @@ db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_cart)
|
|||
//// shopping_cart -> select * from carts where product_name = 'fancy_product'
|
||||
// Do you noticed the search table is different for above query, yay
|
||||
|
||||
db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&user2)
|
||||
//// users1 -> select * from users where mail_type = 'TEXT';
|
||||
//// users2 -> select * from deleted_users where mail_type = 'TEXT';
|
||||
|
||||
// Open your mind, add more cool examples
|
||||
```
|
||||
|
||||
|
|
7
chain.go
7
chain.go
|
@ -23,6 +23,7 @@ type Chain struct {
|
|||
orderStrs []string
|
||||
offsetStr string
|
||||
limitStr string
|
||||
specifiedTableName string
|
||||
}
|
||||
|
||||
func (s *Chain) err(err error) {
|
||||
|
@ -44,6 +45,7 @@ func (s *Chain) do(value interface{}) *Do {
|
|||
do.orderStrs = s.orderStrs
|
||||
do.offsetStr = s.offsetStr
|
||||
do.limitStr = s.limitStr
|
||||
do.specifiedTableName = s.specifiedTableName
|
||||
|
||||
s.value = value
|
||||
do.setModel(value)
|
||||
|
@ -174,6 +176,11 @@ func (s *Chain) CreateTable(value interface{}) *Chain {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Table(name string) *Chain {
|
||||
s.specifiedTableName = name
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) validSql(str string) (result bool) {
|
||||
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
|
||||
if !result {
|
||||
|
|
29
do.go
29
do.go
|
@ -15,7 +15,8 @@ type Do struct {
|
|||
chain *Chain
|
||||
db *sql.DB
|
||||
driver string
|
||||
TableName string
|
||||
guessedTableName string
|
||||
specifiedTableName string
|
||||
Errors []error
|
||||
|
||||
model *Model
|
||||
|
@ -34,6 +35,14 @@ type Do struct {
|
|||
operation string
|
||||
}
|
||||
|
||||
func (s *Do) tableName() string {
|
||||
if s.specifiedTableName == "" {
|
||||
return s.guessedTableName
|
||||
} else {
|
||||
return s.specifiedTableName
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) err(err error) {
|
||||
if err != nil {
|
||||
s.Errors = append(s.Errors, err)
|
||||
|
@ -49,9 +58,11 @@ func (s *Do) setModel(value interface{}) {
|
|||
s.value = value
|
||||
s.model = &Model{Data: value, driver: s.driver}
|
||||
var err error
|
||||
s.TableName, err = s.model.tableName()
|
||||
if s.specifiedTableName == "" {
|
||||
s.guessedTableName, err = s.model.tableName()
|
||||
s.err(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) addToVars(value interface{}) string {
|
||||
s.SqlVars = append(s.SqlVars, value)
|
||||
|
@ -91,7 +102,7 @@ func (s *Do) prepareCreateSql() *Do {
|
|||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
||||
s.TableName,
|
||||
s.tableName(),
|
||||
strings.Join(s.quoteMap(columns), ","),
|
||||
strings.Join(sqls, ","),
|
||||
s.model.returningStr(),
|
||||
|
@ -135,7 +146,7 @@ func (s *Do) prepareUpdateSql() *Do {
|
|||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
s.TableName,
|
||||
s.tableName(),
|
||||
strings.Join(sets, ", "),
|
||||
s.combinedSql(),
|
||||
)
|
||||
|
@ -154,7 +165,7 @@ func (s *Do) update() *Do {
|
|||
}
|
||||
|
||||
func (s *Do) prepareDeleteSql() *Do {
|
||||
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
|
||||
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -168,7 +179,7 @@ func (s *Do) delete() *Do {
|
|||
}
|
||||
|
||||
func (s *Do) prepareQuerySql() *Do {
|
||||
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql())
|
||||
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -216,8 +227,11 @@ func (s *Do) query(where ...interface{}) {
|
|||
columns, _ := rows.Columns()
|
||||
var values []interface{}
|
||||
for _, value := range columns {
|
||||
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||
if field.IsValid() {
|
||||
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
|
||||
}
|
||||
}
|
||||
s.err(rows.Scan(values...))
|
||||
|
||||
if is_slice {
|
||||
|
@ -400,10 +414,9 @@ func (s *Do) createTable() *Do {
|
|||
for _, field := range s.model.fields("null") {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"CREATE TABLE \"%v\" (%v)",
|
||||
s.TableName,
|
||||
s.tableName(),
|
||||
strings.Join(sqls, ","),
|
||||
)
|
||||
return s
|
||||
|
|
53
gorm_test.go
53
gorm_test.go
|
@ -1,11 +1,11 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"errors"
|
||||
_ "github.com/lib/pq"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -567,3 +567,52 @@ func TestNoUnExpectedHappenWithInvalidSql(t *testing.T) {
|
|||
|
||||
db.Where("unexisting = ?", "3").Find(&[]User{})
|
||||
}
|
||||
|
||||
func TestSetTableDirectly(t *testing.T) {
|
||||
var ages []int64
|
||||
if db.Table("users").Pluck("age", &ages).Error != nil {
|
||||
t.Errorf("No errors should happen if only set table")
|
||||
}
|
||||
|
||||
if len(ages) == 0 {
|
||||
t.Errorf("Should find out some records")
|
||||
}
|
||||
|
||||
var users []User
|
||||
if db.Table("users").Find(&users).Error != nil {
|
||||
t.Errorf("No errors should happen if set table to an existing table")
|
||||
}
|
||||
|
||||
if db.Table("unexisting_users_table").Find(&users).Error == nil {
|
||||
t.Errorf("Should got some errors if set table to an unexisting table")
|
||||
}
|
||||
|
||||
if db.Table("products").Find(&users).Error == nil {
|
||||
t.Errorf("Should got some errors if set table to an unexisting table")
|
||||
}
|
||||
|
||||
db.Exec("drop table deleted_users;")
|
||||
if db.Table("deleted_users").CreateTable(&User{}).Error != nil {
|
||||
t.Errorf("Should create deleted_users table")
|
||||
}
|
||||
|
||||
db.Table("deleted_users").Save(&User{Name: "DeletedUser"})
|
||||
|
||||
var deleted_users []User
|
||||
db.Table("deleted_users").Find(&deleted_users)
|
||||
if len(deleted_users) != 1 {
|
||||
t.Errorf("Should query from deleted_users table")
|
||||
}
|
||||
|
||||
var deleted_user User
|
||||
db.Table("deleted_users").Find(&deleted_user)
|
||||
if deleted_user.Name != "DeletedUser" {
|
||||
t.Errorf("Should query from deleted_users table")
|
||||
}
|
||||
|
||||
var user1, user2, user3 User
|
||||
db.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
|
||||
if !((user1.Name != user2.Name) && (user1.Name == user3.Name)) {
|
||||
t.Errorf("Set Table Chain Should works well")
|
||||
}
|
||||
}
|
||||
|
|
4
main.go
4
main.go
|
@ -65,6 +65,10 @@ func (s *DB) Model(value interface{}) *Chain {
|
|||
return s.buildORM().Model(value)
|
||||
}
|
||||
|
||||
func (s *DB) Table(name string) *Chain {
|
||||
return s.buildORM().Table(name)
|
||||
}
|
||||
|
||||
func (s *DB) CreateTable(value interface{}) *Chain {
|
||||
return s.buildORM().CreateTable(value)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue