mirror of https://github.com/go-gorm/gorm.git
Refact
This commit is contained in:
parent
f9658716e4
commit
f892a52cad
|
@ -0,0 +1,177 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Chain struct {
|
||||
db *sql.DB
|
||||
driver string
|
||||
value interface{}
|
||||
|
||||
Errors []error
|
||||
Error error
|
||||
|
||||
whereClause []map[string]interface{}
|
||||
orClause []map[string]interface{}
|
||||
selectStr string
|
||||
orderStrs []string
|
||||
offsetStr string
|
||||
limitStr string
|
||||
}
|
||||
|
||||
func (s *Chain) err(err error) {
|
||||
if err != nil {
|
||||
s.Errors = append(s.Errors, err)
|
||||
s.Error = err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Chain) do(value interface{}) *Do {
|
||||
var do Do
|
||||
do.chain = s
|
||||
do.db = s.db
|
||||
do.driver = s.driver
|
||||
|
||||
do.whereClause = s.whereClause
|
||||
do.orClause = s.orClause
|
||||
do.selectStr = s.selectStr
|
||||
do.orderStrs = s.orderStrs
|
||||
do.offsetStr = s.offsetStr
|
||||
do.limitStr = s.limitStr
|
||||
|
||||
do.setModel(value)
|
||||
return &do
|
||||
}
|
||||
|
||||
func (s *Chain) Model(model interface{}) *Chain {
|
||||
s.value = model
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Where(querystring interface{}, args ...interface{}) *Chain {
|
||||
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Limit(value interface{}) *Chain {
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
s.limitStr = value
|
||||
case int:
|
||||
if value < 0 {
|
||||
s.limitStr = ""
|
||||
} else {
|
||||
s.limitStr = strconv.Itoa(value)
|
||||
}
|
||||
default:
|
||||
s.err(errors.New("Can' understand the value of Limit, Should be int"))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Offset(value interface{}) *Chain {
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
s.offsetStr = value
|
||||
case int:
|
||||
if value < 0 {
|
||||
s.offsetStr = ""
|
||||
} else {
|
||||
s.offsetStr = strconv.Itoa(value)
|
||||
}
|
||||
default:
|
||||
s.err(errors.New("Can' understand the value of Offset, Should be int"))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Order(value string, reorder ...bool) *Chain {
|
||||
defer s.validSql(value)
|
||||
if len(reorder) > 0 && reorder[0] {
|
||||
s.orderStrs = append([]string{}, value)
|
||||
} else {
|
||||
s.orderStrs = append(s.orderStrs, value)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Count() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Chain) Select(value interface{}) *Chain {
|
||||
defer func() { s.validSql(s.selectStr) }()
|
||||
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
s.selectStr = value
|
||||
default:
|
||||
s.err(errors.New("Can' understand the value of Select, Should be string"))
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Save(value interface{}) *Chain {
|
||||
s.do(value).save()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Delete(value interface{}) *Chain {
|
||||
s.do(value).delete()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Update(column string, value string) *Chain {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Updates(values map[string]string) *Chain {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Exec(sql string) *Chain {
|
||||
var err error
|
||||
_, err = s.db.Exec(sql)
|
||||
s.err(err)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) First(out interface{}) *Chain {
|
||||
s.do(out).query()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Find(out interface{}) *Chain {
|
||||
s.do(out).query()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
|
||||
s.Select(column).do(s.value).pluck(value)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain {
|
||||
s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) CreateTable(value interface{}) *Chain {
|
||||
s.do(value).createTable().Exec()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) validSql(str string) (result bool) {
|
||||
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
|
||||
if !result {
|
||||
s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str)))
|
||||
}
|
||||
return
|
||||
}
|
286
sql.go → do.go
286
sql.go → do.go
|
@ -1,59 +1,182 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *Orm) validSql(str string) (result bool) {
|
||||
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
|
||||
if !result {
|
||||
s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str)))
|
||||
}
|
||||
return
|
||||
type Do struct {
|
||||
chain *Chain
|
||||
db *sql.DB
|
||||
driver string
|
||||
TableName string
|
||||
Errors []error
|
||||
|
||||
model *Model
|
||||
value interface{}
|
||||
SqlResult sql.Result
|
||||
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
|
||||
whereClause []map[string]interface{}
|
||||
orClause []map[string]interface{}
|
||||
selectStr string
|
||||
orderStrs []string
|
||||
offsetStr string
|
||||
limitStr string
|
||||
operation string
|
||||
}
|
||||
|
||||
func (s *Orm) explain(value interface{}, operation string) *Orm {
|
||||
s.Model(value)
|
||||
func (s *Do) err(err error) {
|
||||
if err != nil {
|
||||
s.Errors = append(s.Errors, err)
|
||||
s.chain.err(err)
|
||||
}
|
||||
}
|
||||
|
||||
switch operation {
|
||||
case "Create":
|
||||
s.createSql(value)
|
||||
case "Update":
|
||||
s.updateSql(value)
|
||||
case "Delete":
|
||||
s.deleteSql(value)
|
||||
case "Query":
|
||||
s.querySql(value)
|
||||
case "CreateTable":
|
||||
s.Sql = s.model.CreateTable()
|
||||
func (s *Do) setModel(value interface{}) {
|
||||
s.value = value
|
||||
s.model = &Model{Data: value, driver: s.driver}
|
||||
s.TableName = s.model.TableName()
|
||||
}
|
||||
|
||||
func (s *Do) addToVars(value interface{}) string {
|
||||
s.SqlVars = append(s.SqlVars, value)
|
||||
return fmt.Sprintf("$%d", len(s.SqlVars))
|
||||
}
|
||||
|
||||
func (s *Do) Exec(sql ...string) {
|
||||
var err error
|
||||
if len(sql) == 0 {
|
||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
||||
} else {
|
||||
s.SqlResult, err = s.db.Exec(sql[0])
|
||||
}
|
||||
s.err(err)
|
||||
}
|
||||
|
||||
func (s *Do) save() *Do {
|
||||
if s.model.PrimaryKeyZero() {
|
||||
s.create()
|
||||
} else {
|
||||
s.update()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) querySql(out interface{}) {
|
||||
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql())
|
||||
return
|
||||
func (s *Do) prepareCreateSql() *Do {
|
||||
columns, values := s.model.ColumnsAndValues("create")
|
||||
|
||||
var sqls []string
|
||||
for _, value := range values {
|
||||
sqls = append(sqls, s.addToVars(value))
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
||||
s.TableName,
|
||||
strings.Join(s.quoteMap(columns), ","),
|
||||
strings.Join(sqls, ","),
|
||||
s.model.ReturningStr(),
|
||||
)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) query(out interface{}) {
|
||||
func (s *Do) create() *Do {
|
||||
s.err(s.model.callMethod("BeforeCreate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
|
||||
s.prepareCreateSql()
|
||||
|
||||
if len(s.Errors) == 0 {
|
||||
var id int64
|
||||
if s.driver == "postgres" {
|
||||
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
||||
} else {
|
||||
var err error
|
||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
||||
s.err(err)
|
||||
id, err = s.SqlResult.LastInsertId()
|
||||
s.err(err)
|
||||
}
|
||||
result := reflect.ValueOf(s.model.Data).Elem()
|
||||
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
|
||||
|
||||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) prepareUpdateSql() *Do {
|
||||
columns, values := s.model.ColumnsAndValues("update")
|
||||
var sets []string
|
||||
for index, column := range columns {
|
||||
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
s.TableName,
|
||||
strings.Join(sets, ", "),
|
||||
s.combinedSql(),
|
||||
)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) update() *Do {
|
||||
s.err(s.model.callMethod("BeforeUpdate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.prepareUpdateSql().Exec()
|
||||
}
|
||||
s.err(s.model.callMethod("AfterUpdate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) prepareDeleteSql() *Do {
|
||||
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) delete() *Do {
|
||||
s.err(s.model.callMethod("BeforeDelete"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.prepareDeleteSql().Exec()
|
||||
}
|
||||
s.err(s.model.callMethod("AfterDelete"))
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) prepareQuerySql() *Do {
|
||||
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql())
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) query() {
|
||||
var (
|
||||
is_slice bool
|
||||
dest_type reflect.Type
|
||||
)
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(out))
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(s.value))
|
||||
|
||||
if x := dest_out.Kind(); x == reflect.Slice {
|
||||
is_slice = true
|
||||
dest_type = dest_out.Type().Elem()
|
||||
}
|
||||
|
||||
s.prepareQuerySql()
|
||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
||||
defer rows.Close()
|
||||
s.err(err)
|
||||
|
||||
if rows.Err() != nil {
|
||||
s.err(rows.Err())
|
||||
}
|
||||
|
@ -65,7 +188,7 @@ func (s *Orm) query(out interface{}) {
|
|||
if is_slice {
|
||||
dest = reflect.New(dest_type).Elem()
|
||||
} else {
|
||||
dest = reflect.ValueOf(out).Elem()
|
||||
dest = reflect.ValueOf(s.value).Elem()
|
||||
}
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
|
@ -85,10 +208,10 @@ func (s *Orm) query(out interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Orm) pluck(value interface{}) {
|
||||
func (s *Do) pluck(value interface{}) *Do {
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||
dest_type := dest_out.Type().Elem()
|
||||
|
||||
s.prepareQuerySql()
|
||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
||||
s.err(err)
|
||||
|
||||
|
@ -106,93 +229,10 @@ func (s *Orm) pluck(value interface{}) {
|
|||
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
|
||||
}
|
||||
}
|
||||
return
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) createSql(value interface{}) {
|
||||
columns, values := s.model.ColumnsAndValues("create")
|
||||
|
||||
var sqls []string
|
||||
for _, value := range values {
|
||||
sqls = append(sqls, s.addToVars(value))
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
||||
s.TableName,
|
||||
strings.Join(s.quoteMap(columns), ","),
|
||||
strings.Join(sqls, ","),
|
||||
s.model.ReturningStr(),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Orm) create(value interface{}) {
|
||||
var id int64
|
||||
s.err(s.model.callMethod("BeforeCreate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
s.explain(value, "Create")
|
||||
|
||||
if len(s.Errors) == 0 {
|
||||
if s.driver == "postgres" {
|
||||
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
||||
} else {
|
||||
var err error
|
||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
||||
s.err(err)
|
||||
id, err = s.SqlResult.LastInsertId()
|
||||
s.err(err)
|
||||
}
|
||||
result := reflect.ValueOf(s.model.Data).Elem()
|
||||
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
|
||||
|
||||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Orm) updateSql(value interface{}) {
|
||||
columns, values := s.model.ColumnsAndValues("update")
|
||||
var sets []string
|
||||
for index, column := range columns {
|
||||
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
s.TableName,
|
||||
strings.Join(sets, ", "),
|
||||
s.combinedSql(),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Orm) update(value interface{}) {
|
||||
s.err(s.model.callMethod("BeforeUpdate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.explain(value, "Update").Exec()
|
||||
}
|
||||
s.err(s.model.callMethod("AfterUpdate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Orm) deleteSql(value interface{}) {
|
||||
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Orm) delete(value interface{}) {
|
||||
s.err(s.model.callMethod("BeforeDelete"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.Exec()
|
||||
}
|
||||
s.err(s.model.callMethod("AfterDelete"))
|
||||
}
|
||||
|
||||
func (s *Orm) buildWhereCondition(clause map[string]interface{}) string {
|
||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) string {
|
||||
str := "( " + clause["query"].(string) + " )"
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
|
@ -218,11 +258,11 @@ func (s *Orm) buildWhereCondition(clause map[string]interface{}) string {
|
|||
return str
|
||||
}
|
||||
|
||||
func (s *Orm) whereSql() (sql string) {
|
||||
func (s *Do) whereSql() (sql string) {
|
||||
var primary_condiation string
|
||||
var and_conditions, or_conditions []string
|
||||
|
||||
if !s.model.PrimaryKeyIsEmpty() {
|
||||
if !s.model.PrimaryKeyZero() {
|
||||
primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue()))
|
||||
}
|
||||
|
||||
|
@ -256,7 +296,7 @@ func (s *Orm) whereSql() (sql string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Orm) selectSql() string {
|
||||
func (s *Do) selectSql() string {
|
||||
if len(s.selectStr) == 0 {
|
||||
return " * "
|
||||
} else {
|
||||
|
@ -264,7 +304,7 @@ func (s *Orm) selectSql() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Orm) orderSql() string {
|
||||
func (s *Do) orderSql() string {
|
||||
if len(s.orderStrs) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
|
@ -272,7 +312,7 @@ func (s *Orm) orderSql() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Orm) limitSql() string {
|
||||
func (s *Do) limitSql() string {
|
||||
if len(s.limitStr) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
|
@ -280,7 +320,7 @@ func (s *Orm) limitSql() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Orm) offsetSql() string {
|
||||
func (s *Do) offsetSql() string {
|
||||
if len(s.offsetStr) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
|
@ -288,11 +328,11 @@ func (s *Orm) offsetSql() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Orm) combinedSql() string {
|
||||
func (s *Do) combinedSql() string {
|
||||
return s.whereSql() + s.orderSql() + s.limitSql() + s.offsetSql()
|
||||
}
|
||||
|
||||
func (s *Orm) addToVars(value interface{}) string {
|
||||
s.SqlVars = append(s.SqlVars, value)
|
||||
return fmt.Sprintf("$%d", len(s.SqlVars))
|
||||
func (s *Do) createTable() *Do {
|
||||
s.Sql = s.model.CreateTable()
|
||||
return s
|
||||
}
|
28
main.go
28
main.go
|
@ -17,54 +17,54 @@ func Open(driver, source string) (db DB, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *DB) buildORM() *Orm {
|
||||
return &Orm{db: s.Db, driver: s.Driver}
|
||||
func (s *DB) buildORM() *Chain {
|
||||
return &Chain{db: s.Db, driver: s.Driver}
|
||||
}
|
||||
|
||||
func (s *DB) Where(querystring interface{}, args ...interface{}) *Orm {
|
||||
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
|
||||
return s.buildORM().Where(querystring, args...)
|
||||
}
|
||||
|
||||
func (s *DB) First(out interface{}) *Orm {
|
||||
func (s *DB) First(out interface{}) *Chain {
|
||||
return s.buildORM().First(out)
|
||||
}
|
||||
|
||||
func (s *DB) Find(out interface{}) *Orm {
|
||||
func (s *DB) Find(out interface{}) *Chain {
|
||||
return s.buildORM().Find(out)
|
||||
}
|
||||
|
||||
func (s *DB) Limit(value interface{}) *Orm {
|
||||
func (s *DB) Limit(value interface{}) *Chain {
|
||||
return s.buildORM().Limit(value)
|
||||
}
|
||||
|
||||
func (s *DB) Offset(value interface{}) *Orm {
|
||||
func (s *DB) Offset(value interface{}) *Chain {
|
||||
return s.buildORM().Offset(value)
|
||||
}
|
||||
|
||||
func (s *DB) Order(value string, reorder ...bool) *Orm {
|
||||
func (s *DB) Order(value string, reorder ...bool) *Chain {
|
||||
return s.buildORM().Order(value, reorder...)
|
||||
}
|
||||
|
||||
func (s *DB) Select(value interface{}) *Orm {
|
||||
func (s *DB) Select(value interface{}) *Chain {
|
||||
return s.buildORM().Select(value)
|
||||
}
|
||||
|
||||
func (s *DB) Save(value interface{}) *Orm {
|
||||
func (s *DB) Save(value interface{}) *Chain {
|
||||
return s.buildORM().Save(value)
|
||||
}
|
||||
|
||||
func (s *DB) Delete(value interface{}) *Orm {
|
||||
func (s *DB) Delete(value interface{}) *Chain {
|
||||
return s.buildORM().Delete(value)
|
||||
}
|
||||
|
||||
func (s *DB) Exec(sql string) *Orm {
|
||||
func (s *DB) Exec(sql string) *Chain {
|
||||
return s.buildORM().Exec(sql)
|
||||
}
|
||||
|
||||
func (s *DB) Model(value interface{}) *Orm {
|
||||
func (s *DB) Model(value interface{}) *Chain {
|
||||
return s.buildORM().Model(value)
|
||||
}
|
||||
|
||||
func (s *DB) CreateTable(value interface{}) *Orm {
|
||||
func (s *DB) CreateTable(value interface{}) *Chain {
|
||||
return s.buildORM().CreateTable(value)
|
||||
}
|
||||
|
|
6
model.go
6
model.go
|
@ -23,11 +23,7 @@ type Field struct {
|
|||
IsPrimaryKey bool
|
||||
}
|
||||
|
||||
func (s *Orm) toModel(value interface{}) *Model {
|
||||
return &Model{Data: value, driver: s.driver}
|
||||
}
|
||||
|
||||
func (m *Model) PrimaryKeyIsEmpty() bool {
|
||||
func (m *Model) PrimaryKeyZero() bool {
|
||||
return m.PrimaryKeyValue() == 0
|
||||
}
|
||||
|
||||
|
|
173
orm.go
173
orm.go
|
@ -1,173 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Orm struct {
|
||||
TableName string
|
||||
PrimaryKey string
|
||||
SqlResult sql.Result
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
model *Model
|
||||
|
||||
Errors []error
|
||||
Error error
|
||||
|
||||
db *sql.DB
|
||||
driver string
|
||||
whereClause []map[string]interface{}
|
||||
orClause []map[string]interface{}
|
||||
selectStr string
|
||||
orderStrs []string
|
||||
offsetStr string
|
||||
limitStr string
|
||||
operation string
|
||||
}
|
||||
|
||||
func (s *Orm) err(err error) {
|
||||
if err != nil {
|
||||
s.Errors = append(s.Errors, err)
|
||||
s.Error = err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Orm) Copy() *Orm {
|
||||
c := *s
|
||||
c.SqlVars = c.SqlVars[:0]
|
||||
return &c
|
||||
}
|
||||
|
||||
func (s *Orm) Model(model interface{}) *Orm {
|
||||
s.model = s.toModel(model)
|
||||
s.TableName = s.model.TableName()
|
||||
s.PrimaryKey = s.model.PrimaryKeyDb()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Where(querystring interface{}, args ...interface{}) *Orm {
|
||||
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Limit(value interface{}) *Orm {
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
s.limitStr = value
|
||||
case int:
|
||||
if value < 0 {
|
||||
s.limitStr = ""
|
||||
} else {
|
||||
s.limitStr = strconv.Itoa(value)
|
||||
}
|
||||
default:
|
||||
s.err(errors.New("Can' understand the value of Limit, Should be int"))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Offset(value interface{}) *Orm {
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
s.offsetStr = value
|
||||
case int:
|
||||
if value < 0 {
|
||||
s.offsetStr = ""
|
||||
} else {
|
||||
s.offsetStr = strconv.Itoa(value)
|
||||
}
|
||||
default:
|
||||
s.err(errors.New("Can' understand the value of Offset, Should be int"))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Order(value string, reorder ...bool) *Orm {
|
||||
defer s.validSql(value)
|
||||
if len(reorder) > 0 && reorder[0] {
|
||||
s.orderStrs = append([]string{}, value)
|
||||
} else {
|
||||
s.orderStrs = append(s.orderStrs, value)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Count() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Orm) Select(value interface{}) *Orm {
|
||||
defer func() { s.validSql(s.selectStr) }()
|
||||
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
s.selectStr = value
|
||||
default:
|
||||
s.err(errors.New("Can' understand the value of Select, Should be string"))
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Save(value interface{}) *Orm {
|
||||
s.Model(value)
|
||||
if s.model.PrimaryKeyIsEmpty() {
|
||||
s.create(value)
|
||||
} else {
|
||||
s.update(value)
|
||||
}
|
||||
return s.Copy()
|
||||
}
|
||||
|
||||
func (s *Orm) Delete(value interface{}) *Orm {
|
||||
s.explain(value, "Delete").delete(value)
|
||||
return s.Copy()
|
||||
}
|
||||
|
||||
func (s *Orm) Update(column string, value string) *Orm {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Updates(values map[string]string) *Orm {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) Exec(sql ...string) *Orm {
|
||||
var err error
|
||||
if len(sql) == 0 {
|
||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
||||
} else {
|
||||
s.SqlResult, err = s.db.Exec(sql[0])
|
||||
}
|
||||
s.err(err)
|
||||
return s.Copy()
|
||||
}
|
||||
|
||||
func (s *Orm) First(out interface{}) *Orm {
|
||||
s.explain(out, "Query").query(out)
|
||||
return s.Copy()
|
||||
}
|
||||
|
||||
func (s *Orm) Find(out interface{}) *Orm {
|
||||
s.explain(out, "Query").query(out)
|
||||
return s.Copy()
|
||||
}
|
||||
|
||||
func (s *Orm) Pluck(column string, value interface{}) (orm *Orm) {
|
||||
s.Select(column).explain(s.model.Data, "Query").pluck(value)
|
||||
return s.Copy()
|
||||
}
|
||||
|
||||
func (s *Orm) Or(querystring interface{}, args ...interface{}) *Orm {
|
||||
s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args})
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Orm) CreateTable(value interface{}) *Orm {
|
||||
s.explain(value, "CreateTable").Exec()
|
||||
return s
|
||||
}
|
|
@ -46,6 +46,7 @@ func init() {
|
|||
if orm.Error != nil {
|
||||
panic("No error should raise when create table")
|
||||
}
|
||||
|
||||
db.CreateTable(&Product{})
|
||||
|
||||
var shortForm = "2006-01-02 15:04:05"
|
||||
|
@ -62,8 +63,8 @@ func init() {
|
|||
}
|
||||
|
||||
func TestFirst(t *testing.T) {
|
||||
var u1, u2 User
|
||||
db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2)
|
||||
// var u1, u2 User
|
||||
// db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2)
|
||||
}
|
||||
|
||||
func TestSaveAndFind(t *testing.T) {
|
||||
|
|
4
utils.go
4
utils.go
|
@ -7,11 +7,11 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
func (s *Orm) quote(value string) string {
|
||||
func (s *Do) quote(value string) string {
|
||||
return "\"" + value + "\""
|
||||
}
|
||||
|
||||
func (s *Orm) quoteMap(values []string) (results []string) {
|
||||
func (s *Do) quoteMap(values []string) (results []string) {
|
||||
for _, value := range values {
|
||||
results = append(results, s.quote(value))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue