Possible to specify table name

This commit is contained in:
Jinzhu 2013-10-28 20:27:25 +08:00
parent 3bd7daf889
commit 2a20e551ed
6 changed files with 124 additions and 33 deletions

View File

@ -20,12 +20,12 @@ Yet Another ORM library for Go, aims for developer friendly
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable") db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
type User struct { type User struct {
Id int64 Id int64
Age int64 Age int64
Birthday time.Time Birthday time.Time
Name string Name string
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
// Set database pool // Set database pool
@ -127,12 +127,12 @@ Below callbacks are defined now:
Callbacks is a function defined to a model, if the function return error, will prevent the database operations. Callbacks is a function defined to a model, if the function return error, will prevent the database operations.
func (u *User) BeforeUpdate() (err error) { func (u *User) BeforeUpdate() (err error) {
if u.readonly() { if u.readonly() {
err = errors.New("Read Only User") err = errors.New("Read Only User")
}
return
} }
return
}
// Pluck (get users's age as map) // Pluck (get users's age as map)
var ages []int64 var ages []int64
@ -149,6 +149,16 @@ db.Where("name <> ?", "jinzhu").Where("age >= ? and role <> ?", 20, "admin").Fin
// Create Table with struct // Create Table with struct
db.CreateTable(&User{}) db.CreateTable(&User{})
// Specify Table Name
db.Table("deleted_users").CreateTable(&User{})
db.Table("users").Pluck("age", &ages)
//// ages -> select age from users;
var deleted_users []User
db.Table("deleted_users").Find(&deleted_users)
//// deleted_users -> select * from deleted_users;
db.Table("deleted_users").Find(&deleted_user)
//// deleted_user -> select * from deleted_users limit 1;
// Run Raw SQL // Run Raw SQL
db.Exec("drop table users;") db.Exec("drop table users;")
@ -185,6 +195,10 @@ db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_cart)
//// shopping_cart -> select * from carts where product_name = 'fancy_product' //// shopping_cart -> select * from carts where product_name = 'fancy_product'
// Do you noticed the search table is different for above query, yay // Do you noticed the search table is different for above query, yay
db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&user2)
//// users1 -> select * from users where mail_type = 'TEXT';
//// users2 -> select * from deleted_users where mail_type = 'TEXT';
// Open your mind, add more cool examples // Open your mind, add more cool examples
``` ```

View File

