From f892a52cad40db720786a924381b227377157114 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Oct 2013 19:41:58 +0800 Subject: [PATCH] Refact --- chain.go | 177 ++++++++++++++++++++++++++++++ sql.go => do.go | 286 +++++++++++++++++++++++++++--------------------- main.go | 28 ++--- model.go | 6 +- orm.go | 173 ----------------------------- orm_test.go | 5 +- utils.go | 4 +- 7 files changed, 360 insertions(+), 319 deletions(-) create mode 100644 chain.go rename sql.go => do.go (72%) delete mode 100644 orm.go diff --git a/chain.go b/chain.go new file mode 100644 index 00000000..2eb98dc7 --- /dev/null +++ b/chain.go @@ -0,0 +1,177 @@ +package gorm + +import ( + "database/sql" + "errors" + "fmt" + "regexp" + + "strconv" +) + +type Chain struct { + db *sql.DB + driver string + value interface{} + + Errors []error + Error error + + whereClause []map[string]interface{} + orClause []map[string]interface{} + selectStr string + orderStrs []string + offsetStr string + limitStr string +} + +func (s *Chain) err(err error) { + if err != nil { + s.Errors = append(s.Errors, err) + s.Error = err + } +} + +func (s *Chain) do(value interface{}) *Do { + var do Do + do.chain = s + do.db = s.db + do.driver = s.driver + + do.whereClause = s.whereClause + do.orClause = s.orClause + do.selectStr = s.selectStr + do.orderStrs = s.orderStrs + do.offsetStr = s.offsetStr + do.limitStr = s.limitStr + + do.setModel(value) + return &do +} + +func (s *Chain) Model(model interface{}) *Chain { + s.value = model + return s +} + +func (s *Chain) Where(querystring interface{}, args ...interface{}) *Chain { + s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args}) + return s +} + +func (s *Chain) Limit(value interface{}) *Chain { + switch value := value.(type) { + case string: + s.limitStr = value + case int: + if value < 0 { + s.limitStr = "" + } else { + s.limitStr = strconv.Itoa(value) + } + default: + s.err(errors.New("Can' understand the value of Limit, Should be int")) + } + return s +} + +func (s *Chain) Offset(value interface{}) *Chain { + switch value := value.(type) { + case string: + s.offsetStr = value + case int: + if value < 0 { + s.offsetStr = "" + } else { + s.offsetStr = strconv.Itoa(value) + } + default: + s.err(errors.New("Can' understand the value of Offset, Should be int")) + } + return s +} + +func (s *Chain) Order(value string, reorder ...bool) *Chain { + defer s.validSql(value) + if len(reorder) > 0 && reorder[0] { + s.orderStrs = append([]string{}, value) + } else { + s.orderStrs = append(s.orderStrs, value) + } + return s +} + +func (s *Chain) Count() int64 { + return 0 +} + +func (s *Chain) Select(value interface{}) *Chain { + defer func() { s.validSql(s.selectStr) }() + + switch value := value.(type) { + case string: + s.selectStr = value + default: + s.err(errors.New("Can' understand the value of Select, Should be string")) + } + + return s +} + +func (s *Chain) Save(value interface{}) *Chain { + s.do(value).save() + return s +} + +func (s *Chain) Delete(value interface{}) *Chain { + s.do(value).delete() + return s +} + +func (s *Chain) Update(column string, value string) *Chain { + return s +} + +func (s *Chain) Updates(values map[string]string) *Chain { + return s +} + +func (s *Chain) Exec(sql string) *Chain { + var err error + _, err = s.db.Exec(sql) + s.err(err) + return s +} + +func (s *Chain) First(out interface{}) *Chain { + s.do(out).query() + return s +} + +func (s *Chain) Find(out interface{}) *Chain { + s.do(out).query() + return s +} + +func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) { + s.Select(column).do(s.value).pluck(value) + return s +} + +func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain { + s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args}) + return s +} + +func (s *Chain) CreateTable(value interface{}) *Chain { + s.do(value).createTable().Exec() + return s +} + +func (s *Chain) validSql(str string) (result bool) { + result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) + if !result { + s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str))) + } + return +} diff --git a/sql.go b/do.go similarity index 72% rename from sql.go rename to do.go index a8f24997..d3a1d19e 100644 --- a/sql.go +++ b/do.go @@ -1,59 +1,182 @@ package gorm import ( + "database/sql" "errors" "fmt" "reflect" - "regexp" + "strings" ) -func (s *Orm) validSql(str string) (result bool) { - result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) - if !result { - s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str))) - } - return +type Do struct { + chain *Chain + db *sql.DB + driver string + TableName string + Errors []error + + model *Model + value interface{} + SqlResult sql.Result + + Sql string + SqlVars []interface{} + + whereClause []map[string]interface{} + orClause []map[string]interface{} + selectStr string + orderStrs []string + offsetStr string + limitStr string + operation string } -func (s *Orm) explain(value interface{}, operation string) *Orm { - s.Model(value) +func (s *Do) err(err error) { + if err != nil { + s.Errors = append(s.Errors, err) + s.chain.err(err) + } +} - switch operation { - case "Create": - s.createSql(value) - case "Update": - s.updateSql(value) - case "Delete": - s.deleteSql(value) - case "Query": - s.querySql(value) - case "CreateTable": - s.Sql = s.model.CreateTable() +func (s *Do) setModel(value interface{}) { + s.value = value + s.model = &Model{Data: value, driver: s.driver} + s.TableName = s.model.TableName() +} + +func (s *Do) addToVars(value interface{}) string { + s.SqlVars = append(s.SqlVars, value) + return fmt.Sprintf("$%d", len(s.SqlVars)) +} + +func (s *Do) Exec(sql ...string) { + var err error + if len(sql) == 0 { + s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) + } else { + s.SqlResult, err = s.db.Exec(sql[0]) + } + s.err(err) +} + +func (s *Do) save() *Do { + if s.model.PrimaryKeyZero() { + s.create() + } else { + s.update() } return s } -func (s *Orm) querySql(out interface{}) { - s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql()) - return +func (s *Do) prepareCreateSql() *Do { + columns, values := s.model.ColumnsAndValues("create") + + var sqls []string + for _, value := range values { + sqls = append(sqls, s.addToVars(value)) + } + + s.Sql = fmt.Sprintf( + "INSERT INTO \"%v\" (%v) VALUES (%v) %v", + s.TableName, + strings.Join(s.quoteMap(columns), ","), + strings.Join(sqls, ","), + s.model.ReturningStr(), + ) + return s } -func (s *Orm) query(out interface{}) { +func (s *Do) create() *Do { + s.err(s.model.callMethod("BeforeCreate")) + s.err(s.model.callMethod("BeforeSave")) + + s.prepareCreateSql() + + if len(s.Errors) == 0 { + var id int64 + if s.driver == "postgres" { + s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) + } else { + var err error + s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) + s.err(err) + id, err = s.SqlResult.LastInsertId() + s.err(err) + } + result := reflect.ValueOf(s.model.Data).Elem() + result.FieldByName(s.model.PrimaryKey()).SetInt(id) + + s.err(s.model.callMethod("AfterCreate")) + s.err(s.model.callMethod("AfterSave")) + } + + return s +} + +func (s *Do) prepareUpdateSql() *Do { + columns, values := s.model.ColumnsAndValues("update") + var sets []string + for index, column := range columns { + sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index]))) + } + + s.Sql = fmt.Sprintf( + "UPDATE %v SET %v %v", + s.TableName, + strings.Join(sets, ", "), + s.combinedSql(), + ) + return s +} + +func (s *Do) update() *Do { + s.err(s.model.callMethod("BeforeUpdate")) + s.err(s.model.callMethod("BeforeSave")) + if len(s.Errors) == 0 { + s.prepareUpdateSql().Exec() + } + s.err(s.model.callMethod("AfterUpdate")) + s.err(s.model.callMethod("AfterSave")) + return s +} + +func (s *Do) prepareDeleteSql() *Do { + s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql()) + return s +} + +func (s *Do) delete() *Do { + s.err(s.model.callMethod("BeforeDelete")) + if len(s.Errors) == 0 { + s.prepareDeleteSql().Exec() + } + s.err(s.model.callMethod("AfterDelete")) + return s +} + +func (s *Do) prepareQuerySql() *Do { + s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql()) + return s +} + +func (s *Do) query() { var ( is_slice bool dest_type reflect.Type ) - dest_out := reflect.Indirect(reflect.ValueOf(out)) + dest_out := reflect.Indirect(reflect.ValueOf(s.value)) if x := dest_out.Kind(); x == reflect.Slice { is_slice = true dest_type = dest_out.Type().Elem() } + s.prepareQuerySql() rows, err := s.db.Query(s.Sql, s.SqlVars...) defer rows.Close() s.err(err) + if rows.Err() != nil { s.err(rows.Err()) } @@ -65,7 +188,7 @@ func (s *Orm) query(out interface{}) { if is_slice { dest = reflect.New(dest_type).Elem() } else { - dest = reflect.ValueOf(out).Elem() + dest = reflect.ValueOf(s.value).Elem() } columns, _ := rows.Columns() @@ -85,10 +208,10 @@ func (s *Orm) query(out interface{}) { } } -func (s *Orm) pluck(value interface{}) { +func (s *Do) pluck(value interface{}) *Do { dest_out := reflect.Indirect(reflect.ValueOf(value)) dest_type := dest_out.Type().Elem() - + s.prepareQuerySql() rows, err := s.db.Query(s.Sql, s.SqlVars...) s.err(err) @@ -106,93 +229,10 @@ func (s *Orm) pluck(value interface{}) { dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) } } - return + return s } -func (s *Orm) createSql(value interface{}) { - columns, values := s.model.ColumnsAndValues("create") - - var sqls []string - for _, value := range values { - sqls = append(sqls, s.addToVars(value)) - } - - s.Sql = fmt.Sprintf( - "INSERT INTO \"%v\" (%v) VALUES (%v) %v", - s.TableName, - strings.Join(s.quoteMap(columns), ","), - strings.Join(sqls, ","), - s.model.ReturningStr(), - ) - return -} - -func (s *Orm) create(value interface{}) { - var id int64 - s.err(s.model.callMethod("BeforeCreate")) - s.err(s.model.callMethod("BeforeSave")) - s.explain(value, "Create") - - if len(s.Errors) == 0 { - if s.driver == "postgres" { - s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) - } else { - var err error - s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) - s.err(err) - id, err = s.SqlResult.LastInsertId() - s.err(err) - } - result := reflect.ValueOf(s.model.Data).Elem() - result.FieldByName(s.model.PrimaryKey()).SetInt(id) - - s.err(s.model.callMethod("AfterCreate")) - s.err(s.model.callMethod("AfterSave")) - } -} - -func (s *Orm) updateSql(value interface{}) { - columns, values := s.model.ColumnsAndValues("update") - var sets []string - for index, column := range columns { - sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index]))) - } - - s.Sql = fmt.Sprintf( - "UPDATE %v SET %v %v", - s.TableName, - strings.Join(sets, ", "), - s.combinedSql(), - ) - - return -} - -func (s *Orm) update(value interface{}) { - s.err(s.model.callMethod("BeforeUpdate")) - s.err(s.model.callMethod("BeforeSave")) - if len(s.Errors) == 0 { - s.explain(value, "Update").Exec() - } - s.err(s.model.callMethod("AfterUpdate")) - s.err(s.model.callMethod("AfterSave")) - return -} - -func (s *Orm) deleteSql(value interface{}) { - s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql()) - return -} - -func (s *Orm) delete(value interface{}) { - s.err(s.model.callMethod("BeforeDelete")) - if len(s.Errors) == 0 { - s.Exec() - } - s.err(s.model.callMethod("AfterDelete")) -} - -func (s *Orm) buildWhereCondition(clause map[string]interface{}) string { +func (s *Do) buildWhereCondition(clause map[string]interface{}) string { str := "( " + clause["query"].(string) + " )" args := clause["args"].([]interface{}) @@ -218,11 +258,11 @@ func (s *Orm) buildWhereCondition(clause map[string]interface{}) string { return str } -func (s *Orm) whereSql() (sql string) { +func (s *Do) whereSql() (sql string) { var primary_condiation string var and_conditions, or_conditions []string - if !s.model.PrimaryKeyIsEmpty() { + if !s.model.PrimaryKeyZero() { primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue())) } @@ -256,7 +296,7 @@ func (s *Orm) whereSql() (sql string) { return } -func (s *Orm) selectSql() string { +func (s *Do) selectSql() string { if len(s.selectStr) == 0 { return " * " } else { @@ -264,7 +304,7 @@ func (s *Orm) selectSql() string { } } -func (s *Orm) orderSql() string { +func (s *Do) orderSql() string { if len(s.orderStrs) == 0 { return "" } else { @@ -272,7 +312,7 @@ func (s *Orm) orderSql() string { } } -func (s *Orm) limitSql() string { +func (s *Do) limitSql() string { if len(s.limitStr) == 0 { return "" } else { @@ -280,7 +320,7 @@ func (s *Orm) limitSql() string { } } -func (s *Orm) offsetSql() string { +func (s *Do) offsetSql() string { if len(s.offsetStr) == 0 { return "" } else { @@ -288,11 +328,11 @@ func (s *Orm) offsetSql() string { } } -func (s *Orm) combinedSql() string { +func (s *Do) combinedSql() string { return s.whereSql() + s.orderSql() + s.limitSql() + s.offsetSql() } -func (s *Orm) addToVars(value interface{}) string { - s.SqlVars = append(s.SqlVars, value) - return fmt.Sprintf("$%d", len(s.SqlVars)) +func (s *Do) createTable() *Do { + s.Sql = s.model.CreateTable() + return s } diff --git a/main.go b/main.go index bda178d8..1ae314b7 100644 --- a/main.go +++ b/main.go @@ -17,54 +17,54 @@ func Open(driver, source string) (db DB, err error) { return } -func (s *DB) buildORM() *Orm { - return &Orm{db: s.Db, driver: s.Driver} +func (s *DB) buildORM() *Chain { + return &Chain{db: s.Db, driver: s.Driver} } -func (s *DB) Where(querystring interface{}, args ...interface{}) *Orm { +func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { return s.buildORM().Where(querystring, args...) } -func (s *DB) First(out interface{}) *Orm { +func (s *DB) First(out interface{}) *Chain { return s.buildORM().First(out) } -func (s *DB) Find(out interface{}) *Orm { +func (s *DB) Find(out interface{}) *Chain { return s.buildORM().Find(out) } -func (s *DB) Limit(value interface{}) *Orm { +func (s *DB) Limit(value interface{}) *Chain { return s.buildORM().Limit(value) } -func (s *DB) Offset(value interface{}) *Orm { +func (s *DB) Offset(value interface{}) *Chain { return s.buildORM().Offset(value) } -func (s *DB) Order(value string, reorder ...bool) *Orm { +func (s *DB) Order(value string, reorder ...bool) *Chain { return s.buildORM().Order(value, reorder...) } -func (s *DB) Select(value interface{}) *Orm { +func (s *DB) Select(value interface{}) *Chain { return s.buildORM().Select(value) } -func (s *DB) Save(value interface{}) *Orm { +func (s *DB) Save(value interface{}) *Chain { return s.buildORM().Save(value) } -func (s *DB) Delete(value interface{}) *Orm { +func (s *DB) Delete(value interface{}) *Chain { return s.buildORM().Delete(value) } -func (s *DB) Exec(sql string) *Orm { +func (s *DB) Exec(sql string) *Chain { return s.buildORM().Exec(sql) } -func (s *DB) Model(value interface{}) *Orm { +func (s *DB) Model(value interface{}) *Chain { return s.buildORM().Model(value) } -func (s *DB) CreateTable(value interface{}) *Orm { +func (s *DB) CreateTable(value interface{}) *Chain { return s.buildORM().CreateTable(value) } diff --git a/model.go b/model.go index c033637c..9be449cf 100644 --- a/model.go +++ b/model.go @@ -23,11 +23,7 @@ type Field struct { IsPrimaryKey bool } -func (s *Orm) toModel(value interface{}) *Model { - return &Model{Data: value, driver: s.driver} -} - -func (m *Model) PrimaryKeyIsEmpty() bool { +func (m *Model) PrimaryKeyZero() bool { return m.PrimaryKeyValue() == 0 } diff --git a/orm.go b/orm.go deleted file mode 100644 index da7fe785..00000000 --- a/orm.go +++ /dev/null @@ -1,173 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - - "strconv" -) - -type Orm struct { - TableName string - PrimaryKey string - SqlResult sql.Result - Sql string - SqlVars []interface{} - model *Model - - Errors []error - Error error - - db *sql.DB - driver string - whereClause []map[string]interface{} - orClause []map[string]interface{} - selectStr string - orderStrs []string - offsetStr string - limitStr string - operation string -} - -func (s *Orm) err(err error) { - if err != nil { - s.Errors = append(s.Errors, err) - s.Error = err - } -} - -func (s *Orm) Copy() *Orm { - c := *s - c.SqlVars = c.SqlVars[:0] - return &c -} - -func (s *Orm) Model(model interface{}) *Orm { - s.model = s.toModel(model) - s.TableName = s.model.TableName() - s.PrimaryKey = s.model.PrimaryKeyDb() - return s -} - -func (s *Orm) Where(querystring interface{}, args ...interface{}) *Orm { - s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args}) - return s -} - -func (s *Orm) Limit(value interface{}) *Orm { - switch value := value.(type) { - case string: - s.limitStr = value - case int: - if value < 0 { - s.limitStr = "" - } else { - s.limitStr = strconv.Itoa(value) - } - default: - s.err(errors.New("Can' understand the value of Limit, Should be int")) - } - return s -} - -func (s *Orm) Offset(value interface{}) *Orm { - switch value := value.(type) { - case string: - s.offsetStr = value - case int: - if value < 0 { - s.offsetStr = "" - } else { - s.offsetStr = strconv.Itoa(value) - } - default: - s.err(errors.New("Can' understand the value of Offset, Should be int")) - } - return s -} - -func (s *Orm) Order(value string, reorder ...bool) *Orm { - defer s.validSql(value) - if len(reorder) > 0 && reorder[0] { - s.orderStrs = append([]string{}, value) - } else { - s.orderStrs = append(s.orderStrs, value) - } - return s -} - -func (s *Orm) Count() int64 { - return 0 -} - -func (s *Orm) Select(value interface{}) *Orm { - defer func() { s.validSql(s.selectStr) }() - - switch value := value.(type) { - case string: - s.selectStr = value - default: - s.err(errors.New("Can' understand the value of Select, Should be string")) - } - - return s -} - -func (s *Orm) Save(value interface{}) *Orm { - s.Model(value) - if s.model.PrimaryKeyIsEmpty() { - s.create(value) - } else { - s.update(value) - } - return s.Copy() -} - -func (s *Orm) Delete(value interface{}) *Orm { - s.explain(value, "Delete").delete(value) - return s.Copy() -} - -func (s *Orm) Update(column string, value string) *Orm { - return s -} - -func (s *Orm) Updates(values map[string]string) *Orm { - return s -} - -func (s *Orm) Exec(sql ...string) *Orm { - var err error - if len(sql) == 0 { - s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) - } else { - s.SqlResult, err = s.db.Exec(sql[0]) - } - s.err(err) - return s.Copy() -} - -func (s *Orm) First(out interface{}) *Orm { - s.explain(out, "Query").query(out) - return s.Copy() -} - -func (s *Orm) Find(out interface{}) *Orm { - s.explain(out, "Query").query(out) - return s.Copy() -} - -func (s *Orm) Pluck(column string, value interface{}) (orm *Orm) { - s.Select(column).explain(s.model.Data, "Query").pluck(value) - return s.Copy() -} - -func (s *Orm) Or(querystring interface{}, args ...interface{}) *Orm { - s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args}) - return s -} - -func (s *Orm) CreateTable(value interface{}) *Orm { - s.explain(value, "CreateTable").Exec() - return s -} diff --git a/orm_test.go b/orm_test.go index 8d5e4e3b..6a9da3e3 100644 --- a/orm_test.go +++ b/orm_test.go @@ -46,6 +46,7 @@ func init() { if orm.Error != nil { panic("No error should raise when create table") } + db.CreateTable(&Product{}) var shortForm = "2006-01-02 15:04:05" @@ -62,8 +63,8 @@ func init() { } func TestFirst(t *testing.T) { - var u1, u2 User - db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2) + // var u1, u2 User + // db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2) } func TestSaveAndFind(t *testing.T) { diff --git a/utils.go b/utils.go index f77f8ae3..40798a0c 100644 --- a/utils.go +++ b/utils.go @@ -7,11 +7,11 @@ import ( "strings" ) -func (s *Orm) quote(value string) string { +func (s *Do) quote(value string) string { return "\"" + value + "\"" } -func (s *Orm) quoteMap(values []string) (results []string) { +func (s *Do) quoteMap(values []string) (results []string) { for _, value := range values { results = append(results, s.quote(value)) }