forked from mirror/gorm
Yay, all tests passed
This commit is contained in:
parent
38f7ecdf15
commit
0ad707b410
|
@ -3,8 +3,9 @@ package gorm
|
|||
import "errors"
|
||||
|
||||
var (
|
||||
RecordNotFound = errors.New("Record Not Found")
|
||||
InvalidSql = errors.New("Invalid SQL")
|
||||
NoNewAttrs = errors.New("No new Attributes")
|
||||
NoValidTransaction = errors.New("No valid transaction")
|
||||
RecordNotFound = errors.New("Record Not Found")
|
||||
InvalidSql = errors.New("Invalid SQL")
|
||||
NoNewAttrs = errors.New("No new Attributes")
|
||||
NoValidTransaction = errors.New("No valid transaction")
|
||||
CantStartTransaction = errors.New("Can't start transaction")
|
||||
)
|
||||
|
|
2
field.go
2
field.go
|
@ -53,7 +53,7 @@ func (f *Field) sqlTag() (str string) {
|
|||
}
|
||||
}
|
||||
|
||||
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.tagIdentifier))
|
||||
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.parent.tagIdentifier))
|
||||
|
||||
if typ == "-" {
|
||||
return
|
||||
|
|
241
gorm_test.go
241
gorm_test.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
|
@ -1289,145 +1290,145 @@ func TestAutoMigration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// type NullTime struct {
|
||||
// Time time.Time
|
||||
// Valid bool
|
||||
// }
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// func (nt *NullTime) Scan(value interface{}) error {
|
||||
// if value == nil {
|
||||
// nt.Valid = false
|
||||
// return nil
|
||||
// }
|
||||
// nt.Time, nt.Valid = value.(time.Time), true
|
||||
// return nil
|
||||
// }
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
nt.Valid = false
|
||||
return nil
|
||||
}
|
||||
nt.Time, nt.Valid = value.(time.Time), true
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (nt NullTime) Value() (driver.Value, error) {
|
||||
// if !nt.Valid {
|
||||
// return nil, nil
|
||||
// }
|
||||
// return nt.Time, nil
|
||||
// }
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
||||
// type NullValue struct {
|
||||
// Id int64
|
||||
// Name sql.NullString `sql:"not null"`
|
||||
// Age sql.NullInt64
|
||||
// Male sql.NullBool
|
||||
// Height sql.NullFloat64
|
||||
// AddedAt NullTime
|
||||
// }
|
||||
type NullValue struct {
|
||||
Id int64
|
||||
Name sql.NullString `sql:"not null"`
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
AddedAt NullTime
|
||||
}
|
||||
|
||||
// func TestSqlNullValue(t *testing.T) {
|
||||
// db.DropTable(&NullValue{})
|
||||
// db.AutoMigrate(&NullValue{})
|
||||
func TestSqlNullValue(t *testing.T) {
|
||||
db.DropTable(&NullValue{})
|
||||
db.AutoMigrate(&NullValue{})
|
||||
|
||||
// if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
|
||||
// t.Errorf("Not error should raise when test null value", err)
|
||||
// }
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
|
||||
t.Errorf("Not error should raise when test null value", err)
|
||||
}
|
||||
|
||||
// var nv NullValue
|
||||
// db.First(&nv, "name = ?", "hello")
|
||||
var nv NullValue
|
||||
db.First(&nv, "name = ?", "hello")
|
||||
|
||||
// if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
|
||||
// t.Errorf("Should be able to fetch null value")
|
||||
// }
|
||||
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
|
||||
t.Errorf("Should be able to fetch null value")
|
||||
}
|
||||
|
||||
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
|
||||
// t.Errorf("Not error should raise when test null value", err)
|
||||
// }
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
|
||||
t.Errorf("Not error should raise when test null value", err)
|
||||
}
|
||||
|
||||
// var nv2 NullValue
|
||||
// db.First(&nv2, "name = ?", "hello-2")
|
||||
// if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
||||
// t.Errorf("Should be able to fetch null value")
|
||||
// }
|
||||
var nv2 NullValue
|
||||
db.First(&nv2, "name = ?", "hello-2")
|
||||
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
||||
t.Errorf("Should be able to fetch null value")
|
||||
}
|
||||
|
||||
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
|
||||
// t.Errorf("Can't save because of name can't be null", err)
|
||||
// }
|
||||
// }
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
|
||||
t.Errorf("Can't save because of name can't be null", err)
|
||||
}
|
||||
}
|
||||
|
||||
// func TestTransaction(t *testing.T) {
|
||||
// d := db.Begin()
|
||||
// u := User{Name: "transcation"}
|
||||
// if err := d.Save(&u).Error; err != nil {
|
||||
// t.Errorf("No error should raise, but got", err)
|
||||
// }
|
||||
func TestTransaction(t *testing.T) {
|
||||
d := db.Begin()
|
||||
u := User{Name: "transcation"}
|
||||
if err := d.Save(&u).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got", err)
|
||||
}
|
||||
|
||||
// if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||
// t.Errorf("Should find saved record, but got", err)
|
||||
// }
|
||||
if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got", err)
|
||||
}
|
||||
|
||||
// d.Rollback()
|
||||
d.Rollback()
|
||||
|
||||
// if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||
// t.Errorf("Should not find record after rollback")
|
||||
// }
|
||||
if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||
t.Errorf("Should not find record after rollback")
|
||||
}
|
||||
|
||||
// d2 := db.Begin()
|
||||
// u2 := User{Name: "transcation-2"}
|
||||
// if err := d2.Save(&u2).Error; err != nil {
|
||||
// t.Errorf("No error should raise, but got", err)
|
||||
// }
|
||||
// d2.Update("age", 90)
|
||||
d2 := db.Begin()
|
||||
u2 := User{Name: "transcation-2"}
|
||||
if err := d2.Save(&u2).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got", err)
|
||||
}
|
||||
d2.Update("age", 90)
|
||||
|
||||
// if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
// t.Errorf("Should find saved record, but got", err)
|
||||
// }
|
||||
if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got", err)
|
||||
}
|
||||
|
||||
// d2.Commit()
|
||||
d2.Commit()
|
||||
|
||||
// if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
// t.Errorf("Should be able to find committed record")
|
||||
// }
|
||||
// }
|
||||
if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should be able to find committed record")
|
||||
}
|
||||
}
|
||||
|
||||
// func (s *CreditCard) BeforeSave() (err error) {
|
||||
// if s.Number == "0000" {
|
||||
// err = errors.New("invalid credit card")
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
func (s *CreditCard) BeforeSave() (err error) {
|
||||
if s.Number == "0000" {
|
||||
err = errors.New("invalid credit card")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// func BenchmarkGorm(b *testing.B) {
|
||||
// b.N = 5000
|
||||
// for x := 0; x < b.N; x++ {
|
||||
// e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
// // Insert
|
||||
// db.Save(&email)
|
||||
// // Query
|
||||
// db.First(&BigEmail{}, "email = ?", e)
|
||||
// // Update
|
||||
// db.Model(&email).Update("email", "new-"+e)
|
||||
// // Delete
|
||||
// db.Delete(&email)
|
||||
// }
|
||||
// }
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
b.N = 5000
|
||||
for x := 0; x < b.N; x++ {
|
||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
// Insert
|
||||
db.Save(&email)
|
||||
// Query
|
||||
db.First(&BigEmail{}, "email = ?", e)
|
||||
// Update
|
||||
db.Model(&email).Update("email", "new-"+e)
|
||||
// Delete
|
||||
db.Delete(&email)
|
||||
}
|
||||
}
|
||||
|
||||
// func BenchmarkRawSql(b *testing.B) {
|
||||
// db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
// db.SetMaxIdleConns(10)
|
||||
// insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
|
||||
// query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
|
||||
// update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
||||
// delete_sql := "DELETE FROM orders WHERE id = $1"
|
||||
func BenchmarkRawSql(b *testing.B) {
|
||||
db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
db.SetMaxIdleConns(10)
|
||||
insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
|
||||
query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
|
||||
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
||||
delete_sql := "DELETE FROM orders WHERE id = $1"
|
||||
|
||||
// b.N = 5000
|
||||
// for x := 0; x < b.N; x++ {
|
||||
// var id int64
|
||||
// e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
// // Insert
|
||||
// db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
||||
// // Query
|
||||
// rows, _ := db.Query(query_sql, email.Email)
|
||||
// rows.Close()
|
||||
// // Update
|
||||
// db.Exec(update_sql, "new-"+e, time.Now(), id)
|
||||
// // Delete
|
||||
// db.Exec(delete_sql, id)
|
||||
// }
|
||||
// }
|
||||
b.N = 5000
|
||||
for x := 0; x < b.N; x++ {
|
||||
var id int64
|
||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
// Insert
|
||||
db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
||||
// Query
|
||||
rows, _ := db.Query(query_sql, email.Email)
|
||||
rows.Close()
|
||||
// Update
|
||||
db.Exec(update_sql, "new-"+e, time.Now(), id)
|
||||
// Delete
|
||||
db.Exec(delete_sql, id)
|
||||
}
|
||||
}
|
||||
|
|
53
main.go
53
main.go
|
@ -2,10 +2,9 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
import "github.com/jinzhu/gorm/dialect"
|
||||
"github.com/jinzhu/gorm/dialect"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db sqlCommon
|
||||
|
@ -23,12 +22,13 @@ type DB struct {
|
|||
func Open(driver, source string) (db DB, err error) {
|
||||
db.db, err = sql.Open(driver, source)
|
||||
db.dialect = dialect.New(driver)
|
||||
db.tagIdentifier = "sql"
|
||||
db.parent = &db
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DB) SetPool(n int) {
|
||||
if db, ok := s.db.(sqlDb); ok {
|
||||
if db, ok := s.parent.db.(sqlDb); ok {
|
||||
db.SetMaxIdleConns(n)
|
||||
}
|
||||
}
|
||||
|
@ -103,25 +103,23 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
|||
}
|
||||
|
||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||
if s.clone().First(out, where...).Error != nil {
|
||||
return s.clone().do(out).where(where).initialize().db
|
||||
} else {
|
||||
if len(s.search.assignAttrs) > 0 {
|
||||
return s.clone().do(out).updateAttrs(s.search.assignAttrs).db
|
||||
}
|
||||
c := s.clone()
|
||||
if c.First(out, where...).Error == RecordNotFound {
|
||||
return c.do(out).where(where).initialize().db
|
||||
} else if len(s.search.assignAttrs) > 0 {
|
||||
return c.do(out).updateAttrs(s.search.assignAttrs).db
|
||||
}
|
||||
return s
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||
if s.clone().First(out, where...).Error != nil {
|
||||
return s.clone().do(out).where(where...).initialize().db.Save(out)
|
||||
} else {
|
||||
if len(s.search.assignAttrs) > 0 {
|
||||
return s.clone().do(out).updateAttrs(s.search.assignAttrs).update().db
|
||||
}
|
||||
c := s.clone()
|
||||
if c.First(out, where...).Error == RecordNotFound {
|
||||
return c.do(out).where(where...).initialize().db.Save(out)
|
||||
} else if len(s.search.assignAttrs) > 0 {
|
||||
return c.do(out).updateAttrs(s.search.assignAttrs).update().db
|
||||
}
|
||||
return s
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||
|
@ -167,12 +165,10 @@ func (s *DB) Table(name string) *DB {
|
|||
return s.clone().search.table(name).db
|
||||
}
|
||||
|
||||
// Debug
|
||||
func (s *DB) Debug() *DB {
|
||||
return s.clone().LogMode(true)
|
||||
}
|
||||
|
||||
// Transactions
|
||||
func (s *DB) Begin() *DB {
|
||||
c := s.clone()
|
||||
if db, ok := c.db.(sqlDb); ok {
|
||||
|
@ -180,7 +176,7 @@ func (s *DB) Begin() *DB {
|
|||
c.db = interface{}(tx).(sqlCommon)
|
||||
c.err(err)
|
||||
} else {
|
||||
c.err(errors.New("Can't start a transaction."))
|
||||
c.err(CantStartTransaction)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
@ -205,22 +201,19 @@ func (s *DB) Rollback() *DB {
|
|||
|
||||
// Migrations
|
||||
func (s *DB) CreateTable(value interface{}) *DB {
|
||||
s.do(value).createTable()
|
||||
return s
|
||||
return s.clone().do(value).createTable().db
|
||||
}
|
||||
|
||||
func (s *DB) DropTable(value interface{}) *DB {
|
||||
s.do(value).dropTable()
|
||||
return s
|
||||
return s.clone().do(value).dropTable().db
|
||||
}
|
||||
|
||||
func (s *DB) AutoMigrate(value interface{}) *DB {
|
||||
s.do(value).autoMigrate()
|
||||
return s
|
||||
return s.clone().do(value).autoMigrate().db
|
||||
}
|
||||
|
||||
func (s *DB) UpdateColumn(column string, typ string) *DB {
|
||||
s.do(s.data).updateColumn(column, typ)
|
||||
s.clone().do(s.data).updateColumn(column, typ)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -230,11 +223,11 @@ func (s *DB) DropColumn(column string) *DB {
|
|||
}
|
||||
|
||||
func (s *DB) AddIndex(column string, index_name ...string) *DB {
|
||||
s.do(s.data).addIndex(column, index_name...)
|
||||
s.clone().do(s.data).addIndex(column, index_name...)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DB) RemoveIndex(column string) *DB {
|
||||
s.do(s.data).removeIndex(column)
|
||||
s.clone().do(s.data).removeIndex(column)
|
||||
return s
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue