diff --git a/do.go b/do.go index cfc8a181..d7518487 100644 --- a/do.go +++ b/do.go @@ -14,6 +14,7 @@ import ( type Do struct { db *DB + search *search model *Model tableName string value interface{} @@ -27,10 +28,10 @@ type Do struct { func (s *Do) table() string { if len(s.tableName) == 0 { - if len(s.db.search.tableName) == 0 { + if len(s.search.tableName) == 0 { s.tableName = s.model.tableName() } else { - s.tableName = s.db.search.tableName + s.tableName = s.search.tableName } } return s.tableName @@ -46,6 +47,11 @@ func (s *Do) err(err error) error { func (s *Do) setModel(value interface{}) *Do { s.model = &Model{data: value, do: s} s.value = value + if s.db.search == nil { + s.search = &search{} + } else { + s.search = s.db.search + } return s } @@ -191,10 +197,12 @@ func (s *Do) create() (i interface{}) { } func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do { + ignore_protected := len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0] + switch value := values.(type) { case map[string]interface{}: if len(value) > 0 { - results, has_update := s.model.updatedColumnsAndValues(value) + results, has_update := s.model.updatedColumnsAndValues(value, ignore_protected) if len(results) > 0 { s.update_attrs = results } else if has_update { @@ -218,16 +226,9 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do return s } -func (s *Do) prepareUpdateAttrs() (results map[string]interface{}, update bool) { - if len(s.updateAttrs) > 0 { - results, update = s.model.updatedColumnsAndValues(s.updateAttrs) - } - return -} - -func (s *Do) prepareUpdateSql(results map[string]interface{}) { +func (s *Do) prepareUpdateSql() { var sqls []string - for key, value := range results { + for key, value := range s.update_attrs { sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) } @@ -245,19 +246,11 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) { } func (s *Do) update() *Do { - update_attrs := s.updateAttrs - if len(update_attrs) > 0 { - var need_update bool - if update_attrs, need_update = s.prepareUpdateAttrs(); !need_update { - return s - } - } - s.model.callMethod("BeforeUpdate") s.model.callMethod("BeforeSave") - s.saveBeforeAssociations() - s.prepareUpdateSql(update_attrs) + + s.prepareUpdateSql() if !s.db.hasError() { s.exec() @@ -274,7 +267,7 @@ func (s *Do) delete() *Do { s.model.callMethod("BeforeDelete") if !s.db.hasError() { - if !s.unscoped && s.model.hasColumn("DeletedAt") { + if !s.search.unscope && s.model.hasColumn("DeletedAt") { s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql()) } else { s.sql = fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql()) @@ -291,14 +284,12 @@ func (s *Do) prepareQuerySql() { } func (s *Do) first() { - s.limitStr = "1" - s.orderStrs = append(s.orderStrs, s.model.primaryKeyDb()) + s.search.order(s.model.primaryKeyDb()).limit(1) s.query() } func (s *Do) last() { - s.limitStr = "1" - s.orderStrs = append(s.orderStrs, s.model.primaryKeyDb()+" DESC") + s.search.order(s.model.primaryKeyDb() + " DESC").limit(1) s.query() } @@ -355,7 +346,7 @@ func (s *Do) query() { is_slice = true dest_type = dest_out.Type().Elem() } else { - s.limitStr = "1" + s.search.limit(1) } s.prepareQuerySql() @@ -411,7 +402,6 @@ func (s *Do) count(value interface{}) { } func (s *Do) pluck(column string, value interface{}) { - s.selectStr = column dest_out := reflect.Indirect(reflect.ValueOf(value)) if dest_out.Kind() != reflect.Slice { @@ -437,13 +427,6 @@ func (s *Do) pluck(column string, value interface{}) { } } -func (s *Do) where(where ...interface{}) *Do { - if len(where) > 0 { - s.whereClause = append(s.whereClause, map[string]interface{}{"query": where[0], "args": where[1:len(where)]}) - } - return s -} - func (s *Do) primaryCondiation(value interface{}) string { return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value) } @@ -561,7 +544,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { func (s *Do) whereSql() (sql string) { var primary_condiations, and_conditions, or_conditions []string - if !s.unscoped && s.model.hasColumn("DeletedAt") { + if !s.search.unscope && s.model.hasColumn("DeletedAt") { primary_condiations = append(primary_condiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')") } @@ -569,15 +552,15 @@ func (s *Do) whereSql() (sql string) { primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue()))) } - for _, clause := range s.whereClause { + for _, clause := range s.search.whereClause { and_conditions = append(and_conditions, s.buildWhereCondition(clause)) } - for _, clause := range s.orClause { + for _, clause := range s.search.orClause { or_conditions = append(or_conditions, s.buildWhereCondition(clause)) } - for _, clause := range s.notClause { + for _, clause := range s.search.notClause { and_conditions = append(and_conditions, s.buildNotCondition(clause)) } @@ -603,34 +586,34 @@ func (s *Do) whereSql() (sql string) { } func (s *Do) selectSql() string { - if len(s.selectStr) == 0 { + if len(s.search.selectStr) == 0 { return "*" } else { - return s.selectStr + return s.search.selectStr } } func (s *Do) orderSql() string { - if len(s.orderStrs) == 0 { + if len(s.search.orders) == 0 { return "" } else { - return " ORDER BY " + strings.Join(s.orderStrs, ",") + return " ORDER BY " + strings.Join(s.search.orders, ",") } } func (s *Do) limitSql() string { - if len(s.limitStr) == 0 { + if len(s.search.limitStr) == 0 { return "" } else { - return " LIMIT " + s.limitStr + return " LIMIT " + s.search.limitStr } } func (s *Do) offsetSql() string { - if len(s.offsetStr) == 0 { + if len(s.search.offsetStr) == 0 { return "" } else { - return " OFFSET " + s.offsetStr + return " OFFSET " + s.search.offsetStr } } @@ -733,16 +716,19 @@ func (s *Do) commit_or_rollback() { } } +func (s *Do) where(where ...interface{}) *Do { + if len(where) > 0 { + s.search.where(where[0], where[1:]) + } + return s +} + func (s *Do) initialize() { // TODO initializeWithSearchCondition } -func (s *Do) updateAttrs_() { - // TODO return false if no updates -} - func (s *Do) initializeWithSearchCondition() { - for _, clause := range s.whereClause { + for _, clause := range s.search.whereClause { switch value := clause["query"].(type) { case map[string]interface{}: for k, v := range value { diff --git a/main.go b/main.go index 512a6f48..7978eb4b 100644 --- a/main.go +++ b/main.go @@ -82,17 +82,17 @@ func (s *DB) Unscoped() *DB { } func (s *DB) First(out interface{}, where ...interface{}) *DB { - s.clone().search.limit(1).where(where[0], where[1:]).db.do(out).first() + s.clone().do(out).where(where...).first() return s } func (s *DB) Last(out interface{}, where ...interface{}) *DB { - s.clone().search.limit(1).where(where[0], where[1:]).db.do(out).last() + s.clone().do(out).where(where...).last() return s } func (s *DB) Find(out interface{}, where ...interface{}) *DB { - s.clone().search.where(where[0], where[1:]).db.do(out).query() + s.clone().do(out).where(where...).query() return s } diff --git a/model.go b/model.go index 71f17429..3aa39287 100644 --- a/model.go +++ b/model.go @@ -100,7 +100,7 @@ func (m *Model) columnsHasValue(operation string) (fields []*Field) { return } -func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results map[string]interface{}, any_updated bool) { +func (m *Model) updatedColumnsAndValues(values map[string]interface{}, ignore_protected_attrs bool) (results map[string]interface{}, any_updated bool) { if m.data == nil { return values, true } diff --git a/private.go b/private.go index 565c7665..82248ad0 100644 --- a/private.go +++ b/private.go @@ -13,7 +13,9 @@ func (s *DB) clone() *DB { func (s *DB) do(data interface{}) *Do { s.data = data - return &Do{db: s}.setModel(data) + do := Do{db: s} + do.setModel(data) + return &do } func (s *DB) err(err error) error {