Clean up more code

This commit is contained in:
Jinzhu 2013-11-11 22:27:17 +08:00
parent 788da015d1
commit b22289b249
2 changed files with 54 additions and 86 deletions

View File

@ -132,16 +132,12 @@ func (s *Chain) Select(value interface{}) *Chain {
} }
func (s *Chain) Save(value interface{}) *Chain { func (s *Chain) Save(value interface{}) *Chain {
do := s.do(value).begin() s.do(value).begin().save().commit_or_rollback()
do.save()
do.commit_or_rollback()
return s return s
} }
func (s *Chain) Delete(value interface{}) *Chain { func (s *Chain) Delete(value interface{}) *Chain {
do := s.do(value).begin() s.do(value).begin().delete().commit_or_rollback()
do.delete()
do.commit_or_rollback()
return s return s
} }
@ -150,9 +146,7 @@ func (s *Chain) Update(attrs ...interface{}) *Chain {
} }
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain { func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...) s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
do.update()
do.commit_or_rollback()
return s return s
} }

128
do.go
View File

@ -68,12 +68,11 @@ func (s *Do) addToVars(value interface{}) string {
} }
func (s *Do) exec(sqls ...string) (err error) { func (s *Do) exec(sqls ...string) (err error) {
if s.chain.hasError() { if !s.chain.hasError() {
return
} else {
if len(sqls) > 0 { if len(sqls) > 0 {
s.sql = sqls[0] s.sql = sqls[0]
} }
now := time.Now() now := time.Now()
_, err = s.db.Exec(s.sql, s.sqlVars...) _, err = s.db.Exec(s.sql, s.sqlVars...)
s.chain.slog(s.sql, now, s.sqlVars...) s.chain.slog(s.sql, now, s.sqlVars...)
@ -81,13 +80,13 @@ func (s *Do) exec(sqls ...string) (err error) {
return s.err(err) return s.err(err)
} }
func (s *Do) save() (value interface{}) { func (s *Do) save() *Do {
if s.model.primaryKeyZero() { if s.model.primaryKeyZero() {
value = s.create() s.create()
} else { } else {
value = s.update() s.update()
} }
return return s
} }
func (s *Do) prepareCreateSql() { func (s *Do) prepareCreateSql() {
@ -110,24 +109,24 @@ func (s *Do) prepareCreateSql() {
func (s *Do) saveBeforeAssociations() { func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() { for _, field := range s.model.beforeAssociations() {
var id interface{}
do := &Do{chain: s.chain, db: s.db} do := &Do{chain: s.chain, db: s.db}
reflect_value := reflect.ValueOf(field.Value) reflect_value := reflect.ValueOf(field.Value)
if reflect_value.CanAddr() { if reflect_value.CanAddr() {
id = do.setModel(reflect_value.Addr().Interface()).save() do.setModel(reflect_value.Addr().Interface()).save()
} else { } else {
// If can't take address, then clone the value and set it back
dest_value := reflect.New(reflect_value.Type()).Elem() dest_value := reflect.New(reflect_value.Type()).Elem()
m := &Model{data: field.Value, do: s} m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") { for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
} }
id = do.setModel(dest_value.Addr().Interface()).save() do.setModel(dest_value.Addr().Interface()).save()
m.setValueByColumn(field.Name, dest_value.Interface(), s.value) m.setValueByColumn(field.Name, dest_value.Interface(), s.value)
} }
if len(field.foreignKey) > 0 { if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, id, s.model.data) s.model.setValueByColumn(field.foreignKey, do.model.primaryKeyValue(), s.model.data)
} }
} }
} }
@ -136,11 +135,12 @@ func (s *Do) saveAfterAssociations() {
for _, field := range s.model.afterAssociations() { for _, field := range s.model.afterAssociations() {
reflect_value := reflect.ValueOf(field.Value) reflect_value := reflect.ValueOf(field.Value)
switch reflect.TypeOf(field.Value).Kind() { switch reflect_value.Kind() {
case reflect.Slice: case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ { for i := 0; i < reflect_value.Len(); i++ {
value := reflect_value.Index(i).Addr().Interface()
do := &Do{chain: s.chain, db: s.db} do := &Do{chain: s.chain, db: s.db}
value := reflect_value.Index(i).Addr().Interface()
if len(field.foreignKey) > 0 { if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value) s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
} }
@ -175,8 +175,9 @@ func (s *Do) create() (i interface{}) {
s.prepareCreateSql() s.prepareCreateSql()
if !s.chain.hasError() { if !s.chain.hasError() {
now := time.Now()
var id interface{} var id interface{}
now := time.Now()
if s.chain.driver() == "postgres" { if s.chain.driver() == "postgres" {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} else { } else {
@ -188,12 +189,9 @@ func (s *Do) create() (i interface{}) {
s.chain.slog(s.sql, now, s.sqlVars...) s.chain.slog(s.sql, now, s.sqlVars...)
if !s.chain.hasError() { if !s.chain.hasError() {
result := reflect.Indirect(reflect.ValueOf(s.value)) s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
if !setFieldValue(result.FieldByName(s.model.primaryKey()), id) {
fmt.Printf("Can't set primary key for %#v\n", result.Interface())
}
s.saveAfterAssociations()
s.saveAfterAssociations()
s.model.callMethod("AfterCreate") s.model.callMethod("AfterCreate")
s.model.callMethod("AfterSave") s.model.callMethod("AfterSave")
} }
@ -204,19 +202,17 @@ func (s *Do) create() (i interface{}) {
} }
func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do { func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
switch values.(type) { switch vs := values.(type) {
case map[string]interface{}: case map[string]interface{}:
s.updateAttrs = values.(map[string]interface{}) s.updateAttrs = vs
case []interface{}: case []interface{}:
for _, value := range values.([]interface{}) { for _, value := range vs {
s.setUpdateAttrs(value, ignore_protected_attrs...) s.setUpdateAttrs(value, ignore_protected_attrs...)
} }
case interface{}: case interface{}:
m := &Model{data: values, do: s} m := &Model{data: values, do: s}
fields := m.columnsHasValue("other") s.updateAttrs = map[string]interface{}{}
for _, field := range m.columnsHasValue("other") {
s.updateAttrs = make(map[string]interface{}, len(fields))
for _, field := range fields {
s.updateAttrs[field.DbName] = field.Value s.updateAttrs[field.DbName] = field.Value
} }
} }
@ -238,8 +234,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
} }
update_attrs := s.model.columnsAndValues("update") for key, value := range s.model.columnsAndValues("update") {
for key, value := range update_attrs {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
} }
@ -252,13 +247,12 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
return return
} }
func (s *Do) update() (i interface{}) { func (s *Do) update() *Do {
update_attrs := s.updateAttrs update_attrs := s.updateAttrs
if len(update_attrs) > 0 { if len(update_attrs) > 0 {
var need_update bool var need_update bool
update_attrs, need_update = s.prepareUpdateAttrs() if update_attrs, need_update = s.prepareUpdateAttrs(); !need_update {
if !need_update { return s
return
} }
} }
@ -276,29 +270,22 @@ func (s *Do) update() (i interface{}) {
s.model.callMethod("AfterSave") s.model.callMethod("AfterSave")
} }
return s.model.primaryKeyValue() return s
} }
func (s *Do) prepareDeleteSql() { func (s *Do) delete() *Do {
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
return
}
func (s *Do) delete() {
s.model.callMethod("BeforeDelete") s.model.callMethod("BeforeDelete")
if !s.chain.hasError() { if !s.chain.hasError() {
if !s.unscoped && s.model.hasColumn("DeletedAt") { if !s.unscoped && s.model.hasColumn("DeletedAt") {
delete_sql := "deleted_at=" + s.addToVars(time.Now()) s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql())
s.sql = fmt.Sprintf("UPDATE %v SET %v %v", s.tableName(), delete_sql, s.combinedSql())
s.exec()
} else { } else {
s.prepareDeleteSql() s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
s.exec()
} }
s.exec()
s.model.callMethod("AfterDelete") s.model.callMethod("AfterDelete")
} }
return return s
} }
func (s *Do) prepareQuerySql() { func (s *Do) prepareQuerySql() {
@ -379,24 +366,18 @@ func (s *Do) query() {
now := time.Now() now := time.Now()
rows, err := s.db.Query(s.sql, s.sqlVars...) rows, err := s.db.Query(s.sql, s.sqlVars...)
s.chain.slog(s.sql, now, s.sqlVars...) s.chain.slog(s.sql, now, s.sqlVars...)
if s.err(err) != nil { if s.err(err) != nil {
return return
} }
defer rows.Close() defer rows.Close()
var has_record bool
if rows.Err() != nil {
s.err(rows.Err())
}
counts := 0
for rows.Next() { for rows.Next() {
counts += 1 has_record = true
var dest reflect.Value dest := dest_out
if is_slice { if is_slice {
dest = reflect.New(dest_type).Elem() dest = reflect.New(dest_type).Elem()
} else {
dest = dest_out
} }
columns, _ := rows.Columns() columns, _ := rows.Columns()
@ -404,15 +385,10 @@ func (s *Do) query() {
for _, value := range columns { for _, value := range columns {
field := dest.FieldByName(snakeToUpperCamel(value)) field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() { if field.IsValid() {
if field.CanAddr() { values = append(values, field.Addr().Interface())
values = append(values, field.Addr().Interface())
} else {
s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest)))
return
}
} else { } else {
var null interface{} var ignore interface{}
values = append(values, &null) values = append(values, &ignore)
} }
} }
s.err(rows.Scan(values...)) s.err(rows.Scan(values...))
@ -422,7 +398,7 @@ func (s *Do) query() {
} }
} }
if (counts == 0) && !is_slice { if !has_record && !is_slice {
s.err(errors.New("Record not found!")) s.err(errors.New("Record not found!"))
} }
} }
@ -477,9 +453,8 @@ func (s *Do) primaryCondiation(value interface{}) string {
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
query := clause["query"] query := clause["query"]
switch query.(type) { switch value := query.(type) {
case string: case string:
value := query.(string)
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value) id, _ := strconv.Atoi(value)
return s.primaryCondiation(s.addToVars(id)) return s.primaryCondiation(s.addToVars(id))
@ -489,18 +464,18 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
case int, int64, int32: case int, int64, int32:
return s.primaryCondiation(s.addToVars(query)) return s.primaryCondiation(s.addToVars(query))
case sql.NullInt64: case sql.NullInt64:
return s.primaryCondiation(s.addToVars(query.(sql.NullInt64).Int64)) return s.primaryCondiation(s.addToVars(value.Int64))
case []int64, []int, []int32, []string: case []int64, []int, []int32, []string:
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb()) str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
clause["args"] = []interface{}{query} clause["args"] = []interface{}{value}
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
for key, value := range query.(map[string]interface{}) { for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
m := &Model{data: query, do: s} m := &Model{data: value, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value)))
@ -532,9 +507,8 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
query := clause["query"] query := clause["query"]
var not_equal_sql string var not_equal_sql string
switch query.(type) { switch value := query.(type) {
case string: case string:
value := query.(string)
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value) id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id) return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
@ -556,7 +530,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
} }
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
for key, value := range query.(map[string]interface{}) { for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
@ -751,13 +725,13 @@ func (s *Do) initializeWithSearchCondition() {
for _, clause := range s.whereClause { for _, clause := range s.whereClause {
query := clause["query"] query := clause["query"]
switch query.(type) { switch value := query.(type) {
case map[string]interface{}: case map[string]interface{}:
for key, value := range query.(map[string]interface{}) { for k, v := range value {
m.setValueByColumn(key, value, s.value) m.setValueByColumn(k, v, s.value)
} }
case []interface{}: case []interface{}:
for _, obj := range query.([]interface{}) { for _, obj := range value {
switch reflect.ValueOf(obj).Kind() { switch reflect.ValueOf(obj).Kind() {
case reflect.Struct: case reflect.Struct:
m := &Model{data: obj, do: s} m := &Model{data: obj, do: s}