Yay, all tests passed

This commit is contained in:
Jinzhu 2013-11-16 18:01:44 +08:00
parent 38f7ecdf15
commit 0ad707b410
4 changed files with 150 additions and 155 deletions

View File

@ -7,4 +7,5 @@ var (
InvalidSql = errors.New("Invalid SQL") InvalidSql = errors.New("Invalid SQL")
NoNewAttrs = errors.New("No new Attributes") NoNewAttrs = errors.New("No new Attributes")
NoValidTransaction = errors.New("No valid transaction") NoValidTransaction = errors.New("No valid transaction")
CantStartTransaction = errors.New("Can't start transaction")
) )

View File

@ -53,7 +53,7 @@ func (f *Field) sqlTag() (str string) {
} }
} }
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.tagIdentifier)) typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.parent.tagIdentifier))
if typ == "-" { if typ == "-" {
return return

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@ -1289,145 +1290,145 @@ func TestAutoMigration(t *testing.T) {
} }
} }
// type NullTime struct { type NullTime struct {
// Time time.Time Time time.Time
// Valid bool Valid bool
// } }
// func (nt *NullTime) Scan(value interface{}) error { func (nt *NullTime) Scan(value interface{}) error {
// if value == nil { if value == nil {
// nt.Valid = false nt.Valid = false
// return nil return nil
// } }
// nt.Time, nt.Valid = value.(time.Time), true nt.Time, nt.Valid = value.(time.Time), true
// return nil return nil
// } }
// func (nt NullTime) Value() (driver.Value, error) { func (nt NullTime) Value() (driver.Value, error) {
// if !nt.Valid { if !nt.Valid {
// return nil, nil return nil, nil
// } }
// return nt.Time, nil return nt.Time, nil
// } }
// type NullValue struct { type NullValue struct {
// Id int64 Id int64
// Name sql.NullString `sql:"not null"` Name sql.NullString `sql:"not null"`
// Age sql.NullInt64 Age sql.NullInt64
// Male sql.NullBool Male sql.NullBool
// Height sql.NullFloat64 Height sql.NullFloat64
// AddedAt NullTime AddedAt NullTime
// } }
// func TestSqlNullValue(t *testing.T) { func TestSqlNullValue(t *testing.T) {
// db.DropTable(&NullValue{}) db.DropTable(&NullValue{})
// db.AutoMigrate(&NullValue{}) db.AutoMigrate(&NullValue{})
// if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil { if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil {
// t.Errorf("Not error should raise when test null value", err) t.Errorf("Not error should raise when test null value", err)
// } }
// var nv NullValue var nv NullValue
// db.First(&nv, "name = ?", "hello") db.First(&nv, "name = ?", "hello")
// if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
// t.Errorf("Should be able to fetch null value") t.Errorf("Should be able to fetch null value")
// } }
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil { if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil {
// t.Errorf("Not error should raise when test null value", err) t.Errorf("Not error should raise when test null value", err)
// } }
// var nv2 NullValue var nv2 NullValue
// db.First(&nv2, "name = ?", "hello-2") db.First(&nv2, "name = ?", "hello-2")
// if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
// t.Errorf("Should be able to fetch null value") t.Errorf("Should be able to fetch null value")
// } }
// if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil { if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
// t.Errorf("Can't save because of name can't be null", err) t.Errorf("Can't save because of name can't be null", err)
// } }
// } }
// func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
// d := db.Begin() d := db.Begin()
// u := User{Name: "transcation"} u := User{Name: "transcation"}
// if err := d.Save(&u).Error; err != nil { if err := d.Save(&u).Error; err != nil {
// t.Errorf("No error should raise, but got", err) t.Errorf("No error should raise, but got", err)
// } }
// if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil { if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
// t.Errorf("Should find saved record, but got", err) t.Errorf("Should find saved record, but got", err)
// } }
// d.Rollback() d.Rollback()
// if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil { if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
// t.Errorf("Should not find record after rollback") t.Errorf("Should not find record after rollback")
// } }
// d2 := db.Begin() d2 := db.Begin()
// u2 := User{Name: "transcation-2"} u2 := User{Name: "transcation-2"}
// if err := d2.Save(&u2).Error; err != nil { if err := d2.Save(&u2).Error; err != nil {
// t.Errorf("No error should raise, but got", err) t.Errorf("No error should raise, but got", err)
// } }
// d2.Update("age", 90) d2.Update("age", 90)
// if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
// t.Errorf("Should find saved record, but got", err) t.Errorf("Should find saved record, but got", err)
// } }
// d2.Commit() d2.Commit()
// if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil { if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
// t.Errorf("Should be able to find committed record") t.Errorf("Should be able to find committed record")
// } }
// } }
// func (s *CreditCard) BeforeSave() (err error) { func (s *CreditCard) BeforeSave() (err error) {
// if s.Number == "0000" { if s.Number == "0000" {
// err = errors.New("invalid credit card") err = errors.New("invalid credit card")
// } }
// return return
// } }
// func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
// b.N = 5000 b.N = 5000
// for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
// e := strconv.Itoa(x) + "benchmark@example.org" e := strconv.Itoa(x) + "benchmark@example.org"
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// // Insert // Insert
// db.Save(&email) db.Save(&email)
// // Query // Query
// db.First(&BigEmail{}, "email = ?", e) db.First(&BigEmail{}, "email = ?", e)
// // Update // Update
// db.Model(&email).Update("email", "new-"+e) db.Model(&email).Update("email", "new-"+e)
// // Delete // Delete
// db.Delete(&email) db.Delete(&email)
// } }
// } }
// func BenchmarkRawSql(b *testing.B) { func BenchmarkRawSql(b *testing.B) {
// db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db.SetMaxIdleConns(10) db.SetMaxIdleConns(10)
// insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
// query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
// update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
// delete_sql := "DELETE FROM orders WHERE id = $1" delete_sql := "DELETE FROM orders WHERE id = $1"
// b.N = 5000 b.N = 5000
// for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
// var id int64 var id int64
// e := strconv.Itoa(x) + "benchmark@example.org" e := strconv.Itoa(x) + "benchmark@example.org"
// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// // Insert // Insert
// db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// // Query // Query
// rows, _ := db.Query(query_sql, email.Email) rows, _ := db.Query(query_sql, email.Email)
// rows.Close() rows.Close()
// // Update // Update
// db.Exec(update_sql, "new-"+e, time.Now(), id) db.Exec(update_sql, "new-"+e, time.Now(), id)
// // Delete // Delete
// db.Exec(delete_sql, id) db.Exec(delete_sql, id)
// } }
// } }

53
main.go
View File

@ -2,10 +2,9 @@ package gorm
import ( import (
"database/sql" "database/sql"
"errors"
)
import "github.com/jinzhu/gorm/dialect" "github.com/jinzhu/gorm/dialect"
)
type DB struct { type DB struct {
db sqlCommon db sqlCommon
@ -23,12 +22,13 @@ type DB struct {
func Open(driver, source string) (db DB, err error) { func Open(driver, source string) (db DB, err error) {
db.db, err = sql.Open(driver, source) db.db, err = sql.Open(driver, source)
db.dialect = dialect.New(driver) db.dialect = dialect.New(driver)
db.tagIdentifier = "sql"
db.parent = &db db.parent = &db
return return
} }
func (s *DB) SetPool(n int) { func (s *DB) SetPool(n int) {
if db, ok := s.db.(sqlDb); ok { if db, ok := s.parent.db.(sqlDb); ok {
db.SetMaxIdleConns(n) db.SetMaxIdleConns(n)
} }
} }
@ -103,25 +103,23 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
} }
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
if s.clone().First(out, where...).Error != nil { c := s.clone()
return s.clone().do(out).where(where).initialize().db if c.First(out, where...).Error == RecordNotFound {
} else { return c.do(out).where(where).initialize().db
if len(s.search.assignAttrs) > 0 { } else if len(s.search.assignAttrs) > 0 {
return s.clone().do(out).updateAttrs(s.search.assignAttrs).db return c.do(out).updateAttrs(s.search.assignAttrs).db
} }
} return c
return s
} }
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
if s.clone().First(out, where...).Error != nil { c := s.clone()
return s.clone().do(out).where(where...).initialize().db.Save(out) if c.First(out, where...).Error == RecordNotFound {
} else { return c.do(out).where(where...).initialize().db.Save(out)
if len(s.search.assignAttrs) > 0 { } else if len(s.search.assignAttrs) > 0 {
return s.clone().do(out).updateAttrs(s.search.assignAttrs).update().db return c.do(out).updateAttrs(s.search.assignAttrs).update().db
} }
} return c
return s
} }
func (s *DB) Update(attrs ...interface{}) *DB { func (s *DB) Update(attrs ...interface{}) *DB {
@ -167,12 +165,10 @@ func (s *DB) Table(name string) *DB {
return s.clone().search.table(name).db return s.clone().search.table(name).db
} }
// Debug
func (s *DB) Debug() *DB { func (s *DB) Debug() *DB {
return s.clone().LogMode(true) return s.clone().LogMode(true)
} }
// Transactions
func (s *DB) Begin() *DB { func (s *DB) Begin() *DB {
c := s.clone() c := s.clone()
if db, ok := c.db.(sqlDb); ok { if db, ok := c.db.(sqlDb); ok {
@ -180,7 +176,7 @@ func (s *DB) Begin() *DB {
c.db = interface{}(tx).(sqlCommon) c.db = interface{}(tx).(sqlCommon)
c.err(err) c.err(err)
} else { } else {
c.err(errors.New("Can't start a transaction.")) c.err(CantStartTransaction)
} }
return c return c
} }
@ -205,22 +201,19 @@ func (s *DB) Rollback() *DB {
// Migrations // Migrations
func (s *DB) CreateTable(value interface{}) *DB { func (s *DB) CreateTable(value interface{}) *DB {
s.do(value).createTable() return s.clone().do(value).createTable().db
return s
} }
func (s *DB) DropTable(value interface{}) *DB { func (s *DB) DropTable(value interface{}) *DB {
s.do(value).dropTable() return s.clone().do(value).dropTable().db
return s
} }
func (s *DB) AutoMigrate(value interface{}) *DB { func (s *DB) AutoMigrate(value interface{}) *DB {
s.do(value).autoMigrate() return s.clone().do(value).autoMigrate().db
return s
} }
func (s *DB) UpdateColumn(column string, typ string) *DB { func (s *DB) UpdateColumn(column string, typ string) *DB {
s.do(s.data).updateColumn(column, typ) s.clone().do(s.data).updateColumn(column, typ)
return s return s
} }
@ -230,11 +223,11 @@ func (s *DB) DropColumn(column string) *DB {
} }
func (s *DB) AddIndex(column string, index_name ...string) *DB { func (s *DB) AddIndex(column string, index_name ...string) *DB {
s.do(s.data).addIndex(column, index_name...) s.clone().do(s.data).addIndex(column, index_name...)
return s return s
} }
func (s *DB) RemoveIndex(column string) *DB { func (s *DB) RemoveIndex(column string) *DB {
s.do(s.data).removeIndex(column) s.clone().do(s.data).removeIndex(column)
return s return s
} }