Yay, all tests passed

This commit is contained in:
Jinzhu 2013-11-16 18:01:44 +08:00
parent 38f7ecdf15
commit 0ad707b410
4 changed files with 150 additions and 155 deletions

View File

@ -3,8 +3,9 @@ package gorm
import "errors"
var (
RecordNotFound = errors.New("Record Not Found")
InvalidSql = errors.New("Invalid SQL")
NoNewAttrs = errors.New("No new Attributes")
NoValidTransaction = errors.New("No valid transaction")
RecordNotFound = errors.New("Record Not Found")
InvalidSql = errors.New("Invalid SQL")
NoNewAttrs = errors.New("No new Attributes")
NoValidTransaction = errors.New("No valid transaction")
CantStartTransaction = errors.New("Can't start transaction")
)

View File

@ -53,7 +53,7 @@ func (f *Field) sqlTag() (str string) {
}
}
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.tagIdentifier))
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.parent.tagIdentifier))
if typ == "-" {
return

View File

@ -2,6 +2,7 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
@ -1289,145 +1290,145 @@ func TestAutoMigration(t *testing.T) {
}
}
// type NullTime struct {
// Time time.Time
// Valid bool
// }
type NullTime struct {
Time time.Time
Valid bool
}
// func (nt *NullTime) Scan(value interface{}) error {
// if value == nil {
// nt.Valid = false
// return nil
// }
// nt.Time, nt.Valid = value.(time.Time), true
// return nil
// }
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
// func (nt NullTime) Value() (driver.Value, error) {
// if !nt.Valid {
// return nil, nil
// }
// return nt.Time, nil
// }
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
// type NullValue struct {
// Id int64
// Name sql.NullString `sql:"not null"`
// Age sql.NullInt64
// Male sql.NullBool
// Height sql.NullFloat64
// AddedAt NullTime
// }
type NullValue struct {
Id int64
Name sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
}
// func TestSqlNullValue(t *testing.T) {
// db.DropTable(&NullValue{})
// db.AutoMigrate(&NullValue{})
func TestSqlNullValue(t *testing.T) {
db.DropTable(&NullValue{})
db.AutoMigrate(&NullValue{})
// if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
// t.Errorf("Not error should raise when test null value", err)
// }
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
// var nv NullValue
// db.First(&nv, "name = ?", "hello")
var nv NullValue
db.First(&nv, "name = ?", "hello")
// if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
// t.Errorf("Should be able to fetch null value")
// }
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
t.Errorf("Should be able to fetch null value")
}
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
// t.Errorf("Not error should raise when test null value", err)
// }
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
// var nv2 NullValue
// db.First(&nv2, "name = ?", "hello-2")
// if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
// t.Errorf("Should be able to fetch null value")
// }
var nv2 NullValue
db.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value")
}
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
// t.Errorf("Can't save because of name can't be null", err)
// }
// }
if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
t.Errorf("Can't save because of name can't be null", err)
}
}
// func TestTransaction(t *testing.T) {
// d := db.Begin()
// u := User{Name: "transcation"}
// if err := d.Save(&u).Error; err != nil {
// t.Errorf("No error should raise, but got", err)
// }
func TestTransaction(t *testing.T) {
d := db.Begin()
u := User{Name: "transcation"}
if err := d.Save(&u).Error; err != nil {
t.Errorf("No error should raise, but got", err)
}
// if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
// t.Errorf("Should find saved record, but got", err)
// }
if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record, but got", err)
}
// d.Rollback()
d.Rollback()
// if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
// t.Errorf("Should not find record after rollback")
// }
if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
// d2 := db.Begin()
// u2 := User{Name: "transcation-2"}
// if err := d2.Save(&u2).Error; err != nil {
// t.Errorf("No error should raise, but got", err)
// }
// d2.Update("age", 90)
d2 := db.Begin()
u2 := User{Name: "transcation-2"}
if err := d2.Save(&u2).Error; err != nil {
t.Errorf("No error should raise, but got", err)
}
d2.Update("age", 90)
// if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
// t.Errorf("Should find saved record, but got", err)
// }
if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record, but got", err)
}
// d2.Commit()
d2.Commit()
// if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
// t.Errorf("Should be able to find committed record")
// }
// }
if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
}
// func (s *CreditCard) BeforeSave() (err error) {
// if s.Number == "0000" {
// err = errors.New("invalid credit card")
// }
// return
// }
func (s *CreditCard) BeforeSave() (err error) {
if s.Number == "0000" {
err = errors.New("invalid credit card")
}
return
}
// func BenchmarkGorm(b *testing.B) {
// b.N = 5000
// for x := 0; x < b.N; x++ {
// e := strconv.Itoa(x) + "benchmark@example.org"
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// // Insert
// db.Save(&email)
// // Query
// db.First(&BigEmail{}, "email = ?", e)
// // Update
// db.Model(&email).Update("email", "new-"+e)
// // Delete
// db.Delete(&email)
// }
// }
func BenchmarkGorm(b *testing.B) {
b.N = 5000
for x := 0; x < b.N; x++ {
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// Insert
db.Save(&email)
// Query
db.First(&BigEmail{}, "email = ?", e)
// Update
db.Model(&email).Update("email", "new-"+e)
// Delete
db.Delete(&email)
}
}
// func BenchmarkRawSql(b *testing.B) {
// db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db.SetMaxIdleConns(10)
// insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
// query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
// update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
// delete_sql := "DELETE FROM orders WHERE id = $1"
func BenchmarkRawSql(b *testing.B) {
db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
db.SetMaxIdleConns(10)
insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
delete_sql := "DELETE FROM orders WHERE id = $1"
// b.N = 5000
// for x := 0; x < b.N; x++ {
// var id int64
// e := strconv.Itoa(x) + "benchmark@example.org"
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// // Insert
// db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// // Query
// rows, _ := db.Query(query_sql, email.Email)
// rows.Close()
// // Update
// db.Exec(update_sql, "new-"+e, time.Now(), id)
// // Delete
// db.Exec(delete_sql, id)
// }
// }
b.N = 5000
for x := 0; x < b.N; x++ {
var id int64
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// Insert
db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// Query
rows, _ := db.Query(query_sql, email.Email)
rows.Close()
// Update
db.Exec(update_sql, "new-"+e, time.Now(), id)
// Delete
db.Exec(delete_sql, id)
}
}

