orm.Errors to hold all errors happened, orm.Error store the last error

This commit is contained in:
Jinzhu 2013-10-27 13:35:22 +08:00
parent ea4dee3ba8
commit c3950eef38
3 changed files with 33 additions and 20 deletions

View File

@ -3,12 +3,11 @@
Yet Another ORM library for Go, aims for developer friendly Yet Another ORM library for Go, aims for developer friendly
## TODO ## TODO
* CreatedAt, UpdatedAt
* After/Before Save/Update/Create/Delete
* Soft Delete
* Better First method (First(&user, primary_key, where conditions)) * Better First method (First(&user, primary_key, where conditions))
* Even more complex where query (with map or struct) * Even more complex where query (with map or struct)
* ORM.Errors
* After/Before Save/Update/Create/Delete
* CreatedAt, UpdatedAt
* Soft Delete
* FindOrInitialize / FindOrCreate * FindOrInitialize / FindOrCreate
* SQL Log * SQL Log
* Auto Migration * Auto Migration

23
orm.go
View File

@ -11,11 +11,13 @@ type Orm struct {
TableName string TableName string
PrimaryKey string PrimaryKey string
SqlResult sql.Result SqlResult sql.Result
Error error
Sql string Sql string
SqlVars []interface{} SqlVars []interface{}
model *Model model *Model
Errors []error
Error error
db *sql.DB db *sql.DB
driver string driver string
whereClause []map[string]interface{} whereClause []map[string]interface{}
@ -27,6 +29,13 @@ type Orm struct {
operation string operation string
} }
func (s *Orm) err(err error) {
if err != nil {
s.Errors = append(s.Errors, err)
s.Error = err
}
}
func (s *Orm) Model(model interface{}) *Orm { func (s *Orm) Model(model interface{}) *Orm {
s.model = s.toModel(model) s.model = s.toModel(model)
s.TableName = s.model.TableName() s.TableName = s.model.TableName()
@ -50,7 +59,7 @@ func (s *Orm) Limit(value interface{}) *Orm {
s.limitStr = strconv.Itoa(value) s.limitStr = strconv.Itoa(value)
} }
default: default:
s.Error = errors.New("Can' understand the value of Limit, Should be int") s.err(errors.New("Can' understand the value of Limit, Should be int"))
} }
return s return s
} }
@ -66,7 +75,7 @@ func (s *Orm) Offset(value interface{}) *Orm {
s.offsetStr = strconv.Itoa(value) s.offsetStr = strconv.Itoa(value)
} }
default: default:
s.Error = errors.New("Can' understand the value of Offset, Should be int") s.err(errors.New("Can' understand the value of Offset, Should be int"))
} }
return s return s
} }
@ -92,7 +101,7 @@ func (s *Orm) Select(value interface{}) *Orm {
case string: case string:
s.selectStr = value s.selectStr = value
default: default:
s.Error = errors.New("Can' understand the value of Select, Should be string") s.err(errors.New("Can' understand the value of Select, Should be string"))
} }
return s return s
@ -122,11 +131,13 @@ func (s *Orm) Updates(values map[string]string) *Orm {
} }
func (s *Orm) Exec(sql ...string) *Orm { func (s *Orm) Exec(sql ...string) *Orm {
var err error
if len(sql) == 0 { if len(sql) == 0 {
s.SqlResult, s.Error = s.db.Exec(s.Sql, s.SqlVars...) s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
} else { } else {
s.SqlResult, s.Error = s.db.Exec(sql[0]) s.SqlResult, err = s.db.Exec(sql[0])
} }
s.err(err)
return s return s
} }

23
sql.go
View File

@ -11,7 +11,7 @@ import (
func (s *Orm) validSql(str string) (result bool) { func (s *Orm) validSql(str string) (result bool) {
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
if !result { if !result {
s.Error = errors.New(fmt.Sprintf("SQL is not valid, %s", str)) s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str)))
} }
return return
} }
@ -53,9 +53,9 @@ func (s *Orm) query(out interface{}) {
rows, err := s.db.Query(s.Sql, s.SqlVars...) rows, err := s.db.Query(s.Sql, s.SqlVars...)
defer rows.Close() defer rows.Close()
s.Error = err s.err(err)
if rows.Err() != nil { if rows.Err() != nil {
s.Error = rows.Err() s.err(rows.Err())
} }
counts := 0 counts := 0
@ -73,7 +73,7 @@ func (s *Orm) query(out interface{}) {
for _, value := range columns { for _, value := range columns {
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface()) values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
} }
s.Error = rows.Scan(values...) s.err(rows.Scan(values...))
if is_slice { if is_slice {
dest_out.Set(reflect.Append(dest_out, dest)) dest_out.Set(reflect.Append(dest_out, dest))
@ -81,7 +81,7 @@ func (s *Orm) query(out interface{}) {
} }
if (counts == 0) && !is_slice { if (counts == 0) && !is_slice {
s.Error = errors.New("Record not found!") s.err(errors.New("Record not found!"))
} }
} }
@ -90,12 +90,12 @@ func (s *Orm) pluck(value interface{}) {
dest_type := dest_out.Type().Elem() dest_type := dest_out.Type().Elem()
rows, err := s.db.Query(s.Sql, s.SqlVars...) rows, err := s.db.Query(s.Sql, s.SqlVars...)
s.Error = err s.err(err)
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
dest := reflect.New(dest_type).Elem().Interface() dest := reflect.New(dest_type).Elem().Interface()
s.Error = rows.Scan(&dest) s.err(rows.Scan(&dest))
switch dest.(type) { switch dest.(type) {
case []uint8: case []uint8:
if dest_type.String() == "string" { if dest_type.String() == "string" {
@ -130,10 +130,13 @@ func (s *Orm) createSql(value interface{}) {
func (s *Orm) create(value interface{}) { func (s *Orm) create(value interface{}) {
var id int64 var id int64
if s.driver == "postgres" { if s.driver == "postgres" {
s.Error = s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id) s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
} else { } else {
s.SqlResult, s.Error = s.db.Exec(s.Sql, s.SqlVars...) var err error
id, s.Error = s.SqlResult.LastInsertId() s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
s.err(err)
id, err = s.SqlResult.LastInsertId()
s.err(err)
} }
result := reflect.ValueOf(s.model.Data).Elem() result := reflect.ValueOf(s.model.Data).Elem()