mirror of https://github.com/go-gorm/gorm.git
274 lines
5.8 KiB
Go
274 lines
5.8 KiB
Go
package gorm
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
)
|
|
|
|
type Chain struct {
|
|
db *sql.DB
|
|
driver string
|
|
debug bool
|
|
value interface{}
|
|
|
|
Errors []error
|
|
Error error
|
|
|
|
whereClause []map[string]interface{}
|
|
orClause []map[string]interface{}
|
|
notClause []map[string]interface{}
|
|
initAttrs []interface{}
|
|
assignAttrs []interface{}
|
|
selectStr string
|
|
orderStrs []string
|
|
offsetStr string
|
|
limitStr string
|
|
specifiedTableName string
|
|
unscoped bool
|
|
}
|
|
|
|
func (s *Chain) msg(str string) {
|
|
if s.debug {
|
|
debug(str)
|
|
}
|
|
}
|
|
|
|
func (s *Chain) err(err error) error {
|
|
if err != nil {
|
|
s.Errors = append(s.Errors, err)
|
|
s.Error = err
|
|
|
|
if s.debug {
|
|
debug(err)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Chain) deleteLastError() {
|
|
s.Error = nil
|
|
s.Errors = s.Errors[:len(s.Errors)-1]
|
|
}
|
|
|
|
func (s *Chain) do(value interface{}) *Do {
|
|
var do Do
|
|
do.chain = s
|
|
do.db = s.db
|
|
do.driver = s.driver
|
|
|
|
do.whereClause = s.whereClause
|
|
do.orClause = s.orClause
|
|
do.notClause = s.notClause
|
|
do.selectStr = s.selectStr
|
|
do.orderStrs = s.orderStrs
|
|
do.offsetStr = s.offsetStr
|
|
do.limitStr = s.limitStr
|
|
do.specifiedTableName = s.specifiedTableName
|
|
do.unscoped = s.unscoped
|
|
|
|
s.value = value
|
|
do.setModel(value)
|
|
return &do
|
|
}
|
|
|
|
func (s *Chain) Model(model interface{}) *Chain {
|
|
s.value = model
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Where(querystring interface{}, args ...interface{}) *Chain {
|
|
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Not(querystring interface{}, args ...interface{}) *Chain {
|
|
s.notClause = append(s.notClause, map[string]interface{}{"query": querystring, "args": args})
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Limit(value interface{}) *Chain {
|
|
switch value := value.(type) {
|
|
case string:
|
|
s.limitStr = value
|
|
case int:
|
|
if value < 0 {
|
|
s.limitStr = ""
|
|
} else {
|
|
s.limitStr = strconv.Itoa(value)
|
|
}
|
|
default:
|
|
s.err(errors.New("Can' understand the value of Limit, Should be int"))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Offset(value interface{}) *Chain {
|
|
switch value := value.(type) {
|
|
case string:
|
|
s.offsetStr = value
|
|
case int:
|
|
if value < 0 {
|
|
s.offsetStr = ""
|
|
} else {
|
|
s.offsetStr = strconv.Itoa(value)
|
|
}
|
|
default:
|
|
s.err(errors.New("Can' understand the value of Offset, Should be int"))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Order(value string, reorder ...bool) *Chain {
|
|
defer s.validSql(value)
|
|
if len(reorder) > 0 && reorder[0] {
|
|
s.orderStrs = append([]string{}, value)
|
|
} else {
|
|
s.orderStrs = append(s.orderStrs, value)
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Count(value interface{}) *Chain {
|
|
s.Select("count(*)").do(s.value).count(value)
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Select(value interface{}) *Chain {
|
|
defer func() { s.validSql(s.selectStr) }()
|
|
|
|
switch value := value.(type) {
|
|
case string:
|
|
s.selectStr = value
|
|
default:
|
|
s.err(errors.New("Can' understand the value of Select, Should be string"))
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Save(value interface{}) *Chain {
|
|
s.do(value).save()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Delete(value interface{}) *Chain {
|
|
s.do(value).delete()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Update(attrs ...interface{}) *Chain {
|
|
return s.Updates(toSearchableMap(attrs...), true)
|
|
}
|
|
|
|
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
|
|
s.do(s.value).setUpdateAttrs(values, ignore_protected_attrs...).update()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Exec(sql string) *Chain {
|
|
s.do(nil).exec(sql)
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) First(out interface{}, where ...interface{}) *Chain {
|
|
s.do(out).where(where...).first()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Last(out interface{}, where ...interface{}) *Chain {
|
|
s.do(out).where(where...).last()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Attrs(attrs ...interface{}) *Chain {
|
|
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Assign(attrs ...interface{}) *Chain {
|
|
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain {
|
|
if s.First(out, where...).Error != nil {
|
|
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
|
|
s.deleteLastError()
|
|
} else {
|
|
if len(s.assignAttrs) > 0 {
|
|
s.do(out).setUpdateAttrs(s.assignAttrs).prepareUpdateAttrs()
|
|
}
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain {
|
|
if s.First(out, where...).Error != nil {
|
|
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
|
|
s.deleteLastError()
|
|
s.Save(out)
|
|
} else {
|
|
if len(s.assignAttrs) > 0 {
|
|
s.do(out).setUpdateAttrs(s.assignAttrs).update()
|
|
}
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Find(out interface{}, where ...interface{}) *Chain {
|
|
s.do(out).where(where...).query()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
|
|
s.do(s.value).pluck(column, value)
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain {
|
|
s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args})
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) CreateTable(value interface{}) *Chain {
|
|
s.do(value).createTable().exec()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) DropTable(value interface{}) *Chain {
|
|
s.do(value).dropTable().exec()
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Unscoped() *Chain {
|
|
s.unscoped = true
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Table(name string) *Chain {
|
|
s.specifiedTableName = name
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain {
|
|
original_value := s.value
|
|
s.do(value).related(original_value, foreign_keys...)
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) Debug() *Chain {
|
|
s.debug = true
|
|
return s
|
|
}
|
|
|
|
func (s *Chain) validSql(str string) (result bool) {
|
|
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
|
|
if !result {
|
|
s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str)))
|
|
}
|
|
return
|
|
}
|