53
main.go
View File

@ -2,10 +2,9 @@ package gorm
import (
"database/sql"
"errors"
)
import "github.com/jinzhu/gorm/dialect"
"github.com/jinzhu/gorm/dialect"
)
type DB struct {
db sqlCommon
@ -23,12 +22,13 @@ type DB struct {
func Open(driver, source string) (db DB, err error) {
db.db, err = sql.Open(driver, source)
db.dialect = dialect.New(driver)
db.tagIdentifier = "sql"
db.parent = &db
return
}
func (s *DB) SetPool(n int) {
if db, ok := s.db.(sqlDb); ok {
if db, ok := s.parent.db.(sqlDb); ok {
db.SetMaxIdleConns(n)
}
}
@ -103,25 +103,23 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
}
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
if s.clone().First(out, where...).Error != nil {
return s.clone().do(out).where(where).initialize().db
} else {
if len(s.search.assignAttrs) > 0 {
return s.clone().do(out).updateAttrs(s.search.assignAttrs).db
}
c := s.clone()
if c.First(out, where...).Error == RecordNotFound {
return c.do(out).where(where).initialize().db
} else if len(s.search.assignAttrs) > 0 {
return c.do(out).updateAttrs(s.search.assignAttrs).db
}
return s
return c
}
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
if s.clone().First(out, where...).Error != nil {
return s.clone().do(out).where(where...).initialize().db.Save(out)
} else {
if len(s.search.assignAttrs) > 0 {
return s.clone().do(out).updateAttrs(s.search.assignAttrs).update().db
}
c := s.clone()
if c.First(out, where...).Error == RecordNotFound {
return c.do(out).where(where...).initialize().db.Save(out)
} else if len(s.search.assignAttrs) > 0 {
return c.do(out).updateAttrs(s.search.assignAttrs).update().db
}
return s
return c
}
func (s *DB) Update(attrs ...interface{}) *DB {
@ -167,12 +165,10 @@ func (s *DB) Table(name string) *DB {
return s.clone().search.table(name).db
}
// Debug
func (s *DB) Debug() *DB {
return s.clone().LogMode(true)
}
// Transactions
func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok {
@ -180,7 +176,7 @@ func (s *DB) Begin() *DB {
c.db = interface{}(tx).(sqlCommon)
c.err(err)
} else {
c.err(errors.New("Can't start a transaction."))
c.err(CantStartTransaction)
}
return c
}
@ -205,22 +201,19 @@ func (s *DB) Rollback() *DB {
// Migrations
func (s *DB) CreateTable(value interface{}) *DB {
s.do(value).createTable()
return s
return s.clone().do(value).createTable().db
}
func (s *DB) DropTable(value interface{}) *DB {
s.do(value).dropTable()
return s
return s.clone().do(value).dropTable().db
}
func (s *DB) AutoMigrate(value interface{}) *DB {
s.do(value).autoMigrate()
return s
return s.clone().do(value).autoMigrate().db
}
func (s *DB) UpdateColumn(column string, typ string) *DB {
s.do(s.data).updateColumn(column, typ)
s.clone().do(s.data).updateColumn(column, typ)
return s
}
@ -230,11 +223,11 @@ func (s *DB) DropColumn(column string) *DB {
}
func (s *DB) AddIndex(column string, index_name ...string) *DB {
s.do(s.data).addIndex(column, index_name...)
s.clone().do(s.data).addIndex(column, index_name...)
return s
}
func (s *DB) RemoveIndex(column string) *DB {
s.do(s.data).removeIndex(column)
s.clone().do(s.data).removeIndex(column)
return s
}