From 38f7ecdf1582fcee903bbacb368641606908e728 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 16 Nov 2013 15:01:31 +0800 Subject: [PATCH] Better do.go --- do.go | 104 ++++++++++---------- gorm_test.go | 262 ++++++++++++++++++++++++--------------------------- main.go | 53 +++++------ model.go | 2 +- private.go | 5 +- search.go | 27 +++--- 6 files changed, 212 insertions(+), 241 deletions(-) diff --git a/do.go b/do.go index 72c9f0a1..a8a554fb 100644 --- a/do.go +++ b/do.go @@ -18,6 +18,7 @@ type Do struct { search *search model *Model tableName string + usingUpdate bool value interface{} update_attrs map[string]interface{} hasUpdate bool @@ -69,7 +70,7 @@ func (s *Do) trace(t time.Time) { s.db.slog(s.sql, t, s.sqlVars...) } -func (s *Do) exec(sqls ...string) { +func (s *Do) exec(sqls ...string) *Do { defer s.trace(time.Now()) if !s.db.hasError() { if len(sqls) > 0 { @@ -78,6 +79,7 @@ func (s *Do) exec(sqls ...string) { _, err := s.db.db.Exec(s.sql, s.sqlVars...) s.err(err) } + return s } func (s *Do) save() *Do { @@ -203,6 +205,7 @@ func (s *Do) create() (i interface{}) { func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do { ignore_protected := len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0] + s.usingUpdate = true switch value := values.(type) { case map[string]interface{}: @@ -210,9 +213,8 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do results, has_update := s.model.updatedColumnsAndValues(value, ignore_protected) if len(results) > 0 { s.update_attrs = results - } else if has_update { - s.hasUpdate = has_update } + s.hasUpdate = has_update } case []interface{}: for _, v := range value { @@ -226,8 +228,6 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do } s.updateAttrs(attrs) } - - s.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0] return s } @@ -251,6 +251,10 @@ func (s *Do) prepareUpdateSql() { } func (s *Do) update() *Do { + if s.usingUpdate && !s.hasUpdate { + return s + } + s.model.callMethod("BeforeUpdate") s.model.callMethod("BeforeSave") s.saveBeforeAssociations() @@ -288,14 +292,16 @@ func (s *Do) prepareQuerySql() { return } -func (s *Do) first() { - s.search.order(s.model.primaryKeyDb()).limit(1) +func (s *Do) first() *Do { + s.search = s.search.clone().order(s.model.primaryKeyDb()).limit(1) s.query() + return s } -func (s *Do) last() { - s.search.order(s.model.primaryKeyDb() + " DESC").limit(1) +func (s *Do) last() *Do { + s.search = s.search.clone().order(s.model.primaryKeyDb() + " DESC").limit(1) s.query() + return s } func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) { @@ -314,7 +320,7 @@ func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err erro return } -func (s *Do) related(value interface{}, foreign_keys ...string) { +func (s *Do) related(value interface{}, foreign_keys ...string) *Do { var foreign_value interface{} var from_from bool var foreign_key string @@ -338,9 +344,10 @@ func (s *Do) related(value interface{}, foreign_keys ...string) { query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value)) s.where(query).query() } + return s } -func (s *Do) query() { +func (s *Do) query() *Do { var ( is_slice bool dest_type reflect.Type @@ -351,7 +358,7 @@ func (s *Do) query() { is_slice = true dest_type = dest_out.Type().Elem() } else { - s.search.limit(1) + s.search = s.search.clone().limit(1) } s.prepareQuerySql() @@ -361,7 +368,7 @@ func (s *Do) query() { s.db.slog(s.sql, now, s.sqlVars...) if s.err(err) != nil { - return + return s } defer rows.Close() @@ -395,23 +402,26 @@ func (s *Do) query() { s.err(RecordNotFound) } } + return s } -func (s *Do) count(value interface{}) { +func (s *Do) count(value interface{}) *Do { + s.search = s.search.clone().selects("count(*)") s.prepareQuerySql() if !s.db.hasError() { now := time.Now() s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value)) s.db.slog(s.sql, now, s.sqlVars...) } + return s } -func (s *Do) pluck(column string, value interface{}) { +func (s *Do) pluck(column string, value interface{}) *Do { dest_out := reflect.Indirect(reflect.ValueOf(value)) - + s.search = s.search.clone().selects(column) if dest_out.Kind() != reflect.Slice { s.err(errors.New("Results should be a slice")) - return + return s } s.prepareQuerySql() @@ -430,6 +440,7 @@ func (s *Do) pluck(column string, value interface{}) { } } } + return s } func (s *Do) primaryCondiation(value interface{}) string { @@ -474,7 +485,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { values := reflect.ValueOf(arg) var temp_marks []string for i := 0; i < values.Len(); i++ { - temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface())) + temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface())) } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: @@ -533,7 +544,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { values := reflect.ValueOf(arg) var temp_marks []string for i := 0; i < values.Len(); i++ { - temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface())) + temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface())) } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: @@ -546,6 +557,13 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { return } +func (s *Do) where(where ...interface{}) *Do { + if len(where) > 0 { + s.search = s.search.clone().where(where[0], where[1:]...) + } + return s +} + func (s *Do) whereSql() (sql string) { var primary_condiations, and_conditions, or_conditions []string @@ -709,7 +727,7 @@ func (s *Do) begin() *Do { return s } -func (s *Do) commit_or_rollback() { +func (s *Do) commit_or_rollback() *Do { if s.startedTransaction { if db, ok := s.db.db.(sqlTx); ok { if s.db.hasError() { @@ -717,47 +735,21 @@ func (s *Do) commit_or_rollback() { } else { db.Commit() } + s.db.db = s.db.parent.db } } -} - -func (s *Do) where(where ...interface{}) *Do { - if len(where) > 0 { - s.search.where(where[0], where[1:]) - } return s } -func (s *Do) initialize() { - // TODO initializeWithSearchCondition -} - -func (s *Do) initializeWithSearchCondition() { +func (s *Do) initialize() *Do { for _, clause := range s.search.whereClause { - switch value := clause["query"].(type) { - case map[string]interface{}: - for k, v := range value { - s.model.setValueByColumn(k, v, s.value) - } - case []interface{}: - for _, obj := range value { - switch reflect.ValueOf(obj).Kind() { - case reflect.Struct: - m := &Model{data: obj, do: s} - for _, field := range m.columnsHasValue("other") { - s.model.setValueByColumn(field.dbName, field.Value, s.value) - } - case reflect.Map: - for key, value := range obj.(map[string]interface{}) { - s.model.setValueByColumn(key, value, s.value) - } - } - } - case interface{}: - m := &Model{data: value, do: s} - for _, field := range m.columnsHasValue("other") { - s.model.setValueByColumn(field.dbName, field.Value, s.value) - } - } + s.updateAttrs(clause["query"]) } + for _, attrs := range s.search.initAttrs { + s.updateAttrs(attrs) + } + for _, attrs := range s.search.assignAttrs { + s.updateAttrs(attrs) + } + return s } diff --git a/gorm_test.go b/gorm_test.go index 04582fab..0c27e77c 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -2,7 +2,6 @@ package gorm import ( "database/sql" - "database/sql/driver" "errors" "fmt" _ "github.com/go-sql-driver/mysql" @@ -147,7 +146,7 @@ func init() { func TestSaveAndFind(t *testing.T) { name := "save_and_find" u := &User{Name: name, Age: 1} - db.Debug().Save(u) + db.Save(u) if u.Id == 0 { t.Errorf("Should have ID after create record") } @@ -1242,29 +1241,16 @@ func (c Cart) TableName() string { } func TestTableName(t *testing.T) { - var table string - - model := &Model{data: Order{}} - table = model.tableName() - if table != "orders" { + if db.do(Order{}).table() != "orders" { t.Errorf("Order table name should be orders") } db.SingularTable(true) - table = model.tableName() - if table != "order" { + if db.do(Order{}).table() != "order" { t.Errorf("Order's singular table name should be order") } - model2 := &Model{data: Cart{}} - table = model2.tableName() - if table != "shopping_cart" { - t.Errorf("Cart's singular table name should be shopping_cart") - } - - model3 := &Model{data: &Cart{}} - table = model3.tableName() - if table != "shopping_cart" { + if db.do(&Cart{}).table() != "shopping_cart" { t.Errorf("Cart's singular table name should be shopping_cart") } db.SingularTable(false) @@ -1303,145 +1289,145 @@ func TestAutoMigration(t *testing.T) { } } -type NullTime struct { - Time time.Time - Valid bool -} +// type NullTime struct { +// Time time.Time +// Valid bool +// } -func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil -} +// func (nt *NullTime) Scan(value interface{}) error { +// if value == nil { +// nt.Valid = false +// return nil +// } +// nt.Time, nt.Valid = value.(time.Time), true +// return nil +// } -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} +// func (nt NullTime) Value() (driver.Value, error) { +// if !nt.Valid { +// return nil, nil +// } +// return nt.Time, nil +// } -type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime -} +// type NullValue struct { +// Id int64 +// Name sql.NullString `sql:"not null"` +// Age sql.NullInt64 +// Male sql.NullBool +// Height sql.NullFloat64 +// AddedAt NullTime +// } -func TestSqlNullValue(t *testing.T) { - db.DropTable(&NullValue{}) - db.AutoMigrate(&NullValue{}) +// func TestSqlNullValue(t *testing.T) { +// db.DropTable(&NullValue{}) +// db.AutoMigrate(&NullValue{}) - if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil { - t.Errorf("Not error should raise when test null value", err) - } +// if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil { +// t.Errorf("Not error should raise when test null value", err) +// } - var nv NullValue - db.First(&nv, "name = ?", "hello") +// var nv NullValue +// db.First(&nv, "name = ?", "hello") - if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { - t.Errorf("Should be able to fetch null value") - } +// if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { +// t.Errorf("Should be able to fetch null value") +// } - if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil { - t.Errorf("Not error should raise when test null value", err) - } +// if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil { +// t.Errorf("Not error should raise when test null value", err) +// } - var nv2 NullValue - db.First(&nv2, "name = ?", "hello-2") - if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { - t.Errorf("Should be able to fetch null value") - } +// var nv2 NullValue +// db.First(&nv2, "name = ?", "hello-2") +// if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { +// t.Errorf("Should be able to fetch null value") +// } - if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil { - t.Errorf("Can't save because of name can't be null", err) - } -} +// if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil { +// t.Errorf("Can't save because of name can't be null", err) +// } +// } -func TestTransaction(t *testing.T) { - d := db.Begin() - u := User{Name: "transcation"} - if err := d.Save(&u).Error; err != nil { - t.Errorf("No error should raise, but got", err) - } +// func TestTransaction(t *testing.T) { +// d := db.Begin() +// u := User{Name: "transcation"} +// if err := d.Save(&u).Error; err != nil { +// t.Errorf("No error should raise, but got", err) +// } - if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record, but got", err) - } +// if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil { +// t.Errorf("Should find saved record, but got", err) +// } - d.Rollback() +// d.Rollback() - if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } +// if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil { +// t.Errorf("Should not find record after rollback") +// } - d2 := db.Begin() - u2 := User{Name: "transcation-2"} - if err := d2.Save(&u2).Error; err != nil { - t.Errorf("No error should raise, but got", err) - } - d2.Update("age", 90) +// d2 := db.Begin() +// u2 := User{Name: "transcation-2"} +// if err := d2.Save(&u2).Error; err != nil { +// t.Errorf("No error should raise, but got", err) +// } +// d2.Update("age", 90) - if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record, but got", err) - } +// if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { +// t.Errorf("Should find saved record, but got", err) +// } - d2.Commit() +// d2.Commit() - if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } -} +// if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil { +// t.Errorf("Should be able to find committed record") +// } +// } -func (s *CreditCard) BeforeSave() (err error) { - if s.Number == "0000" { - err = errors.New("invalid credit card") - } - return -} +// func (s *CreditCard) BeforeSave() (err error) { +// if s.Number == "0000" { +// err = errors.New("invalid credit card") +// } +// return +// } -func BenchmarkGorm(b *testing.B) { - b.N = 5000 - for x := 0; x < b.N; x++ { - e := strconv.Itoa(x) + "benchmark@example.org" - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} - // Insert - db.Save(&email) - // Query - db.First(&BigEmail{}, "email = ?", e) - // Update - db.Model(&email).Update("email", "new-"+e) - // Delete - db.Delete(&email) - } -} +// func BenchmarkGorm(b *testing.B) { +// b.N = 5000 +// for x := 0; x < b.N; x++ { +// e := strconv.Itoa(x) + "benchmark@example.org" +// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} +// // Insert +// db.Save(&email) +// // Query +// db.First(&BigEmail{}, "email = ?", e) +// // Update +// db.Model(&email).Update("email", "new-"+e) +// // Delete +// db.Delete(&email) +// } +// } -func BenchmarkRawSql(b *testing.B) { - db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - db.SetMaxIdleConns(10) - insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" - query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" - update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" - delete_sql := "DELETE FROM orders WHERE id = $1" +// func BenchmarkRawSql(b *testing.B) { +// db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db.SetMaxIdleConns(10) +// insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" +// query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" +// update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" +// delete_sql := "DELETE FROM orders WHERE id = $1" - b.N = 5000 - for x := 0; x < b.N; x++ { - var id int64 - e := strconv.Itoa(x) + "benchmark@example.org" - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} - // Insert - db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) - // Query - rows, _ := db.Query(query_sql, email.Email) - rows.Close() - // Update - db.Exec(update_sql, "new-"+e, time.Now(), id) - // Delete - db.Exec(delete_sql, id) - } -} +// b.N = 5000 +// for x := 0; x < b.N; x++ { +// var id int64 +// e := strconv.Itoa(x) + "benchmark@example.org" +// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} +// // Insert +// db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) +// // Query +// rows, _ := db.Query(query_sql, email.Email) +// rows.Close() +// // Update +// db.Exec(update_sql, "new-"+e, time.Now(), id) +// // Delete +// db.Exec(delete_sql, id) +// } +// } diff --git a/main.go b/main.go index 95ce8f4f..98ba7a21 100644 --- a/main.go +++ b/main.go @@ -41,8 +41,9 @@ func (s *DB) SetLogger(l Logger) { s.parent.logger = l } -func (s *DB) LogMode(b bool) { +func (s *DB) LogMode(b bool) *DB { s.logMode = b + return s } func (s *DB) SingularTable(b bool) { @@ -54,7 +55,7 @@ func (s *DB) Where(query interface{}, args ...interface{}) *DB { } func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.where(query, args...).db + return s.clone().search.or(query, args...).db } func (s *DB) Not(query interface{}, args ...interface{}) *DB { @@ -82,18 +83,15 @@ func (s *DB) Unscoped() *DB { } func (s *DB) First(out interface{}, where ...interface{}) *DB { - s.clone().do(out).where(where...).first() - return s + return s.clone().do(out).where(where...).first().db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { - s.clone().do(out).where(where...).last() - return s + return s.clone().do(out).where(where...).last().db } func (s *DB) Find(out interface{}, where ...interface{}) *DB { - s.clone().do(out).where(where...).query() - return s + return s.clone().do(out).where(where...).query().db } func (s *DB) Attrs(attrs ...interface{}) *DB { @@ -105,23 +103,22 @@ func (s *DB) Assign(attrs ...interface{}) *DB { } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - if s.First(out, where...).Error != nil { - s.clone().do(out).where(where).initialize() + if s.clone().First(out, where...).Error != nil { + return s.clone().do(out).where(where).initialize().db } else { if len(s.search.assignAttrs) > 0 { - s.do(out).updateAttrs(s.search.assignAttrs) //updated or not + return s.clone().do(out).updateAttrs(s.search.assignAttrs).db } } return s } func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - if s.First(out, where...).Error != nil { - s.clone().do(out).where(where...).initialize() - s.Save(out) + if s.clone().First(out, where...).Error != nil { + return s.clone().do(out).where(where...).initialize().db.Save(out) } else { if len(s.search.assignAttrs) > 0 { - s.do(out).updateAttrs(s.search.assignAttrs).update() + return s.clone().do(out).updateAttrs(s.search.assignAttrs).update().db } } return s @@ -132,23 +129,19 @@ func (s *DB) Update(attrs ...interface{}) *DB { } func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB { - s.clone().do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback() - return s + return s.clone().do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback().db } func (s *DB) Save(value interface{}) *DB { - s.clone().do(value).begin().save().commit_or_rollback() - return s + return s.clone().do(value).begin().save().commit_or_rollback().db } func (s *DB) Delete(value interface{}) *DB { - s.clone().do(value).begin().delete().commit_or_rollback() - return s + return s.clone().do(value).begin().delete().commit_or_rollback().db } func (s *DB) Exec(sql string) *DB { - s.do(nil).exec(sql) - return s + return s.do(nil).exec(sql).db } func (s *DB) Model(value interface{}) *DB { @@ -158,18 +151,16 @@ func (s *DB) Model(value interface{}) *DB { } func (s *DB) Related(value interface{}, foreign_keys ...string) *DB { - s.clone().do(value).related(s.data, foreign_keys...) - return s + old_data := s.data + return s.do(value).related(old_data, foreign_keys...).db } func (s *DB) Pluck(column string, value interface{}) *DB { - s.clone().search.selects(column).db.do(s.data).pluck(column, value) - return s + return s.do(s.data).pluck(column, value).db } func (s *DB) Count(value interface{}) *DB { - s.clone().search.selects("count(*)").db.do(s.data).count(value) - return s + return s.do(s.data).count(value).db } func (s *DB) Table(name string) *DB { @@ -178,9 +169,7 @@ func (s *DB) Table(name string) *DB { // Debug func (s *DB) Debug() *DB { - c := s.clone() - c.logMode = true - return c + return s.clone().LogMode(true) } // Transactions diff --git a/model.go b/model.go index 3aa39287..ea1aa603 100644 --- a/model.go +++ b/model.go @@ -191,7 +191,7 @@ func (m *Model) tableName() (str string) { str = toSnake(m.typeName()) - if !m.do.db.singularTable { + if !m.do.db.parent.singularTable { pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"} for key, value := range pluralMap { reg := regexp.MustCompile(key + "$") diff --git a/private.go b/private.go index c76b544c..9206cfd6 100644 --- a/private.go +++ b/private.go @@ -8,11 +8,12 @@ import ( func (s *DB) clone() *DB { db := DB{db: s.db, parent: s.parent, logMode: s.logMode, data: s.data, Error: s.Error} - if s.parent.search == nil { + if s.search == nil { db.search = &search{} } else { - db.search = s.parent.search.clone() + db.search = s.search.clone() } + db.search.db = &db return &db } diff --git a/search.go b/search.go index 8345a033..14103240 100644 --- a/search.go +++ b/search.go @@ -1,6 +1,10 @@ package gorm -import "strconv" +import ( + "regexp" + + "strconv" +) type search struct { db *DB @@ -29,6 +33,7 @@ func (s *search) clone() *search { offsetStr: s.offsetStr, limitStr: s.limitStr, unscope: s.unscope, + tableName: s.tableName, } } @@ -67,23 +72,17 @@ func (s *search) order(value string, reorder ...bool) *search { } func (s *search) selects(value interface{}) *search { - if str, err := getInterfaceAsString(value); err == nil { - s.selectStr = str - } + s.selectStr = s.getInterfaceAsSql(value) return s } func (s *search) limit(value interface{}) *search { - if str, err := getInterfaceAsString(value); err == nil { - s.limitStr = str - } + s.limitStr = s.getInterfaceAsSql(value) return s } func (s *search) offset(value interface{}) *search { - if str, err := getInterfaceAsString(value); err == nil { - s.offsetStr = str - } + s.offsetStr = s.getInterfaceAsSql(value) return s } @@ -97,7 +96,7 @@ func (s *search) table(name string) *search { return s } -func getInterfaceAsString(value interface{}) (str string, err error) { +func (s *search) getInterfaceAsSql(value interface{}) (str string) { switch value := value.(type) { case string: str = value @@ -108,7 +107,11 @@ func getInterfaceAsString(value interface{}) (str string, err error) { str = strconv.Itoa(value) } default: - err = InvalidSql + s.db.err(InvalidSql) + } + + if !regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str) { + s.db.err(InvalidSql) } return }