mirror of https://github.com/go-gorm/gorm.git
Better do.go
This commit is contained in:
parent
e0d3100294
commit
38f7ecdf15
104
do.go
104
do.go
|
@ -18,6 +18,7 @@ type Do struct {
|
||||||
search *search
|
search *search
|
||||||
model *Model
|
model *Model
|
||||||
tableName string
|
tableName string
|
||||||
|
usingUpdate bool
|
||||||
value interface{}
|
value interface{}
|
||||||
update_attrs map[string]interface{}
|
update_attrs map[string]interface{}
|
||||||
hasUpdate bool
|
hasUpdate bool
|
||||||
|
@ -69,7 +70,7 @@ func (s *Do) trace(t time.Time) {
|
||||||
s.db.slog(s.sql, t, s.sqlVars...)
|
s.db.slog(s.sql, t, s.sqlVars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) exec(sqls ...string) {
|
func (s *Do) exec(sqls ...string) *Do {
|
||||||
defer s.trace(time.Now())
|
defer s.trace(time.Now())
|
||||||
if !s.db.hasError() {
|
if !s.db.hasError() {
|
||||||
if len(sqls) > 0 {
|
if len(sqls) > 0 {
|
||||||
|
@ -78,6 +79,7 @@ func (s *Do) exec(sqls ...string) {
|
||||||
_, err := s.db.db.Exec(s.sql, s.sqlVars...)
|
_, err := s.db.db.Exec(s.sql, s.sqlVars...)
|
||||||
s.err(err)
|
s.err(err)
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) save() *Do {
|
func (s *Do) save() *Do {
|
||||||
|
@ -203,6 +205,7 @@ func (s *Do) create() (i interface{}) {
|
||||||
|
|
||||||
func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
|
func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
|
||||||
ignore_protected := len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
|
ignore_protected := len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
|
||||||
|
s.usingUpdate = true
|
||||||
|
|
||||||
switch value := values.(type) {
|
switch value := values.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
|
@ -210,9 +213,8 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do
|
||||||
results, has_update := s.model.updatedColumnsAndValues(value, ignore_protected)
|
results, has_update := s.model.updatedColumnsAndValues(value, ignore_protected)
|
||||||
if len(results) > 0 {
|
if len(results) > 0 {
|
||||||
s.update_attrs = results
|
s.update_attrs = results
|
||||||
} else if has_update {
|
|
||||||
s.hasUpdate = has_update
|
|
||||||
}
|
}
|
||||||
|
s.hasUpdate = has_update
|
||||||
}
|
}
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
for _, v := range value {
|
for _, v := range value {
|
||||||
|
@ -226,8 +228,6 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do
|
||||||
}
|
}
|
||||||
s.updateAttrs(attrs)
|
s.updateAttrs(attrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,6 +251,10 @@ func (s *Do) prepareUpdateSql() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) update() *Do {
|
func (s *Do) update() *Do {
|
||||||
|
if s.usingUpdate && !s.hasUpdate {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
s.model.callMethod("BeforeUpdate")
|
s.model.callMethod("BeforeUpdate")
|
||||||
s.model.callMethod("BeforeSave")
|
s.model.callMethod("BeforeSave")
|
||||||
s.saveBeforeAssociations()
|
s.saveBeforeAssociations()
|
||||||
|
@ -288,14 +292,16 @@ func (s *Do) prepareQuerySql() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) first() {
|
func (s *Do) first() *Do {
|
||||||
s.search.order(s.model.primaryKeyDb()).limit(1)
|
s.search = s.search.clone().order(s.model.primaryKeyDb()).limit(1)
|
||||||
s.query()
|
s.query()
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) last() {
|
func (s *Do) last() *Do {
|
||||||
s.search.order(s.model.primaryKeyDb() + " DESC").limit(1)
|
s.search = s.search.clone().order(s.model.primaryKeyDb() + " DESC").limit(1)
|
||||||
s.query()
|
s.query()
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) {
|
func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) {
|
||||||
|
@ -314,7 +320,7 @@ func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err erro
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) related(value interface{}, foreign_keys ...string) {
|
func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
|
||||||
var foreign_value interface{}
|
var foreign_value interface{}
|
||||||
var from_from bool
|
var from_from bool
|
||||||
var foreign_key string
|
var foreign_key string
|
||||||
|
@ -338,9 +344,10 @@ func (s *Do) related(value interface{}, foreign_keys ...string) {
|
||||||
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value))
|
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value))
|
||||||
s.where(query).query()
|
s.where(query).query()
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) query() {
|
func (s *Do) query() *Do {
|
||||||
var (
|
var (
|
||||||
is_slice bool
|
is_slice bool
|
||||||
dest_type reflect.Type
|
dest_type reflect.Type
|
||||||
|
@ -351,7 +358,7 @@ func (s *Do) query() {
|
||||||
is_slice = true
|
is_slice = true
|
||||||
dest_type = dest_out.Type().Elem()
|
dest_type = dest_out.Type().Elem()
|
||||||
} else {
|
} else {
|
||||||
s.search.limit(1)
|
s.search = s.search.clone().limit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
|
@ -361,7 +368,7 @@ func (s *Do) query() {
|
||||||
s.db.slog(s.sql, now, s.sqlVars...)
|
s.db.slog(s.sql, now, s.sqlVars...)
|
||||||
|
|
||||||
if s.err(err) != nil {
|
if s.err(err) != nil {
|
||||||
return
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
@ -395,23 +402,26 @@ func (s *Do) query() {
|
||||||
s.err(RecordNotFound)
|
s.err(RecordNotFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) count(value interface{}) {
|
func (s *Do) count(value interface{}) *Do {
|
||||||
|
s.search = s.search.clone().selects("count(*)")
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
if !s.db.hasError() {
|
if !s.db.hasError() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
|
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
|
||||||
s.db.slog(s.sql, now, s.sqlVars...)
|
s.db.slog(s.sql, now, s.sqlVars...)
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) pluck(column string, value interface{}) {
|
func (s *Do) pluck(column string, value interface{}) *Do {
|
||||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
s.search = s.search.clone().selects(column)
|
||||||
if dest_out.Kind() != reflect.Slice {
|
if dest_out.Kind() != reflect.Slice {
|
||||||
s.err(errors.New("Results should be a slice"))
|
s.err(errors.New("Results should be a slice"))
|
||||||
return
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
|
@ -430,6 +440,7 @@ func (s *Do) pluck(column string, value interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) primaryCondiation(value interface{}) string {
|
func (s *Do) primaryCondiation(value interface{}) string {
|
||||||
|
@ -474,7 +485,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
values := reflect.ValueOf(arg)
|
values := reflect.ValueOf(arg)
|
||||||
var temp_marks []string
|
var temp_marks []string
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
|
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||||
default:
|
default:
|
||||||
|
@ -533,7 +544,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
values := reflect.ValueOf(arg)
|
values := reflect.ValueOf(arg)
|
||||||
var temp_marks []string
|
var temp_marks []string
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
|
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||||
default:
|
default:
|
||||||
|
@ -546,6 +557,13 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Do) where(where ...interface{}) *Do {
|
||||||
|
if len(where) > 0 {
|
||||||
|
s.search = s.search.clone().where(where[0], where[1:]...)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Do) whereSql() (sql string) {
|
func (s *Do) whereSql() (sql string) {
|
||||||
var primary_condiations, and_conditions, or_conditions []string
|
var primary_condiations, and_conditions, or_conditions []string
|
||||||
|
|
||||||
|
@ -709,7 +727,7 @@ func (s *Do) begin() *Do {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) commit_or_rollback() {
|
func (s *Do) commit_or_rollback() *Do {
|
||||||
if s.startedTransaction {
|
if s.startedTransaction {
|
||||||
if db, ok := s.db.db.(sqlTx); ok {
|
if db, ok := s.db.db.(sqlTx); ok {
|
||||||
if s.db.hasError() {
|
if s.db.hasError() {
|
||||||
|
@ -717,47 +735,21 @@ func (s *Do) commit_or_rollback() {
|
||||||
} else {
|
} else {
|
||||||
db.Commit()
|
db.Commit()
|
||||||
}
|
}
|
||||||
|
s.db.db = s.db.parent.db
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Do) where(where ...interface{}) *Do {
|
|
||||||
if len(where) > 0 {
|
|
||||||
s.search.where(where[0], where[1:])
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) initialize() {
|
func (s *Do) initialize() *Do {
|
||||||
// TODO initializeWithSearchCondition
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Do) initializeWithSearchCondition() {
|
|
||||||
for _, clause := range s.search.whereClause {
|
for _, clause := range s.search.whereClause {
|
||||||
switch value := clause["query"].(type) {
|
s.updateAttrs(clause["query"])
|
||||||
case map[string]interface{}:
|
|
||||||
for k, v := range value {
|
|
||||||
s.model.setValueByColumn(k, v, s.value)
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for _, obj := range value {
|
|
||||||
switch reflect.ValueOf(obj).Kind() {
|
|
||||||
case reflect.Struct:
|
|
||||||
m := &Model{data: obj, do: s}
|
|
||||||
for _, field := range m.columnsHasValue("other") {
|
|
||||||
s.model.setValueByColumn(field.dbName, field.Value, s.value)
|
|
||||||
}
|
|
||||||
case reflect.Map:
|
|
||||||
for key, value := range obj.(map[string]interface{}) {
|
|
||||||
s.model.setValueByColumn(key, value, s.value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case interface{}:
|
|
||||||
m := &Model{data: value, do: s}
|
|
||||||
for _, field := range m.columnsHasValue("other") {
|
|
||||||
s.model.setValueByColumn(field.dbName, field.Value, s.value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
for _, attrs := range s.search.initAttrs {
|
||||||
|
s.updateAttrs(attrs)
|
||||||
|
}
|
||||||
|
for _, attrs := range s.search.assignAttrs {
|
||||||
|
s.updateAttrs(attrs)
|
||||||
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
262
gorm_test.go
262
gorm_test.go
|
@ -2,7 +2,6 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
@ -147,7 +146,7 @@ func init() {
|
||||||
func TestSaveAndFind(t *testing.T) {
|
func TestSaveAndFind(t *testing.T) {
|
||||||
name := "save_and_find"
|
name := "save_and_find"
|
||||||
u := &User{Name: name, Age: 1}
|
u := &User{Name: name, Age: 1}
|
||||||
db.Debug().Save(u)
|
db.Save(u)
|
||||||
if u.Id == 0 {
|
if u.Id == 0 {
|
||||||
t.Errorf("Should have ID after create record")
|
t.Errorf("Should have ID after create record")
|
||||||
}
|
}
|
||||||
|
@ -1242,29 +1241,16 @@ func (c Cart) TableName() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTableName(t *testing.T) {
|
func TestTableName(t *testing.T) {
|
||||||
var table string
|
if db.do(Order{}).table() != "orders" {
|
||||||
|
|
||||||
model := &Model{data: Order{}}
|
|
||||||
table = model.tableName()
|
|
||||||
if table != "orders" {
|
|
||||||
t.Errorf("Order table name should be orders")
|
t.Errorf("Order table name should be orders")
|
||||||
}
|
}
|
||||||
|
|
||||||
db.SingularTable(true)
|
db.SingularTable(true)
|
||||||
table = model.tableName()
|
if db.do(Order{}).table() != "order" {
|
||||||
if table != "order" {
|
|
||||||
t.Errorf("Order's singular table name should be order")
|
t.Errorf("Order's singular table name should be order")
|
||||||
}
|
}
|
||||||
|
|
||||||
model2 := &Model{data: Cart{}}
|
if db.do(&Cart{}).table() != "shopping_cart" {
|
||||||
table = model2.tableName()
|
|
||||||
if table != "shopping_cart" {
|
|
||||||
t.Errorf("Cart's singular table name should be shopping_cart")
|
|
||||||
}
|
|
||||||
|
|
||||||
model3 := &Model{data: &Cart{}}
|
|
||||||
table = model3.tableName()
|
|
||||||
if table != "shopping_cart" {
|
|
||||||
t.Errorf("Cart's singular table name should be shopping_cart")
|
t.Errorf("Cart's singular table name should be shopping_cart")
|
||||||
}
|
}
|
||||||
db.SingularTable(false)
|
db.SingularTable(false)
|
||||||
|
@ -1303,145 +1289,145 @@ func TestAutoMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type NullTime struct {
|
// type NullTime struct {
|
||||||
Time time.Time
|
// Time time.Time
|
||||||
Valid bool
|
// Valid bool
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (nt *NullTime) Scan(value interface{}) error {
|
// func (nt *NullTime) Scan(value interface{}) error {
|
||||||
if value == nil {
|
// if value == nil {
|
||||||
nt.Valid = false
|
// nt.Valid = false
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
nt.Time, nt.Valid = value.(time.Time), true
|
// nt.Time, nt.Valid = value.(time.Time), true
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (nt NullTime) Value() (driver.Value, error) {
|
// func (nt NullTime) Value() (driver.Value, error) {
|
||||||
if !nt.Valid {
|
// if !nt.Valid {
|
||||||
return nil, nil
|
// return nil, nil
|
||||||
}
|
// }
|
||||||
return nt.Time, nil
|
// return nt.Time, nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
type NullValue struct {
|
// type NullValue struct {
|
||||||
Id int64
|
// Id int64
|
||||||
Name sql.NullString `sql:"not null"`
|
// Name sql.NullString `sql:"not null"`
|
||||||
Age sql.NullInt64
|
// Age sql.NullInt64
|
||||||
Male sql.NullBool
|
// Male sql.NullBool
|
||||||
Height sql.NullFloat64
|
// Height sql.NullFloat64
|
||||||
AddedAt NullTime
|
// AddedAt NullTime
|
||||||
}
|
// }
|
||||||
|
|
||||||
func TestSqlNullValue(t *testing.T) {
|
// func TestSqlNullValue(t *testing.T) {
|
||||||
db.DropTable(&NullValue{})
|
// db.DropTable(&NullValue{})
|
||||||
db.AutoMigrate(&NullValue{})
|
// db.AutoMigrate(&NullValue{})
|
||||||
|
|
||||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
|
// if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
|
||||||
t.Errorf("Not error should raise when test null value", err)
|
// t.Errorf("Not error should raise when test null value", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
var nv NullValue
|
// var nv NullValue
|
||||||
db.First(&nv, "name = ?", "hello")
|
// db.First(&nv, "name = ?", "hello")
|
||||||
|
|
||||||
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
|
// if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
|
||||||
t.Errorf("Should be able to fetch null value")
|
// t.Errorf("Should be able to fetch null value")
|
||||||
}
|
// }
|
||||||
|
|
||||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
|
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
|
||||||
t.Errorf("Not error should raise when test null value", err)
|
// t.Errorf("Not error should raise when test null value", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
var nv2 NullValue
|
// var nv2 NullValue
|
||||||
db.First(&nv2, "name = ?", "hello-2")
|
// db.First(&nv2, "name = ?", "hello-2")
|
||||||
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
// if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
||||||
t.Errorf("Should be able to fetch null value")
|
// t.Errorf("Should be able to fetch null value")
|
||||||
}
|
// }
|
||||||
|
|
||||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
|
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
|
||||||
t.Errorf("Can't save because of name can't be null", err)
|
// t.Errorf("Can't save because of name can't be null", err)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func TestTransaction(t *testing.T) {
|
// func TestTransaction(t *testing.T) {
|
||||||
d := db.Begin()
|
// d := db.Begin()
|
||||||
u := User{Name: "transcation"}
|
// u := User{Name: "transcation"}
|
||||||
if err := d.Save(&u).Error; err != nil {
|
// if err := d.Save(&u).Error; err != nil {
|
||||||
t.Errorf("No error should raise, but got", err)
|
// t.Errorf("No error should raise, but got", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
// if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||||
t.Errorf("Should find saved record, but got", err)
|
// t.Errorf("Should find saved record, but got", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
d.Rollback()
|
// d.Rollback()
|
||||||
|
|
||||||
if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
// if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||||
t.Errorf("Should not find record after rollback")
|
// t.Errorf("Should not find record after rollback")
|
||||||
}
|
// }
|
||||||
|
|
||||||
d2 := db.Begin()
|
// d2 := db.Begin()
|
||||||
u2 := User{Name: "transcation-2"}
|
// u2 := User{Name: "transcation-2"}
|
||||||
if err := d2.Save(&u2).Error; err != nil {
|
// if err := d2.Save(&u2).Error; err != nil {
|
||||||
t.Errorf("No error should raise, but got", err)
|
// t.Errorf("No error should raise, but got", err)
|
||||||
}
|
// }
|
||||||
d2.Update("age", 90)
|
// d2.Update("age", 90)
|
||||||
|
|
||||||
if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
// if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||||
t.Errorf("Should find saved record, but got", err)
|
// t.Errorf("Should find saved record, but got", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
d2.Commit()
|
// d2.Commit()
|
||||||
|
|
||||||
if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
// if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||||
t.Errorf("Should be able to find committed record")
|
// t.Errorf("Should be able to find committed record")
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (s *CreditCard) BeforeSave() (err error) {
|
// func (s *CreditCard) BeforeSave() (err error) {
|
||||||
if s.Number == "0000" {
|
// if s.Number == "0000" {
|
||||||
err = errors.New("invalid credit card")
|
// err = errors.New("invalid credit card")
|
||||||
}
|
// }
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
// func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 5000
|
// b.N = 5000
|
||||||
for x := 0; x < b.N; x++ {
|
// for x := 0; x < b.N; x++ {
|
||||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
// e := strconv.Itoa(x) + "benchmark@example.org"
|
||||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||||
// Insert
|
// // Insert
|
||||||
db.Save(&email)
|
// db.Save(&email)
|
||||||
// Query
|
// // Query
|
||||||
db.First(&BigEmail{}, "email = ?", e)
|
// db.First(&BigEmail{}, "email = ?", e)
|
||||||
// Update
|
// // Update
|
||||||
db.Model(&email).Update("email", "new-"+e)
|
// db.Model(&email).Update("email", "new-"+e)
|
||||||
// Delete
|
// // Delete
|
||||||
db.Delete(&email)
|
// db.Delete(&email)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func BenchmarkRawSql(b *testing.B) {
|
// func BenchmarkRawSql(b *testing.B) {
|
||||||
db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
// db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
db.SetMaxIdleConns(10)
|
// db.SetMaxIdleConns(10)
|
||||||
insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
|
// insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
|
||||||
query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
|
// query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
|
||||||
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
// update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
||||||
delete_sql := "DELETE FROM orders WHERE id = $1"
|
// delete_sql := "DELETE FROM orders WHERE id = $1"
|
||||||
|
|
||||||
b.N = 5000
|
// b.N = 5000
|
||||||
for x := 0; x < b.N; x++ {
|
// for x := 0; x < b.N; x++ {
|
||||||
var id int64
|
// var id int64
|
||||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
// e := strconv.Itoa(x) + "benchmark@example.org"
|
||||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||||
// Insert
|
// // Insert
|
||||||
db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
// db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
||||||
// Query
|
// // Query
|
||||||
rows, _ := db.Query(query_sql, email.Email)
|
// rows, _ := db.Query(query_sql, email.Email)
|
||||||
rows.Close()
|
// rows.Close()
|
||||||
// Update
|
// // Update
|
||||||
db.Exec(update_sql, "new-"+e, time.Now(), id)
|
// db.Exec(update_sql, "new-"+e, time.Now(), id)
|
||||||
// Delete
|
// // Delete
|
||||||
db.Exec(delete_sql, id)
|
// db.Exec(delete_sql, id)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
53
main.go
53
main.go
|
@ -41,8 +41,9 @@ func (s *DB) SetLogger(l Logger) {
|
||||||
s.parent.logger = l
|
s.parent.logger = l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) LogMode(b bool) {
|
func (s *DB) LogMode(b bool) *DB {
|
||||||
s.logMode = b
|
s.logMode = b
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) SingularTable(b bool) {
|
func (s *DB) SingularTable(b bool) {
|
||||||
|
@ -54,7 +55,7 @@ func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.where(query, args...).db
|
return s.clone().search.or(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
||||||
|
@ -82,18 +83,15 @@ func (s *DB) Unscoped() *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
s.clone().do(out).where(where...).first()
|
return s.clone().do(out).where(where...).first().db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
s.clone().do(out).where(where...).last()
|
return s.clone().do(out).where(where...).last().db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||||
s.clone().do(out).where(where...).query()
|
return s.clone().do(out).where(where...).query().db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
||||||
|
@ -105,23 +103,22 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
if s.First(out, where...).Error != nil {
|
if s.clone().First(out, where...).Error != nil {
|
||||||
s.clone().do(out).where(where).initialize()
|
return s.clone().do(out).where(where).initialize().db
|
||||||
} else {
|
} else {
|
||||||
if len(s.search.assignAttrs) > 0 {
|
if len(s.search.assignAttrs) > 0 {
|
||||||
s.do(out).updateAttrs(s.search.assignAttrs) //updated or not
|
return s.clone().do(out).updateAttrs(s.search.assignAttrs).db
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||||
if s.First(out, where...).Error != nil {
|
if s.clone().First(out, where...).Error != nil {
|
||||||
s.clone().do(out).where(where...).initialize()
|
return s.clone().do(out).where(where...).initialize().db.Save(out)
|
||||||
s.Save(out)
|
|
||||||
} else {
|
} else {
|
||||||
if len(s.search.assignAttrs) > 0 {
|
if len(s.search.assignAttrs) > 0 {
|
||||||
s.do(out).updateAttrs(s.search.assignAttrs).update()
|
return s.clone().do(out).updateAttrs(s.search.assignAttrs).update().db
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
|
@ -132,23 +129,19 @@ func (s *DB) Update(attrs ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
|
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
|
||||||
s.clone().do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
|
return s.clone().do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback().db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Save(value interface{}) *DB {
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
s.clone().do(value).begin().save().commit_or_rollback()
|
return s.clone().do(value).begin().save().commit_or_rollback().db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Delete(value interface{}) *DB {
|
func (s *DB) Delete(value interface{}) *DB {
|
||||||
s.clone().do(value).begin().delete().commit_or_rollback()
|
return s.clone().do(value).begin().delete().commit_or_rollback().db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Exec(sql string) *DB {
|
func (s *DB) Exec(sql string) *DB {
|
||||||
s.do(nil).exec(sql)
|
return s.do(nil).exec(sql).db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Model(value interface{}) *DB {
|
func (s *DB) Model(value interface{}) *DB {
|
||||||
|
@ -158,18 +151,16 @@ func (s *DB) Model(value interface{}) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Related(value interface{}, foreign_keys ...string) *DB {
|
func (s *DB) Related(value interface{}, foreign_keys ...string) *DB {
|
||||||
s.clone().do(value).related(s.data, foreign_keys...)
|
old_data := s.data
|
||||||
return s
|
return s.do(value).related(old_data, foreign_keys...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Pluck(column string, value interface{}) *DB {
|
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||||
s.clone().search.selects(column).db.do(s.data).pluck(column, value)
|
return s.do(s.data).pluck(column, value).db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Count(value interface{}) *DB {
|
func (s *DB) Count(value interface{}) *DB {
|
||||||
s.clone().search.selects("count(*)").db.do(s.data).count(value)
|
return s.do(s.data).count(value).db
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Table(name string) *DB {
|
func (s *DB) Table(name string) *DB {
|
||||||
|
@ -178,9 +169,7 @@ func (s *DB) Table(name string) *DB {
|
||||||
|
|
||||||
// Debug
|
// Debug
|
||||||
func (s *DB) Debug() *DB {
|
func (s *DB) Debug() *DB {
|
||||||
c := s.clone()
|
return s.clone().LogMode(true)
|
||||||
c.logMode = true
|
|
||||||
return c
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transactions
|
// Transactions
|
||||||
|
|
2
model.go
2
model.go
|
@ -191,7 +191,7 @@ func (m *Model) tableName() (str string) {
|
||||||
|
|
||||||
str = toSnake(m.typeName())
|
str = toSnake(m.typeName())
|
||||||
|
|
||||||
if !m.do.db.singularTable {
|
if !m.do.db.parent.singularTable {
|
||||||
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
|
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
|
||||||
for key, value := range pluralMap {
|
for key, value := range pluralMap {
|
||||||
reg := regexp.MustCompile(key + "$")
|
reg := regexp.MustCompile(key + "$")
|
||||||
|
|
|
@ -8,11 +8,12 @@ import (
|
||||||
func (s *DB) clone() *DB {
|
func (s *DB) clone() *DB {
|
||||||
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, data: s.data, Error: s.Error}
|
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, data: s.data, Error: s.Error}
|
||||||
|
|
||||||
if s.parent.search == nil {
|
if s.search == nil {
|
||||||
db.search = &search{}
|
db.search = &search{}
|
||||||
} else {
|
} else {
|
||||||
db.search = s.parent.search.clone()
|
db.search = s.search.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.search.db = &db
|
db.search.db = &db
|
||||||
return &db
|
return &db
|
||||||
}
|
}
|
||||||
|
|
27
search.go
27
search.go
|
@ -1,6 +1,10 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "strconv"
|
import (
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
type search struct {
|
type search struct {
|
||||||
db *DB
|
db *DB
|
||||||
|
@ -29,6 +33,7 @@ func (s *search) clone() *search {
|
||||||
offsetStr: s.offsetStr,
|
offsetStr: s.offsetStr,
|
||||||
limitStr: s.limitStr,
|
limitStr: s.limitStr,
|
||||||
unscope: s.unscope,
|
unscope: s.unscope,
|
||||||
|
tableName: s.tableName,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,23 +72,17 @@ func (s *search) order(value string, reorder ...bool) *search {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) selects(value interface{}) *search {
|
func (s *search) selects(value interface{}) *search {
|
||||||
if str, err := getInterfaceAsString(value); err == nil {
|
s.selectStr = s.getInterfaceAsSql(value)
|
||||||
s.selectStr = str
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) limit(value interface{}) *search {
|
func (s *search) limit(value interface{}) *search {
|
||||||
if str, err := getInterfaceAsString(value); err == nil {
|
s.limitStr = s.getInterfaceAsSql(value)
|
||||||
s.limitStr = str
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) offset(value interface{}) *search {
|
func (s *search) offset(value interface{}) *search {
|
||||||
if str, err := getInterfaceAsString(value); err == nil {
|
s.offsetStr = s.getInterfaceAsSql(value)
|
||||||
s.offsetStr = str
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ func (s *search) table(name string) *search {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInterfaceAsString(value interface{}) (str string, err error) {
|
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
|
||||||
switch value := value.(type) {
|
switch value := value.(type) {
|
||||||
case string:
|
case string:
|
||||||
str = value
|
str = value
|
||||||
|
@ -108,7 +107,11 @@ func getInterfaceAsString(value interface{}) (str string, err error) {
|
||||||
str = strconv.Itoa(value)
|
str = strconv.Itoa(value)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
err = InvalidSql
|
s.db.err(InvalidSql)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str) {
|
||||||
|
s.db.err(InvalidSql)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue