From b22289b2497706fcb99b7f017ae3e63f8952fa4f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Nov 2013 22:27:17 +0800 Subject: [PATCH] Clean up more code --- chain.go | 12 ++---- do.go | 128 ++++++++++++++++++++++--------------------------------- 2 files changed, 54 insertions(+), 86 deletions(-) diff --git a/chain.go b/chain.go index 1df4412c..a80eaaa1 100644 --- a/chain.go +++ b/chain.go @@ -132,16 +132,12 @@ func (s *Chain) Select(value interface{}) *Chain { } func (s *Chain) Save(value interface{}) *Chain { - do := s.do(value).begin() - do.save() - do.commit_or_rollback() + s.do(value).begin().save().commit_or_rollback() return s } func (s *Chain) Delete(value interface{}) *Chain { - do := s.do(value).begin() - do.delete() - do.commit_or_rollback() + s.do(value).begin().delete().commit_or_rollback() return s } @@ -150,9 +146,7 @@ func (s *Chain) Update(attrs ...interface{}) *Chain { } func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain { - do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...) - do.update() - do.commit_or_rollback() + s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback() return s } diff --git a/do.go b/do.go index 89525bf4..0b159302 100644 --- a/do.go +++ b/do.go @@ -68,12 +68,11 @@ func (s *Do) addToVars(value interface{}) string { } func (s *Do) exec(sqls ...string) (err error) { - if s.chain.hasError() { - return - } else { + if !s.chain.hasError() { if len(sqls) > 0 { s.sql = sqls[0] } + now := time.Now() _, err = s.db.Exec(s.sql, s.sqlVars...) s.chain.slog(s.sql, now, s.sqlVars...) @@ -81,13 +80,13 @@ func (s *Do) exec(sqls ...string) (err error) { return s.err(err) } -func (s *Do) save() (value interface{}) { +func (s *Do) save() *Do { if s.model.primaryKeyZero() { - value = s.create() + s.create() } else { - value = s.update() + s.update() } - return + return s } func (s *Do) prepareCreateSql() { @@ -110,24 +109,24 @@ func (s *Do) prepareCreateSql() { func (s *Do) saveBeforeAssociations() { for _, field := range s.model.beforeAssociations() { - var id interface{} do := &Do{chain: s.chain, db: s.db} reflect_value := reflect.ValueOf(field.Value) if reflect_value.CanAddr() { - id = do.setModel(reflect_value.Addr().Interface()).save() + do.setModel(reflect_value.Addr().Interface()).save() } else { + // If can't take address, then clone the value and set it back dest_value := reflect.New(reflect_value.Type()).Elem() m := &Model{data: field.Value, do: s} for _, f := range m.columnsHasValue("other") { dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) } - id = do.setModel(dest_value.Addr().Interface()).save() + do.setModel(dest_value.Addr().Interface()).save() m.setValueByColumn(field.Name, dest_value.Interface(), s.value) } if len(field.foreignKey) > 0 { - s.model.setValueByColumn(field.foreignKey, id, s.model.data) + s.model.setValueByColumn(field.foreignKey, do.model.primaryKeyValue(), s.model.data) } } } @@ -136,11 +135,12 @@ func (s *Do) saveAfterAssociations() { for _, field := range s.model.afterAssociations() { reflect_value := reflect.ValueOf(field.Value) - switch reflect.TypeOf(field.Value).Kind() { + switch reflect_value.Kind() { case reflect.Slice: for i := 0; i < reflect_value.Len(); i++ { - value := reflect_value.Index(i).Addr().Interface() do := &Do{chain: s.chain, db: s.db} + + value := reflect_value.Index(i).Addr().Interface() if len(field.foreignKey) > 0 { s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value) } @@ -175,8 +175,9 @@ func (s *Do) create() (i interface{}) { s.prepareCreateSql() if !s.chain.hasError() { - now := time.Now() var id interface{} + + now := time.Now() if s.chain.driver() == "postgres" { s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) } else { @@ -188,12 +189,9 @@ func (s *Do) create() (i interface{}) { s.chain.slog(s.sql, now, s.sqlVars...) if !s.chain.hasError() { - result := reflect.Indirect(reflect.ValueOf(s.value)) - if !setFieldValue(result.FieldByName(s.model.primaryKey()), id) { - fmt.Printf("Can't set primary key for %#v\n", result.Interface()) - } - s.saveAfterAssociations() + s.model.setValueByColumn(s.model.primaryKey(), id, s.value) + s.saveAfterAssociations() s.model.callMethod("AfterCreate") s.model.callMethod("AfterSave") } @@ -204,19 +202,17 @@ func (s *Do) create() (i interface{}) { } func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do { - switch values.(type) { + switch vs := values.(type) { case map[string]interface{}: - s.updateAttrs = values.(map[string]interface{}) + s.updateAttrs = vs case []interface{}: - for _, value := range values.([]interface{}) { + for _, value := range vs { s.setUpdateAttrs(value, ignore_protected_attrs...) } case interface{}: m := &Model{data: values, do: s} - fields := m.columnsHasValue("other") - - s.updateAttrs = make(map[string]interface{}, len(fields)) - for _, field := range fields { + s.updateAttrs = map[string]interface{}{} + for _, field := range m.columnsHasValue("other") { s.updateAttrs[field.DbName] = field.Value } } @@ -238,8 +234,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) { sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) } - update_attrs := s.model.columnsAndValues("update") - for key, value := range update_attrs { + for key, value := range s.model.columnsAndValues("update") { sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) } @@ -252,13 +247,12 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) { return } -func (s *Do) update() (i interface{}) { +func (s *Do) update() *Do { update_attrs := s.updateAttrs if len(update_attrs) > 0 { var need_update bool - update_attrs, need_update = s.prepareUpdateAttrs() - if !need_update { - return + if update_attrs, need_update = s.prepareUpdateAttrs(); !need_update { + return s } } @@ -276,29 +270,22 @@ func (s *Do) update() (i interface{}) { s.model.callMethod("AfterSave") } - return s.model.primaryKeyValue() + return s } -func (s *Do) prepareDeleteSql() { - s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql()) - return -} - -func (s *Do) delete() { +func (s *Do) delete() *Do { s.model.callMethod("BeforeDelete") if !s.chain.hasError() { if !s.unscoped && s.model.hasColumn("DeletedAt") { - delete_sql := "deleted_at=" + s.addToVars(time.Now()) - s.sql = fmt.Sprintf("UPDATE %v SET %v %v", s.tableName(), delete_sql, s.combinedSql()) - s.exec() + s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql()) } else { - s.prepareDeleteSql() - s.exec() + s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql()) } + s.exec() s.model.callMethod("AfterDelete") } - return + return s } func (s *Do) prepareQuerySql() { @@ -379,24 +366,18 @@ func (s *Do) query() { now := time.Now() rows, err := s.db.Query(s.sql, s.sqlVars...) s.chain.slog(s.sql, now, s.sqlVars...) + if s.err(err) != nil { return } defer rows.Close() - - if rows.Err() != nil { - s.err(rows.Err()) - } - - counts := 0 + var has_record bool for rows.Next() { - counts += 1 - var dest reflect.Value + has_record = true + dest := dest_out if is_slice { dest = reflect.New(dest_type).Elem() - } else { - dest = dest_out } columns, _ := rows.Columns() @@ -404,15 +385,10 @@ func (s *Do) query() { for _, value := range columns { field := dest.FieldByName(snakeToUpperCamel(value)) if field.IsValid() { - if field.CanAddr() { - values = append(values, field.Addr().Interface()) - } else { - s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest))) - return - } + values = append(values, field.Addr().Interface()) } else { - var null interface{} - values = append(values, &null) + var ignore interface{} + values = append(values, &ignore) } } s.err(rows.Scan(values...)) @@ -422,7 +398,7 @@ func (s *Do) query() { } } - if (counts == 0) && !is_slice { + if !has_record && !is_slice { s.err(errors.New("Record not found!")) } } @@ -477,9 +453,8 @@ func (s *Do) primaryCondiation(value interface{}) string { func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { query := clause["query"] - switch query.(type) { + switch value := query.(type) { case string: - value := query.(string) if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) return s.primaryCondiation(s.addToVars(id)) @@ -489,18 +464,18 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { case int, int64, int32: return s.primaryCondiation(s.addToVars(query)) case sql.NullInt64: - return s.primaryCondiation(s.addToVars(query.(sql.NullInt64).Int64)) + return s.primaryCondiation(s.addToVars(value.Int64)) case []int64, []int, []int32, []string: str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb()) - clause["args"] = []interface{}{query} + clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string - for key, value := range query.(map[string]interface{}) { + for key, value := range value { sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value))) } return strings.Join(sqls, " AND ") case interface{}: - m := &Model{data: query, do: s} + m := &Model{data: value, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value))) @@ -532,9 +507,8 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { query := clause["query"] var not_equal_sql string - switch query.(type) { + switch value := query.(type) { case string: - value := query.(string) if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id) @@ -556,7 +530,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { } case map[string]interface{}: var sqls []string - for key, value := range query.(map[string]interface{}) { + for key, value := range value { sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value))) } return strings.Join(sqls, " AND ") @@ -751,13 +725,13 @@ func (s *Do) initializeWithSearchCondition() { for _, clause := range s.whereClause { query := clause["query"] - switch query.(type) { + switch value := query.(type) { case map[string]interface{}: - for key, value := range query.(map[string]interface{}) { - m.setValueByColumn(key, value, s.value) + for k, v := range value { + m.setValueByColumn(k, v, s.value) } case []interface{}: - for _, obj := range query.([]interface{}) { + for _, obj := range value { switch reflect.ValueOf(obj).Kind() { case reflect.Struct: m := &Model{data: obj, do: s}