From c3950eef386725876a2733423157a8616bce4e44 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Oct 2013 13:35:22 +0800 Subject: [PATCH] orm.Errors to hold all errors happened, orm.Error store the last error --- README.md | 7 +++---- orm.go | 23 +++++++++++++++++------ sql.go | 23 +++++++++++++---------- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index a24d14cd..e1e8357f 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,11 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO +* CreatedAt, UpdatedAt +* After/Before Save/Update/Create/Delete +* Soft Delete * Better First method (First(&user, primary_key, where conditions)) * Even more complex where query (with map or struct) -* ORM.Errors -* After/Before Save/Update/Create/Delete -* CreatedAt, UpdatedAt -* Soft Delete * FindOrInitialize / FindOrCreate * SQL Log * Auto Migration diff --git a/orm.go b/orm.go index 1c3558d0..fc282268 100644 --- a/orm.go +++ b/orm.go @@ -11,11 +11,13 @@ type Orm struct { TableName string PrimaryKey string SqlResult sql.Result - Error error Sql string SqlVars []interface{} model *Model + Errors []error + Error error + db *sql.DB driver string whereClause []map[string]interface{} @@ -27,6 +29,13 @@ type Orm struct { operation string } +func (s *Orm) err(err error) { + if err != nil { + s.Errors = append(s.Errors, err) + s.Error = err + } +} + func (s *Orm) Model(model interface{}) *Orm { s.model = s.toModel(model) s.TableName = s.model.TableName() @@ -50,7 +59,7 @@ func (s *Orm) Limit(value interface{}) *Orm { s.limitStr = strconv.Itoa(value) } default: - s.Error = errors.New("Can' understand the value of Limit, Should be int") + s.err(errors.New("Can' understand the value of Limit, Should be int")) } return s } @@ -66,7 +75,7 @@ func (s *Orm) Offset(value interface{}) *Orm { s.offsetStr = strconv.Itoa(value) } default: - s.Error = errors.New("Can' understand the value of Offset, Should be int") + s.err(errors.New("Can' understand the value of Offset, Should be int")) } return s } @@ -92,7 +101,7 @@ func (s *Orm) Select(value interface{}) *Orm { case string: s.selectStr = value default: - s.Error = errors.New("Can' understand the value of Select, Should be string") + s.err(errors.New("Can' understand the value of Select, Should be string")) } return s @@ -122,11 +131,13 @@ func (s *Orm) Updates(values map[string]string) *Orm { } func (s *Orm) Exec(sql ...string) *Orm { + var err error if len(sql) == 0 { - s.SqlResult, s.Error = s.db.Exec(s.Sql, s.SqlVars...) + s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) } else { - s.SqlResult, s.Error = s.db.Exec(sql[0]) + s.SqlResult, err = s.db.Exec(sql[0]) } + s.err(err) return s } diff --git a/sql.go b/sql.go index 5692ba9b..8fc23e96 100644 --- a/sql.go +++ b/sql.go @@ -11,7 +11,7 @@ import ( func (s *Orm) validSql(str string) (result bool) { result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) if !result { - s.Error = errors.New(fmt.Sprintf("SQL is not valid, %s", str)) + s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str))) } return } @@ -53,9 +53,9 @@ func (s *Orm) query(out interface{}) { rows, err := s.db.Query(s.Sql, s.SqlVars...) defer rows.Close() - s.Error = err + s.err(err) if rows.Err() != nil { - s.Error = rows.Err() + s.err(rows.Err()) } counts := 0 @@ -73,7 +73,7 @@ func (s *Orm) query(out interface{}) { for _, value := range columns { values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface()) } - s.Error = rows.Scan(values...) + s.err(rows.Scan(values...)) if is_slice { dest_out.Set(reflect.Append(dest_out, dest)) @@ -81,7 +81,7 @@ func (s *Orm) query(out interface{}) { } if (counts == 0) && !is_slice { - s.Error = errors.New("Record not found!") + s.err(errors.New("Record not found!")) } } @@ -90,12 +90,12 @@ func (s *Orm) pluck(value interface{}) { dest_type := dest_out.Type().Elem() rows, err := s.db.Query(s.Sql, s.SqlVars...) - s.Error = err + s.err(err) defer rows.Close() for rows.Next() { dest := reflect.New(dest_type).Elem().Interface() - s.Error = rows.Scan(&dest) + s.err(rows.Scan(&dest)) switch dest.(type) { case []uint8: if dest_type.String() == "string" { @@ -130,10 +130,13 @@ func (s *Orm) createSql(value interface{}) { func (s *Orm) create(value interface{}) { var id int64 if s.driver == "postgres" { - s.Error = s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id) + s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) } else { - s.SqlResult, s.Error = s.db.Exec(s.Sql, s.SqlVars...) - id, s.Error = s.SqlResult.LastInsertId() + var err error + s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) + s.err(err) + id, err = s.SqlResult.LastInsertId() + s.err(err) } result := reflect.ValueOf(s.model.Data).Elem()