This commit is contained in:
Jinzhu 2013-10-27 19:41:58 +08:00
parent f9658716e4
commit f892a52cad
7 changed files with 360 additions and 319 deletions

177
chain.go Normal file
View File

@ -0,0 +1,177 @@
package gorm
import (
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
)
type Chain struct {
db *sql.DB
driver string
value interface{}
Errors []error
Error error
whereClause []map[string]interface{}
orClause []map[string]interface{}
selectStr string
orderStrs []string
offsetStr string
limitStr string
}
func (s *Chain) err(err error) {
if err != nil {
s.Errors = append(s.Errors, err)
s.Error = err
}
}
func (s *Chain) do(value interface{}) *Do {
var do Do
do.chain = s
do.db = s.db
do.driver = s.driver
do.whereClause = s.whereClause
do.orClause = s.orClause
do.selectStr = s.selectStr
do.orderStrs = s.orderStrs
do.offsetStr = s.offsetStr
do.limitStr = s.limitStr
do.setModel(value)
return &do
}
func (s *Chain) Model(model interface{}) *Chain {
s.value = model
return s
}
func (s *Chain) Where(querystring interface{}, args ...interface{}) *Chain {
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
func (s *Chain) Limit(value interface{}) *Chain {
switch value := value.(type) {
case string:
s.limitStr = value
case int:
if value < 0 {
s.limitStr = ""
} else {
s.limitStr = strconv.Itoa(value)
}
default:
s.err(errors.New("Can' understand the value of Limit, Should be int"))
}
return s
}
func (s *Chain) Offset(value interface{}) *Chain {
switch value := value.(type) {
case string:
s.offsetStr = value
case int:
if value < 0 {
s.offsetStr = ""
} else {
s.offsetStr = strconv.Itoa(value)
}
default:
s.err(errors.New("Can' understand the value of Offset, Should be int"))
}
return s
}
func (s *Chain) Order(value string, reorder ...bool) *Chain {
defer s.validSql(value)
if len(reorder) > 0 && reorder[0] {
s.orderStrs = append([]string{}, value)
} else {
s.orderStrs = append(s.orderStrs, value)
}
return s
}
func (s *Chain) Count() int64 {
return 0
}
func (s *Chain) Select(value interface{}) *Chain {
defer func() { s.validSql(s.selectStr) }()
switch value := value.(type) {
case string:
s.selectStr = value
default:
s.err(errors.New("Can' understand the value of Select, Should be string"))
}
return s
}
func (s *Chain) Save(value interface{}) *Chain {
s.do(value).save()
return s
}
func (s *Chain) Delete(value interface{}) *Chain {
s.do(value).delete()
return s
}
func (s *Chain) Update(column string, value string) *Chain {
return s
}
func (s *Chain) Updates(values map[string]string) *Chain {
return s
}
func (s *Chain) Exec(sql string) *Chain {
var err error
_, err = s.db.Exec(sql)
s.err(err)
return s
}
func (s *Chain) First(out interface{}) *Chain {
s.do(out).query()
return s
}
func (s *Chain) Find(out interface{}) *Chain {
s.do(out).query()
return s
}
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
s.Select(column).do(s.value).pluck(value)
return s
}
func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain {
s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
func (s *Chain) CreateTable(value interface{}) *Chain {
s.do(value).createTable().Exec()
return s
}
func (s *Chain) validSql(str string) (result bool) {
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
if !result {
s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str)))
}
return
}

View File