@ -17,12 +17,13 @@ type Chain struct {
Errors []error Errors []error
Error error Error error
whereClause []map[string]interface{} whereClause []map[string]interface{}
orClause []map[string]interface{} orClause []map[string]interface{}
selectStr string selectStr string
orderStrs []string orderStrs []string
offsetStr string offsetStr string
limitStr string limitStr string
specifiedTableName string
} }
func (s *Chain) err(err error) { func (s *Chain) err(err error) {
@ -44,6 +45,7 @@ func (s *Chain) do(value interface{}) *Do {
do.orderStrs = s.orderStrs do.orderStrs = s.orderStrs
do.offsetStr = s.offsetStr do.offsetStr = s.offsetStr
do.limitStr = s.limitStr do.limitStr = s.limitStr
do.specifiedTableName = s.specifiedTableName
s.value = value s.value = value
do.setModel(value) do.setModel(value)
@ -174,6 +176,11 @@ func (s *Chain) CreateTable(value interface{}) *Chain {
return s return s
} }
func (s *Chain) Table(name string) *Chain {
s.specifiedTableName = name
return s
}
func (s *Chain) validSql(str string) (result bool) { func (s *Chain) validSql(str string) (result bool) {
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
if !result { if !result {

41
do.go
View File

@ -12,11 +12,12 @@ import (
) )
type Do struct { type Do struct {
chain *Chain chain *Chain
db *sql.DB db *sql.DB
driver string driver string
TableName string guessedTableName string
Errors []error specifiedTableName string
Errors []error
model *Model model *Model
value interface{} value interface{}
@ -34,6 +35,14 @@ type Do struct {
operation string operation string
} }
func (s *Do) tableName() string {
if s.specifiedTableName == "" {
return s.guessedTableName
} else {
return s.specifiedTableName
}
}
func (s *Do) err(err error) { func (s *Do) err(err error) {
if err != nil { if err != nil {
s.Errors = append(s.Errors, err) s.Errors = append(s.Errors, err)
@ -49,8 +58,10 @@ func (s *Do) setModel(value interface{}) {
s.value = value s.value = value
s.model = &Model{Data: value, driver: s.driver} s.model = &Model{Data: value, driver: s.driver}
var err error var err error
s.TableName, err = s.model.tableName() if s.specifiedTableName == "" {
s.err(err) s.guessedTableName, err = s.model.tableName()
s.err(err)
}
} }
func (s *Do) addToVars(value interface{}) string { func (s *Do) addToVars(value interface{}) string {
@ -91,7 +102,7 @@ func (s *Do) prepareCreateSql() *Do {
s.Sql = fmt.Sprintf( s.Sql = fmt.Sprintf(
"INSERT INTO \"%v\" (%v) VALUES (%v) %v", "INSERT INTO \"%v\" (%v) VALUES (%v) %v",
s.TableName, s.tableName(),
strings.Join(s.quoteMap(columns), ","), strings.Join(s.quoteMap(columns), ","),
strings.Join(sqls, ","), strings.Join(sqls, ","),
s.model.returningStr(), s.model.returningStr(),
@ -135,7 +146,7 @@ func (s *Do) prepareUpdateSql() *Do {
s.Sql = fmt.Sprintf( s.Sql = fmt.Sprintf(
"UPDATE %v SET %v %v", "UPDATE %v SET %v %v",
s.TableName, s.tableName(),
strings.Join(sets, ", "), strings.Join(sets, ", "),
s.combinedSql(), s.combinedSql(),
) )
@ -154,7 +165,7 @@ func (s *Do) update() *Do {
} }
func (s *Do) prepareDeleteSql() *Do { func (s *Do) prepareDeleteSql() *Do {
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql()) s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
return s return s
} }
@ -168,7 +179,7 @@ func (s *Do) delete() *Do {
} }
func (s *Do) prepareQuerySql() *Do { func (s *Do) prepareQuerySql() *Do {
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql()) s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
return s return s
} }
@ -216,7 +227,10 @@ func (s *Do) query(where ...interface{}) {
columns, _ := rows.Columns() columns, _ := rows.Columns()
var values []interface{} var values []interface{}
for _, value := range columns { for _, value := range columns {
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface()) field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
}
} }
s.err(rows.Scan(values...)) s.err(rows.Scan(values...))
@ -400,10 +414,9 @@ func (s *Do) createTable() *Do {
for _, field := range s.model.fields("null") { for _, field := range s.model.fields("null") {
sqls = append(sqls, field.DbName+" "+field.SqlType) sqls = append(sqls, field.DbName+" "+field.SqlType)
} }
s.Sql = fmt.Sprintf( s.Sql = fmt.Sprintf(
"CREATE TABLE \"%v\" (%v)", "CREATE TABLE \"%v\" (%v)",
s.TableName, s.tableName(),
strings.Join(sqls, ","), strings.Join(sqls, ","),
) )
return s return s

View File

@ -1,11 +1,11 @@
package gorm package gorm
import ( import (
_ "github.com/lib/pq"
"errors" "errors"
_ "github.com/lib/pq"
"reflect" "reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
) )
@ -567,3 +567,52 @@ func TestNoUnExpectedHappenWithInvalidSql(t *testing.T) {
db.Where("unexisting = ?", "3").Find(&[]User{}) db.Where("unexisting = ?", "3").Find(&[]User{})
} }
func TestSetTableDirectly(t *testing.T) {
var ages []int64
if db.Table("users").Pluck("age", &ages).Error != nil {
t.Errorf("No errors should happen if only set table")
}
if len(ages) == 0 {
t.Errorf("Should find out some records")
}
var users []User
if db.Table("users").Find(&users).Error != nil {
t.Errorf("No errors should happen if set table to an existing table")
}
if db.Table("unexisting_users_table").Find(&users).Error == nil {
t.Errorf("Should got some errors if set table to an unexisting table")
}
if db.Table("products").Find(&users).Error == nil {
t.Errorf("Should got some errors if set table to an unexisting table")
}
db.Exec("drop table deleted_users;")
if db.Table("deleted_users").CreateTable(&User{}).Error != nil {
t.Errorf("Should create deleted_users table")
}
db.Table("deleted_users").Save(&User{Name: "DeletedUser"})
var deleted_users []User
db.Table("deleted_users").Find(&deleted_users)
if len(deleted_users) != 1 {
t.Errorf("Should query from deleted_users table")
}
var deleted_user User
db.Table("deleted_users").Find(&deleted_user)
if deleted_user.Name != "DeletedUser" {
t.Errorf("Should query from deleted_users table")
}
var user1, user2, user3 User
db.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
if !((user1.Name != user2.Name) && (user1.Name == user3.Name)) {
t.Errorf("Set Table Chain Should works well")
}
}

View File

@ -65,6 +65,10 @@ func (s *DB) Model(value interface{}) *Chain {
return s.buildORM().Model(value) return s.buildORM().Model(value)
} }
func (s *DB) Table(name string) *Chain {
return s.buildORM().Table(name)
}
func (s *DB) CreateTable(value interface{}) *Chain { func (s *DB) CreateTable(value interface{}) *Chain {
return s.buildORM().CreateTable(value) return s.buildORM().CreateTable(value)
} }

View File

@ -29,6 +29,10 @@ func (m *Model) primaryKeyZero() bool {
} }
func (m *Model) primaryKeyValue() int64 { func (m *Model) primaryKeyValue() int64 {
if m.Data == nil {
return 0
}
t := reflect.TypeOf(m.Data).Elem() t := reflect.TypeOf(m.Data).Elem()
switch t.Kind() { switch t.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: