gorm/chain.go

311 lines
6.9 KiB
Go
Raw Normal View History

2013-10-27 15:41:58 +04:00
package gorm
import (
"errors"
"fmt"
"regexp"
)
type Chain struct {
2013-11-11 11:48:31 +04:00
d *DB
db sql_common
value interface{}
debug_mode bool
2013-10-27 15:41:58 +04:00
Errors []error
Error error
2013-10-28 16:27:25 +04:00
whereClause []map[string]interface{}
orClause []map[string]interface{}
2013-10-31 13:31:00 +04:00
notClause []map[string]interface{}
2013-10-30 11:21:58 +04:00
initAttrs []interface{}
2013-10-31 04:15:19 +04:00
assignAttrs []interface{}
2013-10-28 16:27:25 +04:00
selectStr string
orderStrs []string
offsetStr string
limitStr string
specifiedTableName string
unscoped bool
2013-10-27 15:41:58 +04:00
}
2013-11-11 09:40:35 +04:00
func (s *Chain) driver() string {
return s.d.driver
}
2013-11-11 09:53:04 +04:00
2013-10-29 03:39:26 +04:00
func (s *Chain) err(err error) error {
2013-10-27 15:41:58 +04:00
if err != nil {
s.Errors = append(s.Errors, err)
s.Error = err
2013-11-11 11:48:31 +04:00
s.warn(err)
2013-10-27 15:41:58 +04:00
}
2013-10-29 03:39:26 +04:00
return err
2013-10-27 15:41:58 +04:00
}
2013-11-11 07:53:56 +04:00
func (s *Chain) hasError() bool {
return len(s.Errors) > 0
}
func (s *Chain) deleteLastError() {
s.Error = nil
s.Errors = s.Errors[:len(s.Errors)-1]
}
func (s *Chain) do(value interface{}) *Do {
do := Do{
2013-11-10 19:07:09 +04:00
chain: s,
db: s.db,
whereClause: s.whereClause,
orClause: s.orClause,
notClause: s.notClause,
selectStr: s.selectStr,
orderStrs: s.orderStrs,
offsetStr: s.offsetStr,
limitStr: s.limitStr,
specifiedTableName: s.specifiedTableName,
unscoped: s.unscoped,
}
2013-10-27 15:41:58 +04:00
2013-10-27 16:07:13 +04:00
s.value = value
2013-10-27 15:41:58 +04:00
do.setModel(value)
return &do
2013-10-27 15:41:58 +04:00
}
func (s *Chain) Model(model interface{}) *Chain {
s.value = model
return s
}
func (s *Chain) Where(querystring interface{}, args ...interface{}) *Chain {
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
2013-10-31 13:31:00 +04:00
func (s *Chain) Not(querystring interface{}, args ...interface{}) *Chain {
s.notClause = append(s.notClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
2013-10-27 15:41:58 +04:00
func (s *Chain) Limit(value interface{}) *Chain {
2013-11-10 19:07:09 +04:00
if str, err := getInterfaceAsString(value); err == nil {
s.limitStr = str
} else {
2013-10-27 15:41:58 +04:00
s.err(errors.New("Can' understand the value of Limit, Should be int"))
}
return s
}
func (s *Chain) Offset(value interface{}) *Chain {
2013-11-10 19:07:09 +04:00
if str, err := getInterfaceAsString(value); err == nil {
s.offsetStr = str
} else {
2013-10-27 15:41:58 +04:00
s.err(errors.New("Can' understand the value of Offset, Should be int"))
}
return s
}
func (s *Chain) Order(value string, reorder ...bool) *Chain {
defer s.validSql(value)
if len(reorder) > 0 && reorder[0] {
2013-11-10 19:07:09 +04:00
s.orderStrs = []string{value}
2013-10-27 15:41:58 +04:00
} else {
s.orderStrs = append(s.orderStrs, value)
}
return s
}
2013-10-27 16:07:13 +04:00
func (s *Chain) Count(value interface{}) *Chain {
s.Select("count(*)").do(s.value).count(value)
return s
2013-10-27 15:41:58 +04:00
}
func (s *Chain) Select(value interface{}) *Chain {
defer func() { s.validSql(s.selectStr) }()
switch value := value.(type) {
case string:
s.selectStr = value
default:
s.err(errors.New("Can' understand the value of Select, Should be string"))
}
return s
}
func (s *Chain) Save(value interface{}) *Chain {
2013-11-11 18:27:17 +04:00
s.do(value).begin().save().commit_or_rollback()
2013-10-27 15:41:58 +04:00
return s
}
func (s *Chain) Delete(value interface{}) *Chain {
2013-11-11 18:27:17 +04:00
s.do(value).begin().delete().commit_or_rollback()
2013-10-27 15:41:58 +04:00
return s
}
func (s *Chain) Update(attrs ...interface{}) *Chain {
return s.Updates(toSearchableMap(attrs...), true)
2013-10-27 15:41:58 +04:00
}
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
2013-11-11 18:27:17 +04:00
s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
2013-10-27 15:41:58 +04:00
return s
}
func (s *Chain) Exec(sql string) *Chain {
2013-10-29 03:39:26 +04:00
s.do(nil).exec(sql)
2013-10-27 15:41:58 +04:00
return s
}
2013-10-27 16:54:23 +04:00
func (s *Chain) First(out interface{}, where ...interface{}) *Chain {
s.do(out).where(where...).first()
return s
}
func (s *Chain) Last(out interface{}, where ...interface{}) *Chain {
s.do(out).where(where...).last()
2013-10-27 15:41:58 +04:00
return s
}
func (s *Chain) Attrs(attrs ...interface{}) *Chain {
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
2013-10-30 11:21:58 +04:00
return s
}
func (s *Chain) Assign(attrs ...interface{}) *Chain {
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
2013-10-31 04:15:19 +04:00
return s
}
func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain {
2013-10-29 16:20:25 +04:00
if s.First(out, where...).Error != nil {
s.deleteLastError()
2013-11-10 19:07:09 +04:00
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
2013-10-31 04:15:19 +04:00
} else {
2013-10-31 05:34:27 +04:00
if len(s.assignAttrs) > 0 {
s.do(out).setUpdateAttrs(s.assignAttrs).prepareUpdateAttrs()
}
}
return s
}
func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain {
2013-10-29 16:20:25 +04:00
if s.First(out, where...).Error != nil {
s.deleteLastError()
2013-11-10 19:07:09 +04:00
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
2013-10-29 16:20:25 +04:00
s.Save(out)
2013-10-31 05:34:27 +04:00
} else {
if len(s.assignAttrs) > 0 {
s.do(out).setUpdateAttrs(s.assignAttrs).update()
}
2013-10-29 16:20:25 +04:00
}
return s
}
2013-10-27 16:54:23 +04:00
func (s *Chain) Find(out interface{}, where ...interface{}) *Chain {
2013-10-29 16:20:25 +04:00
s.do(out).where(where...).query()
2013-10-27 15:41:58 +04:00
return s
}
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
2013-10-29 03:39:26 +04:00
s.do(s.value).pluck(column, value)
2013-10-27 15:41:58 +04:00
return s
}
func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain {
s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
func (s *Chain) Unscoped() *Chain {
s.unscoped = true
return s
}
2013-10-28 16:27:25 +04:00
func (s *Chain) Table(name string) *Chain {
s.specifiedTableName = name
return s
}
func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain {
original_value := s.value
s.do(value).related(original_value, foreign_keys...)
return s
}
2013-11-11 09:16:08 +04:00
func (s *Chain) Begin() *Chain {
if db, ok := s.db.(sql_db); ok {
tx, err := db.Begin()
s.db = interface{}(tx).(sql_common)
s.err(err)
} else {
s.err(errors.New("Can't start a transaction."))
}
return s
}
2013-11-11 11:48:31 +04:00
func (s *Chain) Debug() *Chain {
s.debug_mode = true
return s
}
2013-11-11 09:16:08 +04:00
func (s *Chain) Commit() *Chain {
if db, ok := s.db.(sql_tx); ok {
s.err(db.Commit())
} else {
s.err(errors.New("Commit is not supported, no database transaction found."))
}
return s
}
func (s *Chain) Rollback() *Chain {
if db, ok := s.db.(sql_tx); ok {
s.err(db.Rollback())
} else {
s.err(errors.New("Rollback is not supported, no database transaction found."))
}
return s
}
func (s *Chain) CreateTable(value interface{}) *Chain {
s.do(value).createTable()
return s
}
func (s *Chain) DropTable(value interface{}) *Chain {
s.do(value).dropTable()
return s
}
func (s *Chain) AutoMigrate(value interface{}) *Chain {
s.do(value).autoMigrate()
return s
}
func (s *Chain) UpdateColumn(column string, typ string) *Chain {
s.do(s.value).updateColumn(column, typ)
return s
}
func (s *Chain) DropColumn(column string) *Chain {
s.do(s.value).dropColumn(column)
return s
}
func (s *Chain) AddIndex(column string, index_name ...string) *Chain {
s.do(s.value).addIndex(column, index_name...)
return s
}
func (s *Chain) RemoveIndex(column string) *Chain {
s.do(s.value).removeIndex(column)
return s
}
2013-10-27 15:41:58 +04:00
func (s *Chain) validSql(str string) (result bool) {
2013-11-10 05:57:34 +04:00
result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str)
2013-10-27 15:41:58 +04:00
if !result {
s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str)))
}
return
}