package gorm import ( "database/sql" "errors" "fmt" "regexp" "strconv" ) type Chain struct { db *sql.DB driver string debug bool value interface{} Errors []error Error error whereClause []map[string]interface{} orClause []map[string]interface{} notClause []map[string]interface{} initAttrs []interface{} assignAttrs []interface{} selectStr string orderStrs []string offsetStr string limitStr string specifiedTableName string unscoped bool } func (s *Chain) msg(str string) { if s.debug { debug(str) } } func (s *Chain) err(err error) error { if err != nil { s.Errors = append(s.Errors, err) s.Error = err if s.debug { debug(err) } } return err } func (s *Chain) deleteLastError() { s.Error = nil s.Errors = s.Errors[:len(s.Errors)-1] } func (s *Chain) do(value interface{}) *Do { var do Do do.chain = s do.db = s.db do.driver = s.driver do.whereClause = s.whereClause do.orClause = s.orClause do.notClause = s.notClause do.selectStr = s.selectStr do.orderStrs = s.orderStrs do.offsetStr = s.offsetStr do.limitStr = s.limitStr do.specifiedTableName = s.specifiedTableName do.unscoped = s.unscoped s.value = value do.setModel(value) return &do } func (s *Chain) Model(model interface{}) *Chain { s.value = model return s } func (s *Chain) Where(querystring interface{}, args ...interface{}) *Chain { s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args}) return s } func (s *Chain) Not(querystring interface{}, args ...interface{}) *Chain { s.notClause = append(s.notClause, map[string]interface{}{"query": querystring, "args": args}) return s } func (s *Chain) Limit(value interface{}) *Chain { switch value := value.(type) { case string: s.limitStr = value case int: if value < 0 { s.limitStr = "" } else { s.limitStr = strconv.Itoa(value) } default: s.err(errors.New("Can' understand the value of Limit, Should be int")) } return s } func (s *Chain) Offset(value interface{}) *Chain { switch value := value.(type) { case string: s.offsetStr = value case int: if value < 0 { s.offsetStr = "" } else { s.offsetStr = strconv.Itoa(value) } default: s.err(errors.New("Can' understand the value of Offset, Should be int")) } return s } func (s *Chain) Order(value string, reorder ...bool) *Chain { defer s.validSql(value) if len(reorder) > 0 && reorder[0] { s.orderStrs = append([]string{}, value) } else { s.orderStrs = append(s.orderStrs, value) } return s } func (s *Chain) Count(value interface{}) *Chain { s.Select("count(*)").do(s.value).count(value) return s } func (s *Chain) Select(value interface{}) *Chain { defer func() { s.validSql(s.selectStr) }() switch value := value.(type) { case string: s.selectStr = value default: s.err(errors.New("Can' understand the value of Select, Should be string")) } return s } func (s *Chain) Save(value interface{}) *Chain { s.do(value).save() return s } func (s *Chain) Delete(value interface{}) *Chain { s.do(value).delete() return s } func (s *Chain) Update(attrs ...interface{}) *Chain { return s.Updates(toSearchableMap(attrs...), true) } func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain { s.do(s.value).setUpdateAttrs(values, ignore_protected_attrs...).update() return s } func (s *Chain) Exec(sql string) *Chain { s.do(nil).exec(sql) return s } func (s *Chain) First(out interface{}, where ...interface{}) *Chain { do := s.do(out) do.limitStr = "1" do.where(where...).query() return s } func (s *Chain) Attrs(attrs ...interface{}) *Chain { s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) return s } func (s *Chain) Assign(attrs ...interface{}) *Chain { s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) return s } func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain { if s.First(out, where...).Error != nil { s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() s.deleteLastError() } else { if len(s.assignAttrs) > 0 { s.do(out).setUpdateAttrs(s.assignAttrs).prepareUpdateAttrs() } } return s } func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain { if s.First(out, where...).Error != nil { s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() s.deleteLastError() s.Save(out) } else { if len(s.assignAttrs) > 0 { s.do(out).setUpdateAttrs(s.assignAttrs).update() } } return s } func (s *Chain) Find(out interface{}, where ...interface{}) *Chain { s.do(out).where(where...).query() return s } func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) { s.do(s.value).pluck(column, value) return s } func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain { s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args}) return s } func (s *Chain) CreateTable(value interface{}) *Chain { s.do(value).createTable().exec() return s } func (s *Chain) Unscoped() *Chain { s.unscoped = true return s } func (s *Chain) Table(name string) *Chain { s.specifiedTableName = name return s } func (s *Chain) Debug() *Chain { s.debug = true return s } func (s *Chain) validSql(str string) (result bool) { result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) if !result { s.err(errors.New(fmt.Sprintf("SQL is not valid, %s", str))) } return }