diff --git a/errors.go b/errors.go index e99a7f23..d91fbeb5 100644 --- a/errors.go +++ b/errors.go @@ -3,8 +3,9 @@ package gorm import "errors" var ( - RecordNotFound = errors.New("Record Not Found") - InvalidSql = errors.New("Invalid SQL") - NoNewAttrs = errors.New("No new Attributes") - NoValidTransaction = errors.New("No valid transaction") + RecordNotFound = errors.New("Record Not Found") + InvalidSql = errors.New("Invalid SQL") + NoNewAttrs = errors.New("No new Attributes") + NoValidTransaction = errors.New("No valid transaction") + CantStartTransaction = errors.New("Can't start transaction") ) diff --git a/field.go b/field.go index f5190c39..8455c184 100644 --- a/field.go +++ b/field.go @@ -53,7 +53,7 @@ func (f *Field) sqlTag() (str string) { } } - typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.tagIdentifier)) + typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.parent.tagIdentifier)) if typ == "-" { return diff --git a/gorm_test.go b/gorm_test.go index 0c27e77c..8db3598e 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" _ "github.com/go-sql-driver/mysql" @@ -1289,145 +1290,145 @@ func TestAutoMigration(t *testing.T) { } } -// type NullTime struct { -// Time time.Time -// Valid bool -// } +type NullTime struct { + Time time.Time + Valid bool +} -// func (nt *NullTime) Scan(value interface{}) error { -// if value == nil { -// nt.Valid = false -// return nil -// } -// nt.Time, nt.Valid = value.(time.Time), true -// return nil -// } +func (nt *NullTime) Scan(value interface{}) error { + if value == nil { + nt.Valid = false + return nil + } + nt.Time, nt.Valid = value.(time.Time), true + return nil +} -// func (nt NullTime) Value() (driver.Value, error) { -// if !nt.Valid { -// return nil, nil -// } -// return nt.Time, nil -// } +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} -// type NullValue struct { -// Id int64 -// Name sql.NullString `sql:"not null"` -// Age sql.NullInt64 -// Male sql.NullBool -// Height sql.NullFloat64 -// AddedAt NullTime -// } +type NullValue struct { + Id int64 + Name sql.NullString `sql:"not null"` + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + AddedAt NullTime +} -// func TestSqlNullValue(t *testing.T) { -// db.DropTable(&NullValue{}) -// db.AutoMigrate(&NullValue{}) +func TestSqlNullValue(t *testing.T) { + db.DropTable(&NullValue{}) + db.AutoMigrate(&NullValue{}) -// if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil { -// t.Errorf("Not error should raise when test null value", err) -// } + if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), true}}).Error; err != nil { + t.Errorf("Not error should raise when test null value", err) + } -// var nv NullValue -// db.First(&nv, "name = ?", "hello") + var nv NullValue + db.First(&nv, "name = ?", "hello") -// if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { -// t.Errorf("Should be able to fetch null value") -// } + if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { + t.Errorf("Should be able to fetch null value") + } -// if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil { -// t.Errorf("Not error should raise when test null value", err) -// } + if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err != nil { + t.Errorf("Not error should raise when test null value", err) + } -// var nv2 NullValue -// db.First(&nv2, "name = ?", "hello-2") -// if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { -// t.Errorf("Should be able to fetch null value") -// } + var nv2 NullValue + db.First(&nv2, "name = ?", "hello-2") + if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { + t.Errorf("Should be able to fetch null value") + } -// if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil { -// t.Errorf("Can't save because of name can't be null", err) -// } -// } + if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil { + t.Errorf("Can't save because of name can't be null", err) + } +} -// func TestTransaction(t *testing.T) { -// d := db.Begin() -// u := User{Name: "transcation"} -// if err := d.Save(&u).Error; err != nil { -// t.Errorf("No error should raise, but got", err) -// } +func TestTransaction(t *testing.T) { + d := db.Begin() + u := User{Name: "transcation"} + if err := d.Save(&u).Error; err != nil { + t.Errorf("No error should raise, but got", err) + } -// if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil { -// t.Errorf("Should find saved record, but got", err) -// } + if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record, but got", err) + } -// d.Rollback() + d.Rollback() -// if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil { -// t.Errorf("Should not find record after rollback") -// } + if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } -// d2 := db.Begin() -// u2 := User{Name: "transcation-2"} -// if err := d2.Save(&u2).Error; err != nil { -// t.Errorf("No error should raise, but got", err) -// } -// d2.Update("age", 90) + d2 := db.Begin() + u2 := User{Name: "transcation-2"} + if err := d2.Save(&u2).Error; err != nil { + t.Errorf("No error should raise, but got", err) + } + d2.Update("age", 90) -// if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { -// t.Errorf("Should find saved record, but got", err) -// } + if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record, but got", err) + } -// d2.Commit() + d2.Commit() -// if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil { -// t.Errorf("Should be able to find committed record") -// } -// } + if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } +} -// func (s *CreditCard) BeforeSave() (err error) { -// if s.Number == "0000" { -// err = errors.New("invalid credit card") -// } -// return -// } +func (s *CreditCard) BeforeSave() (err error) { + if s.Number == "0000" { + err = errors.New("invalid credit card") + } + return +} -// func BenchmarkGorm(b *testing.B) { -// b.N = 5000 -// for x := 0; x < b.N; x++ { -// e := strconv.Itoa(x) + "benchmark@example.org" -// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} -// // Insert -// db.Save(&email) -// // Query -// db.First(&BigEmail{}, "email = ?", e) -// // Update -// db.Model(&email).Update("email", "new-"+e) -// // Delete -// db.Delete(&email) -// } -// } +func BenchmarkGorm(b *testing.B) { + b.N = 5000 + for x := 0; x < b.N; x++ { + e := strconv.Itoa(x) + "benchmark@example.org" + email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} + // Insert + db.Save(&email) + // Query + db.First(&BigEmail{}, "email = ?", e) + // Update + db.Model(&email).Update("email", "new-"+e) + // Delete + db.Delete(&email) + } +} -// func BenchmarkRawSql(b *testing.B) { -// db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") -// db.SetMaxIdleConns(10) -// insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" -// query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" -// update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" -// delete_sql := "DELETE FROM orders WHERE id = $1" +func BenchmarkRawSql(b *testing.B) { + db, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") + db.SetMaxIdleConns(10) + insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" + query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" + update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" + delete_sql := "DELETE FROM orders WHERE id = $1" -// b.N = 5000 -// for x := 0; x < b.N; x++ { -// var id int64 -// e := strconv.Itoa(x) + "benchmark@example.org" -// email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} -// // Insert -// db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) -// // Query -// rows, _ := db.Query(query_sql, email.Email) -// rows.Close() -// // Update -// db.Exec(update_sql, "new-"+e, time.Now(), id) -// // Delete -// db.Exec(delete_sql, id) -// } -// } + b.N = 5000 + for x := 0; x < b.N; x++ { + var id int64 + e := strconv.Itoa(x) + "benchmark@example.org" + email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} + // Insert + db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) + // Query + rows, _ := db.Query(query_sql, email.Email) + rows.Close() + // Update + db.Exec(update_sql, "new-"+e, time.Now(), id) + // Delete + db.Exec(delete_sql, id) + } +} diff --git a/main.go b/main.go index 98ba7a21..89279783 100644 --- a/main.go +++ b/main.go @@ -2,10 +2,9 @@ package gorm import ( "database/sql" - "errors" -) -import "github.com/jinzhu/gorm/dialect" + "github.com/jinzhu/gorm/dialect" +) type DB struct { db sqlCommon @@ -23,12 +22,13 @@ type DB struct { func Open(driver, source string) (db DB, err error) { db.db, err = sql.Open(driver, source) db.dialect = dialect.New(driver) + db.tagIdentifier = "sql" db.parent = &db return } func (s *DB) SetPool(n int) { - if db, ok := s.db.(sqlDb); ok { + if db, ok := s.parent.db.(sqlDb); ok { db.SetMaxIdleConns(n) } } @@ -103,25 +103,23 @@ func (s *DB) Assign(attrs ...interface{}) *DB { } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - if s.clone().First(out, where...).Error != nil { - return s.clone().do(out).where(where).initialize().db - } else { - if len(s.search.assignAttrs) > 0 { - return s.clone().do(out).updateAttrs(s.search.assignAttrs).db - } + c := s.clone() + if c.First(out, where...).Error == RecordNotFound { + return c.do(out).where(where).initialize().db + } else if len(s.search.assignAttrs) > 0 { + return c.do(out).updateAttrs(s.search.assignAttrs).db } - return s + return c } func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - if s.clone().First(out, where...).Error != nil { - return s.clone().do(out).where(where...).initialize().db.Save(out) - } else { - if len(s.search.assignAttrs) > 0 { - return s.clone().do(out).updateAttrs(s.search.assignAttrs).update().db - } + c := s.clone() + if c.First(out, where...).Error == RecordNotFound { + return c.do(out).where(where...).initialize().db.Save(out) + } else if len(s.search.assignAttrs) > 0 { + return c.do(out).updateAttrs(s.search.assignAttrs).update().db } - return s + return c } func (s *DB) Update(attrs ...interface{}) *DB { @@ -167,12 +165,10 @@ func (s *DB) Table(name string) *DB { return s.clone().search.table(name).db } -// Debug func (s *DB) Debug() *DB { return s.clone().LogMode(true) } -// Transactions func (s *DB) Begin() *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok { @@ -180,7 +176,7 @@ func (s *DB) Begin() *DB { c.db = interface{}(tx).(sqlCommon) c.err(err) } else { - c.err(errors.New("Can't start a transaction.")) + c.err(CantStartTransaction) } return c } @@ -205,22 +201,19 @@ func (s *DB) Rollback() *DB { // Migrations func (s *DB) CreateTable(value interface{}) *DB { - s.do(value).createTable() - return s + return s.clone().do(value).createTable().db } func (s *DB) DropTable(value interface{}) *DB { - s.do(value).dropTable() - return s + return s.clone().do(value).dropTable().db } func (s *DB) AutoMigrate(value interface{}) *DB { - s.do(value).autoMigrate() - return s + return s.clone().do(value).autoMigrate().db } func (s *DB) UpdateColumn(column string, typ string) *DB { - s.do(s.data).updateColumn(column, typ) + s.clone().do(s.data).updateColumn(column, typ) return s } @@ -230,11 +223,11 @@ func (s *DB) DropColumn(column string) *DB { } func (s *DB) AddIndex(column string, index_name ...string) *DB { - s.do(s.data).addIndex(column, index_name...) + s.clone().do(s.data).addIndex(column, index_name...) return s } func (s *DB) RemoveIndex(column string) *DB { - s.do(s.data).removeIndex(column) + s.clone().do(s.data).removeIndex(column) return s }