@ -1,59 +1,182 @@
package gorm package gorm
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
) )
func (s *Orm) validSql(str string) (result bool) { type Do struct {
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) chain *Chain
if !result { db *sql.DB
s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str))) driver string
} TableName string
return Errors []error
model *Model
value interface{}
SqlResult sql.Result
Sql string
SqlVars []interface{}
whereClause []map[string]interface{}
orClause []map[string]interface{}
selectStr string
orderStrs []string
offsetStr string
limitStr string
operation string
} }
func (s *Orm) explain(value interface{}, operation string) *Orm { func (s *Do) err(err error) {
s.Model(value) if err != nil {
s.Errors = append(s.Errors, err)
s.chain.err(err)
}
}
switch operation { func (s *Do) setModel(value interface{}) {
case "Create": s.value = value
s.createSql(value) s.model = &Model{Data: value, driver: s.driver}
case "Update": s.TableName = s.model.TableName()
s.updateSql(value) }
case "Delete":
s.deleteSql(value) func (s *Do) addToVars(value interface{}) string {
case "Query": s.SqlVars = append(s.SqlVars, value)
s.querySql(value) return fmt.Sprintf("$%d", len(s.SqlVars))
case "CreateTable": }
s.Sql = s.model.CreateTable()
func (s *Do) Exec(sql ...string) {
var err error
if len(sql) == 0 {
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
} else {
s.SqlResult, err = s.db.Exec(sql[0])
}
s.err(err)
}
func (s *Do) save() *Do {
if s.model.PrimaryKeyZero() {
s.create()
} else {
s.update()
} }
return s return s
} }
func (s *Orm) querySql(out interface{}) { func (s *Do) prepareCreateSql() *Do {
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql()) columns, values := s.model.ColumnsAndValues("create")
return
var sqls []string
for _, value := range values {
sqls = append(sqls, s.addToVars(value))
}
s.Sql = fmt.Sprintf(
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
s.TableName,
strings.Join(s.quoteMap(columns), ","),
strings.Join(sqls, ","),
s.model.ReturningStr(),
)
return s
} }
func (s *Orm) query(out interface{}) { func (s *Do) create() *Do {
s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave"))
s.prepareCreateSql()
if len(s.Errors) == 0 {
var id int64
if s.driver == "postgres" {
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
} else {
var err error
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
s.err(err)
id, err = s.SqlResult.LastInsertId()
s.err(err)
}
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
}
return s
}
func (s *Do) prepareUpdateSql() *Do {
columns, values := s.model.ColumnsAndValues("update")
var sets []string
for index, column := range columns {
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))
}
s.Sql = fmt.Sprintf(
"UPDATE %v SET %v %v",
s.TableName,
strings.Join(sets, ", "),
s.combinedSql(),
)
return s
}
func (s *Do) update() *Do {
s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave"))
if len(s.Errors) == 0 {
s.prepareUpdateSql().Exec()
}
s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave"))
return s
}
func (s *Do) prepareDeleteSql() *Do {
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
return s
}
func (s *Do) delete() *Do {
s.err(s.model.callMethod("BeforeDelete"))
if len(s.Errors) == 0 {
s.prepareDeleteSql().Exec()
}
s.err(s.model.callMethod("AfterDelete"))
return s
}
func (s *Do) prepareQuerySql() *Do {
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql())
return s
}
func (s *Do) query() {
var ( var (
is_slice bool is_slice bool
dest_type reflect.Type dest_type reflect.Type
) )
dest_out := reflect.Indirect(reflect.ValueOf(out)) dest_out := reflect.Indirect(reflect.ValueOf(s.value))
if x := dest_out.Kind(); x == reflect.Slice { if x := dest_out.Kind(); x == reflect.Slice {
is_slice = true is_slice = true
dest_type = dest_out.Type().Elem() dest_type = dest_out.Type().Elem()
} }
s.prepareQuerySql()
rows, err := s.db.Query(s.Sql, s.SqlVars...) rows, err := s.db.Query(s.Sql, s.SqlVars...)
defer rows.Close() defer rows.Close()
s.err(err) s.err(err)
if rows.Err() != nil { if rows.Err() != nil {
s.err(rows.Err()) s.err(rows.Err())
} }
@ -65,7 +188,7 @@ func (s *Orm) query(out interface{}) {
if is_slice { if is_slice {
dest = reflect.New(dest_type).Elem() dest = reflect.New(dest_type).Elem()
} else { } else {
dest = reflect.ValueOf(out).Elem() dest = reflect.ValueOf(s.value).Elem()
} }
columns, _ := rows.Columns() columns, _ := rows.Columns()
@ -85,10 +208,10 @@ func (s *Orm) query(out interface{}) {
} }
} }
func (s *Orm) pluck(value interface{}) { func (s *Do) pluck(value interface{}) *Do {
dest_out := reflect.Indirect(reflect.ValueOf(value)) dest_out := reflect.Indirect(reflect.ValueOf(value))
dest_type := dest_out.Type().Elem() dest_type := dest_out.Type().Elem()
s.prepareQuerySql()
rows, err := s.db.Query(s.Sql, s.SqlVars...) rows, err := s.db.Query(s.Sql, s.SqlVars...)
s.err(err) s.err(err)
@ -106,93 +229,10 @@ func (s *Orm) pluck(value interface{}) {
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
} }
} }
return return s
} }
func (s *Orm) createSql(value interface{}) { func (s *Do) buildWhereCondition(clause map[string]interface{}) string {
columns, values := s.model.ColumnsAndValues("create")
var sqls []string
for _, value := range values {
sqls = append(sqls, s.addToVars(value))
}
s.Sql = fmt.Sprintf(
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
s.TableName,
strings.Join(s.quoteMap(columns), ","),
strings.Join(sqls, ","),
s.model.ReturningStr(),
)
return
}
func (s *Orm) create(value interface{}) {
var id int64
s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave"))
s.explain(value, "Create")
if len(s.Errors) == 0 {
if s.driver == "postgres" {
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
} else {
var err error
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
s.err(err)
id, err = s.SqlResult.LastInsertId()
s.err(err)
}
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
}
}
func (s *Orm) updateSql(value interface{}) {
columns, values := s.model.ColumnsAndValues("update")
var sets []string
for index, column := range columns {
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))
}
s.Sql = fmt.Sprintf(
"UPDATE %v SET %v %v",
s.TableName,
strings.Join(sets, ", "),
s.combinedSql(),
)
return
}
func (s *Orm) update(value interface{}) {
s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave"))
if len(s.Errors) == 0 {
s.explain(value, "Update").Exec()
}
s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave"))
return
}
func (s *Orm) deleteSql(value interface{}) {
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
return
}
func (s *Orm) delete(value interface{}) {
s.err(s.model.callMethod("BeforeDelete"))
if len(s.Errors) == 0 {
s.Exec()
}
s.err(s.model.callMethod("AfterDelete"))
}
func (s *Orm) buildWhereCondition(clause map[string]interface{}) string {
str := "( " + clause["query"].(string) + " )" str := "( " + clause["query"].(string) + " )"
args := clause["args"].([]interface{}) args := clause["args"].([]interface{})
@ -218,11 +258,11 @@ func (s *Orm) buildWhereCondition(clause map[string]interface{}) string {
return str return str
} }
func (s *Orm) whereSql() (sql string) { func (s *Do) whereSql() (sql string) {
var primary_condiation string var primary_condiation string
var and_conditions, or_conditions []string var and_conditions, or_conditions []string
if !s.model.PrimaryKeyIsEmpty() { if !s.model.PrimaryKeyZero() {
primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue())) primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue()))
} }
@ -256,7 +296,7 @@ func (s *Orm) whereSql() (sql string) {
return return
} }
func (s *Orm) selectSql() string { func (s *Do) selectSql() string {
if len(s.selectStr) == 0 { if len(s.selectStr) == 0 {
return " * " return " * "
} else { } else {
@ -264,7 +304,7 @@ func (s *Orm) selectSql() string {
} }
} }
func (s *Orm) orderSql() string { func (s *Do) orderSql() string {
if len(s.orderStrs) == 0 { if len(s.orderStrs) == 0 {
return "" return ""
} else { } else {
@ -272,7 +312,7 @@ func (s *Orm) orderSql() string {
} }
} }
func (s *Orm) limitSql() string { func (s *Do) limitSql() string {
if len(s.limitStr) == 0 { if len(s.limitStr) == 0 {
return "" return ""
} else { } else {
@ -280,7 +320,7 @@ func (s *Orm) limitSql() string {
} }
} }
func (s *Orm) offsetSql() string { func (s *Do) offsetSql() string {
if len(s.offsetStr) == 0 { if len(s.offsetStr) == 0 {
return "" return ""
} else { } else {
@ -288,11 +328,11 @@ func (s *Orm) offsetSql() string {
} }
} }
func (s *Orm) combinedSql() string { func (s *Do) combinedSql() string {
return s.whereSql() + s.orderSql() + s.limitSql() + s.offsetSql() return s.whereSql() + s.orderSql() + s.limitSql() + s.offsetSql()
} }
func (s *Orm) addToVars(value interface{}) string { func (s *Do) createTable() *Do {
s.SqlVars = append(s.SqlVars, value) s.Sql = s.model.CreateTable()
return fmt.Sprintf("$%d", len(s.SqlVars)) return s
} }

28
main.go
View File

@ -17,54 +17,54 @@ func Open(driver, source string) (db DB, err error) {
return return
} }
func (s *DB) buildORM() *Orm { func (s *DB) buildORM() *Chain {
return &Orm{db: s.Db, driver: s.Driver} return &Chain{db: s.Db, driver: s.Driver}
} }
func (s *DB) Where(querystring interface{}, args ...interface{}) *Orm { func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
return s.buildORM().Where(querystring, args...) return s.buildORM().Where(querystring, args...)
} }
func (s *DB) First(out interface{}) *Orm { func (s *DB) First(out interface{}) *Chain {
return s.buildORM().First(out) return s.buildORM().First(out)
} }
func (s *DB) Find(out interface{}) *Orm { func (s *DB) Find(out interface{}) *Chain {
return s.buildORM().Find(out) return s.buildORM().Find(out)
} }
func (s *DB) Limit(value interface{}) *Orm { func (s *DB) Limit(value interface{}) *Chain {
return s.buildORM().Limit(value) return s.buildORM().Limit(value)
} }
func (s *DB) Offset(value interface{}) *Orm { func (s *DB) Offset(value interface{}) *Chain {
return s.buildORM().Offset(value) return s.buildORM().Offset(value)
} }
func (s *DB) Order(value string, reorder ...bool) *Orm { func (s *DB) Order(value string, reorder ...bool) *Chain {
return s.buildORM().Order(value, reorder...) return s.buildORM().Order(value, reorder...)
} }
func (s *DB) Select(value interface{}) *Orm { func (s *DB) Select(value interface{}) *Chain {
return s.buildORM().Select(value) return s.buildORM().Select(value)
} }
func (s *DB) Save(value interface{}) *Orm { func (s *DB) Save(value interface{}) *Chain {
return s.buildORM().Save(value) return s.buildORM().Save(value)
} }
func (s *DB) Delete(value interface{}) *Orm { func (s *DB) Delete(value interface{}) *Chain {
return s.buildORM().Delete(value) return s.buildORM().Delete(value)
} }
func (s *DB) Exec(sql string) *Orm { func (s *DB) Exec(sql string) *Chain {
return s.buildORM().Exec(sql) return s.buildORM().Exec(sql)
} }
func (s *DB) Model(value interface{}) *Orm { func (s *DB) Model(value interface{}) *Chain {
return s.buildORM().Model(value) return s.buildORM().Model(value)
} }
func (s *DB) CreateTable(value interface{}) *Orm { func (s *DB) CreateTable(value interface{}) *Chain {
return s.buildORM().CreateTable(value) return s.buildORM().CreateTable(value)
} }

View File

@ -23,11 +23,7 @@ type Field struct {
IsPrimaryKey bool IsPrimaryKey bool
} }
func (s *Orm) toModel(value interface{}) *Model { func (m *Model) PrimaryKeyZero() bool {
return &Model{Data: value, driver: s.driver}
}
func (m *Model) PrimaryKeyIsEmpty() bool {
return m.PrimaryKeyValue() == 0 return m.PrimaryKeyValue() == 0
} }

173
orm.go
View File

@ -1,173 +0,0 @@
package gorm
import (
"database/sql"
"errors"
"strconv"
)
type Orm struct {
TableName string
PrimaryKey string
SqlResult sql.Result
Sql string
SqlVars []interface{}
model *Model
Errors []error
Error error
db *sql.DB
driver string
whereClause []map[string]interface{}
orClause []map[string]interface{}
selectStr string
orderStrs []string
offsetStr string
limitStr string
operation string
}
func (s *Orm) err(err error) {
if err != nil {
s.Errors = append(s.Errors, err)
s.Error = err
}
}
func (s *Orm) Copy() *Orm {
c := *s
c.SqlVars = c.SqlVars[:0]
return &c
}
func (s *Orm) Model(model interface{}) *Orm {
s.model = s.toModel(model)
s.TableName = s.model.TableName()
s.PrimaryKey = s.model.PrimaryKeyDb()
return s
}
func (s *Orm) Where(querystring interface{}, args ...interface{}) *Orm {
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
func (s *Orm) Limit(value interface{}) *Orm {
switch value := value.(type) {
case string:
s.limitStr = value
case int:
if value < 0 {
s.limitStr = ""
} else {
s.limitStr = strconv.Itoa(value)
}
default:
s.err(errors.New("Can' understand the value of Limit, Should be int"))
}
return s
}
func (s *Orm) Offset(value interface{}) *Orm {
switch value := value.(type) {
case string:
s.offsetStr = value
case int:
if value < 0 {
s.offsetStr = ""
} else {
s.offsetStr = strconv.Itoa(value)
}
default:
s.err(errors.New("Can' understand the value of Offset, Should be int"))
}
return s
}
func (s *Orm) Order(value string, reorder ...bool) *Orm {
defer s.validSql(value)
if len(reorder) > 0 && reorder[0] {
s.orderStrs = append([]string{}, value)
} else {
s.orderStrs = append(s.orderStrs, value)
}
return s
}
func (s *Orm) Count() int64 {
return 0
}
func (s *Orm) Select(value interface{}) *Orm {
defer func() { s.validSql(s.selectStr) }()
switch value := value.(type) {
case string:
s.selectStr = value
default:
s.err(errors.New("Can' understand the value of Select, Should be string"))
}
return s
}
func (s *Orm) Save(value interface{}) *Orm {
s.Model(value)
if s.model.PrimaryKeyIsEmpty() {
s.create(value)
} else {
s.update(value)
}
return s.Copy()
}
func (s *Orm) Delete(value interface{}) *Orm {
s.explain(value, "Delete").delete(value)
return s.Copy()
}
func (s *Orm) Update(column string, value string) *Orm {
return s
}
func (s *Orm) Updates(values map[string]string) *Orm {
return s
}
func (s *Orm) Exec(sql ...string) *Orm {
var err error
if len(sql) == 0 {
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
} else {
s.SqlResult, err = s.db.Exec(sql[0])
}
s.err(err)
return s.Copy()
}
func (s *Orm) First(out interface{}) *Orm {
s.explain(out, "Query").query(out)
return s.Copy()
}
func (s *Orm) Find(out interface{}) *Orm {
s.explain(out, "Query").query(out)
return s.Copy()
}
func (s *Orm) Pluck(column string, value interface{}) (orm *Orm) {
s.Select(column).explain(s.model.Data, "Query").pluck(value)
return s.Copy()
}
func (s *Orm) Or(querystring interface{}, args ...interface{}) *Orm {
s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
func (s *Orm) CreateTable(value interface{}) *Orm {
s.explain(value, "CreateTable").Exec()
return s
}

View File

@ -46,6 +46,7 @@ func init() {
if orm.Error != nil { if orm.Error != nil {
panic("No error should raise when create table") panic("No error should raise when create table")
} }
db.CreateTable(&Product{}) db.CreateTable(&Product{})
var shortForm = "2006-01-02 15:04:05" var shortForm = "2006-01-02 15:04:05"
@ -62,8 +63,8 @@ func init() {
} }
func TestFirst(t *testing.T) { func TestFirst(t *testing.T) {
var u1, u2 User // var u1, u2 User
db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2) // db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2)
} }
func TestSaveAndFind(t *testing.T) { func TestSaveAndFind(t *testing.T) {

View File

@ -7,11 +7,11 @@ import (
"strings" "strings"
) )
func (s *Orm) quote(value string) string { func (s *Do) quote(value string) string {
return "\"" + value + "\"" return "\"" + value + "\""
} }
func (s *Orm) quoteMap(values []string) (results []string) { func (s *Do) quoteMap(values []string) (results []string) {
for _, value := range values { for _, value := range values {
results = append(results, s.quote(value)) results = append(results, s.quote(value))
} }