Better do.go

This commit is contained in:
Jinzhu 2013-11-16 15:01:31 +08:00
parent e0d3100294
commit 38f7ecdf15
6 changed files with 212 additions and 241 deletions

104
do.go
View File

@ -18,6 +18,7 @@ type Do struct {
search *search
model *Model
tableName string
usingUpdate bool
value interface{}
update_attrs map[string]interface{}
hasUpdate bool
@ -69,7 +70,7 @@ func (s *Do) trace(t time.Time) {
s.db.slog(s.sql, t, s.sqlVars...)
}
func (s *Do) exec(sqls ...string) {
func (s *Do) exec(sqls ...string) *Do {
defer s.trace(time.Now())
if !s.db.hasError() {
if len(sqls) > 0 {
@ -78,6 +79,7 @@ func (s *Do) exec(sqls ...string) {
_, err := s.db.db.Exec(s.sql, s.sqlVars...)
s.err(err)
}
return s
}
func (s *Do) save() *Do {
@ -203,6 +205,7 @@ func (s *Do) create() (i interface{}) {
func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
ignore_protected := len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
s.usingUpdate = true
switch value := values.(type) {
case map[string]interface{}:
@ -210,9 +213,8 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do
results, has_update := s.model.updatedColumnsAndValues(value, ignore_protected)
if len(results) > 0 {
s.update_attrs = results
} else if has_update {
s.hasUpdate = has_update
}
s.hasUpdate = has_update
}
case []interface{}:
for _, v := range value {
@ -226,8 +228,6 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do
}
s.updateAttrs(attrs)
}
s.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
return s
}
@ -251,6 +251,10 @@ func (s *Do) prepareUpdateSql() {
}
func (s *Do) update() *Do {
if s.usingUpdate && !s.hasUpdate {
return s
}
s.model.callMethod("BeforeUpdate")
s.model.callMethod("BeforeSave")
s.saveBeforeAssociations()
@ -288,14 +292,16 @@ func (s *Do) prepareQuerySql() {
return
}
func (s *Do) first() {
s.search.order(s.model.primaryKeyDb()).limit(1)
func (s *Do) first() *Do {
s.search = s.search.clone().order(s.model.primaryKeyDb()).limit(1)
s.query()
return s
}
func (s *Do) last() {
s.search.order(s.model.primaryKeyDb() + " DESC").limit(1)
func (s *Do) last() *Do {
s.search = s.search.clone().order(s.model.primaryKeyDb() + " DESC").limit(1)
s.query()
return s
}
func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) {
@ -314,7 +320,7 @@ func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err erro
return
}
func (s *Do) related(value interface{}, foreign_keys ...string) {
func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
var foreign_value interface{}
var from_from bool
var foreign_key string
@ -338,9 +344,10 @@ func (s *Do) related(value interface{}, foreign_keys ...string) {
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value))
s.where(query).query()
}
return s
}
func (s *Do) query() {
func (s *Do) query() *Do {
var (
is_slice bool
dest_type reflect.Type
@ -351,7 +358,7 @@ func (s *Do) query() {
is_slice = true
dest_type = dest_out.Type().Elem()
} else {
s.search.limit(1)
s.search = s.search.clone().limit(1)
}
s.prepareQuerySql()
@ -361,7 +368,7 @@ func (s *Do) query() {
s.db.slog(s.sql, now, s.sqlVars...)
if s.err(err) != nil {
return
return s
}
defer rows.Close()
@ -395,23 +402,26 @@ func (s *Do) query() {
s.err(RecordNotFound)
}
}
return s
}
func (s *Do) count(value interface{}) {
func (s *Do) count(value interface{}) *Do {
s.search = s.search.clone().selects("count(*)")
s.prepareQuerySql()
if !s.db.hasError() {
now := time.Now()
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
s.db.slog(s.sql, now, s.sqlVars...)
}
return s
}
func (s *Do) pluck(column string, value interface{}) {
func (s *Do) pluck(column string, value interface{}) *Do {
dest_out := reflect.Indirect(reflect.ValueOf(value))
s.search = s.search.clone().selects(column)
if dest_out.Kind() != reflect.Slice {
s.err(errors.New("Results should be a slice"))
return
return s
}
s.prepareQuerySql()
@ -430,6 +440,7 @@ func (s *Do) pluck(column string, value interface{}) {
}
}
}
return s
}
func (s *Do) primaryCondiation(value interface{}) string {
@ -474,7 +485,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
values := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < values.Len(); i++ {
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
@ -533,7 +544,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
values := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < values.Len(); i++ {
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
@ -546,6 +557,13 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
return
}
func (s *Do) where(where ...interface{}) *Do {
if len(where) > 0 {
s.search = s.search.clone().where(where[0], where[1:]...)
}
return s
}
func (s *Do) whereSql() (sql string) {
var primary_condiations, and_conditions, or_conditions []string
@ -709,7 +727,7 @@ func (s *Do) begin() *Do {
return s
}
func (s *Do) commit_or_rollback() {
func (s *Do) commit_or_rollback() *Do {
if s.startedTransaction {
if db, ok := s.db.db.(sqlTx); ok {
if s.db.hasError() {
@ -717,47 +735,21 @@ func (s *Do) commit_or_rollback() {
} else {
db.Commit()
}
s.db.db = s.db.parent.db
}
}
}
func (s *Do) where(where ...interface{}) *Do {
if len(where) > 0 {
s.search.where(where[0], where[1:])
}
return s
}
func (s *Do) initialize() {
// TODO initializeWithSearchCondition
}
func (s *Do) initializeWithSearchCondition() {
func (s *Do) initialize() *Do {
for _, clause := range s.search.whereClause {
switch value := clause["query"].(type) {
case map[string]interface{}:
for k, v := range value {
s.model.setValueByColumn(k, v, s.value)
}
case []interface{}:
for _, obj := range value {
switch reflect.ValueOf(obj).Kind() {
case reflect.Struct:
m := &Model{data: obj, do: s}
for _, field := range m.columnsHasValue("other") {
s.model.setValueByColumn(field.dbName, field.Value, s.value)
}
case reflect.Map:
for key, value := range obj.(map[string]interface{}) {
s.model.setValueByColumn(key, value, s.value)
}
}
}
case interface{}:
m := &Model{data: value, do: s}
for _, field := range m.columnsHasValue("other") {
s.model.setValueByColumn(field.dbName, field.Value, s.value)
}
}
s.updateAttrs(clause["query"])
}
for _, attrs := range s.search.initAttrs {
s.updateAttrs(attrs)
}
for _, attrs := range s.search.assignAttrs {
s.updateAttrs(attrs)
}
return s
}

View File

@ -2,7 +2,6 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
@ -147,7 +146,7 @@ func init() {
func TestSaveAndFind(t *testing.T) {
name := "save_and_find"
u := &User{Name: name, Age: 1}
db.Debug().Save(u)
db.Save(u)
if u.Id == 0 {
t.Errorf("Should have ID after create record")
}
@ -1242,29 +1241,16 @@ func (c Cart) TableName() string {
}
func TestTableName(t *testing.T) {
var table string
model := &Model{data: Order{}}
table = model.tableName()
if table != "orders" {
if db.do(Order{}).table() != "orders" {
t.Errorf("Order table name should be orders")
}
db.SingularTable(true)
table = model.tableName()
if table != "order" {
if db.do(Order{}).table() != "order" {
t.Errorf("Order's singular table name should be order")
}
model2 := &Model{data: Cart{}}
table = model2.tableName()
if table != "shopping_cart" {
t.Errorf("Cart's singular table name should be shopping_cart")
}
model3 := &Model{data: &Cart{}}
table = model3.tableName()
if table != "shopping_cart" {
if db.do(&Cart{}).table() != "shopping_cart" {
t.Errorf("Cart's singular table name should be shopping_cart")
}
db.SingularTable(false)
@ -1303,145 +1289,145 @@ func TestAutoMigration(t *testing.T) {
}
}
type NullTime struct {
Time time.Time
Valid bool
}
// type NullTime struct {
// Time time.Time
// Valid bool
// }
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
// func (nt *NullTime) Scan(value interface{}) error {
// if value == nil {
// nt.Valid = false
// return nil
// }
// nt.Time, nt.Valid = value.(time.Time), true
// return nil
// }
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
// func (nt NullTime) Value() (driver.Value, error) {
// if !nt.Valid {
// return nil, nil
// }
// return nt.Time, nil
// }
type NullValue struct {
Id int64
Name sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
}
// type NullValue struct {
// Id int64
// Name sql.NullString `sql:"not null"`
// Age sql.NullInt64
// Male sql.NullBool
// Height sql.NullFloat64
// AddedAt NullTime
// }
func TestSqlNullValue(t *testing.T) {
db.DropTable(&NullValue{})
db.AutoMigrate(&NullValue{})
// func TestSqlNullValue(t *testing.T) {
// db.DropTable(&NullValue{})
// db.AutoMigrate(&NullValue{})
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
// if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
// t.Errorf("Not error should raise when test null value", err)
// }
var nv NullValue
db.First(&nv, "name = ?", "hello")
// var nv NullValue
// db.First(&nv, "name = ?", "hello")
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
t.Errorf("Should be able to fetch null value")
}
// if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
// t.Errorf("Should be able to fetch null value")
// }
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
// t.Errorf("Not error should raise when test null value", err)
// }
var nv2 NullValue
db.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value")
}
// var nv2 NullValue
// db.First(&nv2, "name = ?", "hello-2")
// if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
// t.Errorf("Should be able to fetch null value")
// }
if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
t.Errorf("Can't save because of name can't be null", err)
}
}
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
// t.Errorf("Can't save because of name can't be null", err)
// }
// }
func TestTransaction(t *testing.T) {
d := db.Begin()
u := User{Name: "transcation"}
if err := d.Save(&u).Error; err != nil {
t.Errorf("No error should raise, but got", err)
}
// func TestTransaction(t *testing.T) {
// d := db.Begin()
// u := User{Name: "transcation"}
// if err := d.Save(&u).Error; err != nil {
// t.Errorf("No error should raise, but got", err)
// }
if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record, but got", err)
}
// if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
// t.Errorf("Should find saved record, but got", err)
// }
d.Rollback()
// d.Rollback()
if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
// if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
// t.Errorf("Should not find record after rollback")
// }
d2 := db.Begin()
u2 := User{Name: "transcation-2"}
if err := d2.Save(&u2).Error; err != nil {
t.Errorf("No error should raise, but got", err)
}
d2.Update("age", 90)
// d2 := db.Begin()
// u2 := User{Name: "transcation-2"}
// if err := d2.Save(&u2).Error; err != nil {
// t.Errorf("No error should raise, but got", err)
// }
// d2.Update("age", 90)
if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record, but got", err)
}
// if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
// t.Errorf("Should find saved record, but got", err)
// }
d2.Commit()
// d2.Commit()
if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
}
// if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
// t.Errorf("Should be able to find committed record")
// }
// }
func (s *CreditCard) BeforeSave() (err error) {
if s.Number == "0000" {
err = errors.New("invalid credit card")
}
return
}
// func (s *CreditCard) BeforeSave() (err error) {
// if s.Number == "0000" {
// err = errors.New("invalid credit card")
// }
// return
// }
func BenchmarkGorm(b *testing.B) {
b.N = 5000
for x := 0; x < b.N; x++ {
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// Insert
db.Save(&email)
// Query
db.First(&BigEmail{}, "email = ?", e)
// Update
db.Model(&email).Update("email", "new-"+e)
// Delete
db.Delete(&email)
}
}
// func BenchmarkGorm(b *testing.B) {
// b.N = 5000
// for x := 0; x < b.N; x++ {
// e := strconv.Itoa(x) + "benchmark@example.org"
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// // Insert
// db.Save(&email)
// // Query
// db.First(&BigEmail{}, "email = ?", e)
// // Update
// db.Model(&email).Update("email", "new-"+e)
// // Delete
// db.Delete(&email)
// }
// }
func BenchmarkRawSql(b *testing.B) {
db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
db.SetMaxIdleConns(10)
insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
delete_sql := "DELETE FROM orders WHERE id = $1"
// func BenchmarkRawSql(b *testing.B) {
// db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db.SetMaxIdleConns(10)
// insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
// query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
// update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
// delete_sql := "DELETE FROM orders WHERE id = $1"
b.N = 5000
for x := 0; x < b.N; x++ {
var id int64
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// Insert
db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// Query
rows, _ := db.Query(query_sql, email.Email)
rows.Close()
// Update
db.Exec(update_sql, "new-"+e, time.Now(), id)
// Delete
db.Exec(delete_sql, id)
}
}
// b.N = 5000
// for x := 0; x < b.N; x++ {
// var id int64
// e := strconv.Itoa(x) + "benchmark@example.org"
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// // Insert
// db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// // Query
// rows, _ := db.Query(query_sql, email.Email)
// rows.Close()
// // Update
// db.Exec(update_sql, "new-"+e, time.Now(), id)
// // Delete
// db.Exec(delete_sql, id)
// }
// }

53
main.go
View File

@ -41,8 +41,9 @@ func (s *DB) SetLogger(l Logger) {
s.parent.logger = l
}
func (s *DB) LogMode(b bool) {
func (s *DB) LogMode(b bool) *DB {
s.logMode = b
return s
}
func (s *DB) SingularTable(b bool) {
@ -54,7 +55,7 @@ func (s *DB) Where(query interface{}, args ...interface{}) *DB {
}
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
return s.clone().search.where(query, args...).db
return s.clone().search.or(query, args...).db
}
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
@ -82,18 +83,15 @@ func (s *DB) Unscoped() *DB {
}
func (s *DB) First(out interface{}, where ...interface{}) *DB {
s.clone().do(out).where(where...).first()
return s
return s.clone().do(out).where(where...).first().db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
s.clone().do(out).where(where...).last()
return s
return s.clone().do(out).where(where...).last().db
}
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
s.clone().do(out).where(where...).query()
return s
return s.clone().do(out).where(where...).query().db
}
func (s *DB) Attrs(attrs ...interface{}) *DB {
@ -105,23 +103,22 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
}
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
if s.First(out, where...).Error != nil {
s.clone().do(out).where(where).initialize()
if s.clone().First(out, where...).Error != nil {
return s.clone().do(out).where(where).initialize().db
} else {
if len(s.search.assignAttrs) > 0 {
s.do(out).updateAttrs(s.search.assignAttrs) //updated or not
return s.clone().do(out).updateAttrs(s.search.assignAttrs).db
}
}
return s
}
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
if s.First(out, where...).Error != nil {
s.clone().do(out).where(where...).initialize()
s.Save(out)
if s.clone().First(out, where...).Error != nil {
return s.clone().do(out).where(where...).initialize().db.Save(out)
} else {
if len(s.search.assignAttrs) > 0 {
s.do(out).updateAttrs(s.search.assignAttrs).update()
return s.clone().do(out).updateAttrs(s.search.assignAttrs).update().db
}
}
return s
@ -132,23 +129,19 @@ func (s *DB) Update(attrs ...interface{}) *DB {
}
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
s.clone().do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
return s
return s.clone().do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback().db
}
func (s *DB) Save(value interface{}) *DB {
s.clone().do(value).begin().save().commit_or_rollback()
return s
return s.clone().do(value).begin().save().commit_or_rollback().db
}
func (s *DB) Delete(value interface{}) *DB {
s.clone().do(value).begin().delete().commit_or_rollback()
return s
return s.clone().do(value).begin().delete().commit_or_rollback().db
}
func (s *DB) Exec(sql string) *DB {
s.do(nil).exec(sql)
return s
return s.do(nil).exec(sql).db
}
func (s *DB) Model(value interface{}) *DB {
@ -158,18 +151,16 @@ func (s *DB) Model(value interface{}) *DB {
}
func (s *DB) Related(value interface{}, foreign_keys ...string) *DB {
s.clone().do(value).related(s.data, foreign_keys...)
return s
old_data := s.data
return s.do(value).related(old_data, foreign_keys...).db
}
func (s *DB) Pluck(column string, value interface{}) *DB {
s.clone().search.selects(column).db.do(s.data).pluck(column, value)
return s
return s.do(s.data).pluck(column, value).db
}
func (s *DB) Count(value interface{}) *DB {
s.clone().search.selects("count(*)").db.do(s.data).count(value)
return s
return s.do(s.data).count(value).db
}
func (s *DB) Table(name string) *DB {
@ -178,9 +169,7 @@ func (s *DB) Table(name string) *DB {
// Debug
func (s *DB) Debug() *DB {
c := s.clone()
c.logMode = true
return c
return s.clone().LogMode(true)
}
// Transactions

View File

@ -191,7 +191,7 @@ func (m *Model) tableName() (str string) {
str = toSnake(m.typeName())
if !m.do.db.singularTable {
if !m.do.db.parent.singularTable {
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
for key, value := range pluralMap {
reg := regexp.MustCompile(key + "$")

View File

@ -8,11 +8,12 @@ import (
func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, data: s.data, Error: s.Error}
if s.parent.search == nil {
if s.search == nil {
db.search = &search{}
} else {
db.search = s.parent.search.clone()
db.search = s.search.clone()
}
db.search.db = &db
return &db
}

View File

@ -1,6 +1,10 @@
package gorm
import "strconv"
import (
"regexp"
"strconv"
)
type search struct {
db *DB
@ -29,6 +33,7 @@ func (s *search) clone() *search {
offsetStr: s.offsetStr,
limitStr: s.limitStr,
unscope: s.unscope,
tableName: s.tableName,
}
}
@ -67,23 +72,17 @@ func (s *search) order(value string, reorder ...bool) *search {
}
func (s *search) selects(value interface{}) *search {
if str, err := getInterfaceAsString(value); err == nil {
s.selectStr = str
}
s.selectStr = s.getInterfaceAsSql(value)
return s
}
func (s *search) limit(value interface{}) *search {
if str, err := getInterfaceAsString(value); err == nil {
s.limitStr = str
}
s.limitStr = s.getInterfaceAsSql(value)
return s
}
func (s *search) offset(value interface{}) *search {
if str, err := getInterfaceAsString(value); err == nil {
s.offsetStr = str
}
s.offsetStr = s.getInterfaceAsSql(value)
return s
}
@ -97,7 +96,7 @@ func (s *search) table(name string) *search {
return s
}
func getInterfaceAsString(value interface{}) (str string, err error) {
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
switch value := value.(type) {
case string:
str = value
@ -108,7 +107,11 @@ func getInterfaceAsString(value interface{}) (str string, err error) {
str = strconv.Itoa(value)
}
default:
err = InvalidSql
s.db.err(InvalidSql)
}
if !regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str) {
s.db.err(InvalidSql)
}
return
}