forked from mirror/gorm
Possible to specify table name
This commit is contained in:
parent
3bd7daf889
commit
2a20e551ed
36
README.md
36
README.md
|
@ -20,12 +20,12 @@ Yet Another ORM library for Go, aims for developer friendly
|
||||||
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id int64
|
Id int64
|
||||||
Age int64
|
Age int64
|
||||||
Birthday time.Time
|
Birthday time.Time
|
||||||
Name string
|
Name string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set database pool
|
// Set database pool
|
||||||
|
@ -127,12 +127,12 @@ Below callbacks are defined now:
|
||||||
|
|
||||||
Callbacks is a function defined to a model, if the function return error, will prevent the database operations.
|
Callbacks is a function defined to a model, if the function return error, will prevent the database operations.
|
||||||
|
|
||||||
func (u *User) BeforeUpdate() (err error) {
|
func (u *User) BeforeUpdate() (err error) {
|
||||||
if u.readonly() {
|
if u.readonly() {
|
||||||
err = errors.New("Read Only User")
|
err = errors.New("Read Only User")
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Pluck (get users's age as map)
|
// Pluck (get users's age as map)
|
||||||
var ages []int64
|
var ages []int64
|
||||||
|
@ -149,6 +149,16 @@ db.Where("name <> ?", "jinzhu").Where("age >= ? and role <> ?", 20, "admin").Fin
|
||||||
// Create Table with struct
|
// Create Table with struct
|
||||||
db.CreateTable(&User{})
|
db.CreateTable(&User{})
|
||||||
|
|
||||||
|
// Specify Table Name
|
||||||
|
db.Table("deleted_users").CreateTable(&User{})
|
||||||
|
db.Table("users").Pluck("age", &ages)
|
||||||
|
//// ages -> select age from users;
|
||||||
|
var deleted_users []User
|
||||||
|
db.Table("deleted_users").Find(&deleted_users)
|
||||||
|
//// deleted_users -> select * from deleted_users;
|
||||||
|
db.Table("deleted_users").Find(&deleted_user)
|
||||||
|
//// deleted_user -> select * from deleted_users limit 1;
|
||||||
|
|
||||||
// Run Raw SQL
|
// Run Raw SQL
|
||||||
db.Exec("drop table users;")
|
db.Exec("drop table users;")
|
||||||
|
|
||||||
|
@ -185,6 +195,10 @@ db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_cart)
|
||||||
//// shopping_cart -> select * from carts where product_name = 'fancy_product'
|
//// shopping_cart -> select * from carts where product_name = 'fancy_product'
|
||||||
// Do you noticed the search table is different for above query, yay
|
// Do you noticed the search table is different for above query, yay
|
||||||
|
|
||||||
|
db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&user2)
|
||||||
|
//// users1 -> select * from users where mail_type = 'TEXT';
|
||||||
|
//// users2 -> select * from deleted_users where mail_type = 'TEXT';
|
||||||
|
|
||||||
// Open your mind, add more cool examples
|
// Open your mind, add more cool examples
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
19
chain.go
19
chain.go
|
@ -17,12 +17,13 @@ type Chain struct {
|
||||||
Errors []error
|
Errors []error
|
||||||
Error error
|
Error error
|
||||||
|
|
||||||
whereClause []map[string]interface{}
|
whereClause []map[string]interface{}
|
||||||
orClause []map[string]interface{}
|
orClause []map[string]interface{}
|
||||||
selectStr string
|
selectStr string
|
||||||
orderStrs []string
|
orderStrs []string
|
||||||
offsetStr string
|
offsetStr string
|
||||||
limitStr string
|
limitStr string
|
||||||
|
specifiedTableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) err(err error) {
|
func (s *Chain) err(err error) {
|
||||||
|
@ -44,6 +45,7 @@ func (s *Chain) do(value interface{}) *Do {
|
||||||
do.orderStrs = s.orderStrs
|
do.orderStrs = s.orderStrs
|
||||||
do.offsetStr = s.offsetStr
|
do.offsetStr = s.offsetStr
|
||||||
do.limitStr = s.limitStr
|
do.limitStr = s.limitStr
|
||||||
|
do.specifiedTableName = s.specifiedTableName
|
||||||
|
|
||||||
s.value = value
|
s.value = value
|
||||||
do.setModel(value)
|
do.setModel(value)
|
||||||
|
@ -174,6 +176,11 @@ func (s *Chain) CreateTable(value interface{}) *Chain {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Chain) Table(name string) *Chain {
|
||||||
|
s.specifiedTableName = name
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Chain) validSql(str string) (result bool) {
|
func (s *Chain) validSql(str string) (result bool) {
|
||||||
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
|
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
|
||||||
if !result {
|
if !result {
|
||||||
|
|
41
do.go
41
do.go
|
@ -12,11 +12,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Do struct {
|
type Do struct {
|
||||||
chain *Chain
|
chain *Chain
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
driver string
|
driver string
|
||||||
TableName string
|
guessedTableName string
|
||||||
Errors []error
|
specifiedTableName string
|
||||||
|
Errors []error
|
||||||
|
|
||||||
model *Model
|
model *Model
|
||||||
value interface{}
|
value interface{}
|
||||||
|
@ -34,6 +35,14 @@ type Do struct {
|
||||||
operation string
|
operation string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Do) tableName() string {
|
||||||
|
if s.specifiedTableName == "" {
|
||||||
|
return s.guessedTableName
|
||||||
|
} else {
|
||||||
|
return s.specifiedTableName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Do) err(err error) {
|
func (s *Do) err(err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Errors = append(s.Errors, err)
|
s.Errors = append(s.Errors, err)
|
||||||
|
@ -49,8 +58,10 @@ func (s *Do) setModel(value interface{}) {
|
||||||
s.value = value
|
s.value = value
|
||||||
s.model = &Model{Data: value, driver: s.driver}
|
s.model = &Model{Data: value, driver: s.driver}
|
||||||
var err error
|
var err error
|
||||||
s.TableName, err = s.model.tableName()
|
if s.specifiedTableName == "" {
|
||||||
s.err(err)
|
s.guessedTableName, err = s.model.tableName()
|
||||||
|
s.err(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) addToVars(value interface{}) string {
|
func (s *Do) addToVars(value interface{}) string {
|
||||||
|
@ -91,7 +102,7 @@ func (s *Do) prepareCreateSql() *Do {
|
||||||
|
|
||||||
s.Sql = fmt.Sprintf(
|
s.Sql = fmt.Sprintf(
|
||||||
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
||||||
s.TableName,
|
s.tableName(),
|
||||||
strings.Join(s.quoteMap(columns), ","),
|
strings.Join(s.quoteMap(columns), ","),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(sqls, ","),
|
||||||
s.model.returningStr(),
|
s.model.returningStr(),
|
||||||
|
@ -135,7 +146,7 @@ func (s *Do) prepareUpdateSql() *Do {
|
||||||
|
|
||||||
s.Sql = fmt.Sprintf(
|
s.Sql = fmt.Sprintf(
|
||||||
"UPDATE %v SET %v %v",
|
"UPDATE %v SET %v %v",
|
||||||
s.TableName,
|
s.tableName(),
|
||||||
strings.Join(sets, ", "),
|
strings.Join(sets, ", "),
|
||||||
s.combinedSql(),
|
s.combinedSql(),
|
||||||
)
|
)
|
||||||
|
@ -154,7 +165,7 @@ func (s *Do) update() *Do {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareDeleteSql() *Do {
|
func (s *Do) prepareDeleteSql() *Do {
|
||||||
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
|
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,7 +179,7 @@ func (s *Do) delete() *Do {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareQuerySql() *Do {
|
func (s *Do) prepareQuerySql() *Do {
|
||||||
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql())
|
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,7 +227,10 @@ func (s *Do) query(where ...interface{}) {
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
for _, value := range columns {
|
for _, value := range columns {
|
||||||
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
|
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||||
|
if field.IsValid() {
|
||||||
|
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.err(rows.Scan(values...))
|
s.err(rows.Scan(values...))
|
||||||
|
|
||||||
|
@ -400,10 +414,9 @@ func (s *Do) createTable() *Do {
|
||||||
for _, field := range s.model.fields("null") {
|
for _, field := range s.model.fields("null") {
|
||||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Sql = fmt.Sprintf(
|
s.Sql = fmt.Sprintf(
|
||||||
"CREATE TABLE \"%v\" (%v)",
|
"CREATE TABLE \"%v\" (%v)",
|
||||||
s.TableName,
|
s.tableName(),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(sqls, ","),
|
||||||
)
|
)
|
||||||
return s
|
return s
|
||||||
|
|
53
gorm_test.go
53
gorm_test.go
|
@ -1,11 +1,11 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "github.com/lib/pq"
|
|
||||||
|
|
||||||
"errors"
|
"errors"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -567,3 +567,52 @@ func TestNoUnExpectedHappenWithInvalidSql(t *testing.T) {
|
||||||
|
|
||||||
db.Where("unexisting = ?", "3").Find(&[]User{})
|
db.Where("unexisting = ?", "3").Find(&[]User{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetTableDirectly(t *testing.T) {
|
||||||
|
var ages []int64
|
||||||
|
if db.Table("users").Pluck("age", &ages).Error != nil {
|
||||||
|
t.Errorf("No errors should happen if only set table")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ages) == 0 {
|
||||||
|
t.Errorf("Should find out some records")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
if db.Table("users").Find(&users).Error != nil {
|
||||||
|
t.Errorf("No errors should happen if set table to an existing table")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Table("unexisting_users_table").Find(&users).Error == nil {
|
||||||
|
t.Errorf("Should got some errors if set table to an unexisting table")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Table("products").Find(&users).Error == nil {
|
||||||
|
t.Errorf("Should got some errors if set table to an unexisting table")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Exec("drop table deleted_users;")
|
||||||
|
if db.Table("deleted_users").CreateTable(&User{}).Error != nil {
|
||||||
|
t.Errorf("Should create deleted_users table")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Table("deleted_users").Save(&User{Name: "DeletedUser"})
|
||||||
|
|
||||||
|
var deleted_users []User
|
||||||
|
db.Table("deleted_users").Find(&deleted_users)
|
||||||
|
if len(deleted_users) != 1 {
|
||||||
|
t.Errorf("Should query from deleted_users table")
|
||||||
|
}
|
||||||
|
|
||||||
|
var deleted_user User
|
||||||
|
db.Table("deleted_users").Find(&deleted_user)
|
||||||
|
if deleted_user.Name != "DeletedUser" {
|
||||||
|
t.Errorf("Should query from deleted_users table")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user1, user2, user3 User
|
||||||
|
db.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
|
||||||
|
if !((user1.Name != user2.Name) && (user1.Name == user3.Name)) {
|
||||||
|
t.Errorf("Set Table Chain Should works well")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
4
main.go
4
main.go
|
@ -65,6 +65,10 @@ func (s *DB) Model(value interface{}) *Chain {
|
||||||
return s.buildORM().Model(value)
|
return s.buildORM().Model(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DB) Table(name string) *Chain {
|
||||||
|
return s.buildORM().Table(name)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DB) CreateTable(value interface{}) *Chain {
|
func (s *DB) CreateTable(value interface{}) *Chain {
|
||||||
return s.buildORM().CreateTable(value)
|
return s.buildORM().CreateTable(value)
|
||||||
}
|
}
|
||||||
|
|
4
model.go
4
model.go
|
@ -29,6 +29,10 @@ func (m *Model) primaryKeyZero() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) primaryKeyValue() int64 {
|
func (m *Model) primaryKeyValue() int64 {
|
||||||
|
if m.Data == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
t := reflect.TypeOf(m.Data).Elem()
|
t := reflect.TypeOf(m.Data).Elem()
|
||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||||
|
|
Loading…
Reference in New Issue