Clean up more code

This commit is contained in:
Jinzhu 2013-11-11 22:27:17 +08:00
parent 788da015d1
commit b22289b249
2 changed files with 54 additions and 86 deletions

View File

@ -132,16 +132,12 @@ func (s *Chain) Select(value interface{}) *Chain {
}
func (s *Chain) Save(value interface{}) *Chain {
do := s.do(value).begin()
do.save()
do.commit_or_rollback()
s.do(value).begin().save().commit_or_rollback()
return s
}
func (s *Chain) Delete(value interface{}) *Chain {
do := s.do(value).begin()
do.delete()
do.commit_or_rollback()
s.do(value).begin().delete().commit_or_rollback()
return s
}
@ -150,9 +146,7 @@ func (s *Chain) Update(attrs ...interface{}) *Chain {
}
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...)
do.update()
do.commit_or_rollback()
s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
return s
}

128
do.go
View File

@ -68,12 +68,11 @@ func (s *Do) addToVars(value interface{}) string {
}
func (s *Do) exec(sqls ...string) (err error) {
if s.chain.hasError() {
return
} else {
if !s.chain.hasError() {
if len(sqls) > 0 {
s.sql = sqls[0]
}
now := time.Now()
_, err = s.db.Exec(s.sql, s.sqlVars...)
s.chain.slog(s.sql, now, s.sqlVars...)
@ -81,13 +80,13 @@ func (s *Do) exec(sqls ...string) (err error) {
return s.err(err)
}
func (s *Do) save() (value interface{}) {
func (s *Do) save() *Do {
if s.model.primaryKeyZero() {
value = s.create()
s.create()
} else {
value = s.update()
s.update()
}
return
return s
}
func (s *Do) prepareCreateSql() {
@ -110,24 +109,24 @@ func (s *Do) prepareCreateSql() {
func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() {
var id interface{}
do := &Do{chain: s.chain, db: s.db}
reflect_value := reflect.ValueOf(field.Value)
if reflect_value.CanAddr() {
id = do.setModel(reflect_value.Addr().Interface()).save()
do.setModel(reflect_value.Addr().Interface()).save()
} else {
// If can't take address, then clone the value and set it back
dest_value := reflect.New(reflect_value.Type()).Elem()
m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
id = do.setModel(dest_value.Addr().Interface()).save()
do.setModel(dest_value.Addr().Interface()).save()
m.setValueByColumn(field.Name, dest_value.Interface(), s.value)
}
if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, id, s.model.data)
s.model.setValueByColumn(field.foreignKey, do.model.primaryKeyValue(), s.model.data)
}
}
}
@ -136,11 +135,12 @@ func (s *Do) saveAfterAssociations() {
for _, field := range s.model.afterAssociations() {
reflect_value := reflect.ValueOf(field.Value)
switch reflect.TypeOf(field.Value).Kind() {
switch reflect_value.Kind() {
case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ {
value := reflect_value.Index(i).Addr().Interface()
do := &Do{chain: s.chain, db: s.db}
value := reflect_value.Index(i).Addr().Interface()
if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
}
@ -175,8 +175,9 @@ func (s *Do) create() (i interface{}) {
s.prepareCreateSql()
if !s.chain.hasError() {
now := time.Now()
var id interface{}
now := time.Now()
if s.chain.driver() == "postgres" {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} else {
@ -188,12 +189,9 @@ func (s *Do) create() (i interface{}) {
s.chain.slog(s.sql, now, s.sqlVars...)
if !s.chain.hasError() {
result := reflect.Indirect(reflect.ValueOf(s.value))
if !setFieldValue(result.FieldByName(s.model.primaryKey()), id) {
fmt.Printf("Can't set primary key for %#v\n", result.Interface())
}
s.saveAfterAssociations()
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
s.saveAfterAssociations()
s.model.callMethod("AfterCreate")
s.model.callMethod("AfterSave")
}
@ -204,19 +202,17 @@ func (s *Do) create() (i interface{}) {
}
func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
switch values.(type) {
switch vs := values.(type) {
case map[string]interface{}:
s.updateAttrs = values.(map[string]interface{})
s.updateAttrs = vs
case []interface{}:
for _, value := range values.([]interface{}) {
for _, value := range vs {
s.setUpdateAttrs(value, ignore_protected_attrs...)
}
case interface{}:
m := &Model{data: values, do: s}
fields := m.columnsHasValue("other")
s.updateAttrs = make(map[string]interface{}, len(fields))
for _, field := range fields {
s.updateAttrs = map[string]interface{}{}
for _, field := range m.columnsHasValue("other") {
s.updateAttrs[field.DbName] = field.Value
}
}
@ -238,8 +234,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
}
update_attrs := s.model.columnsAndValues("update")
for key, value := range update_attrs {
for key, value := range s.model.columnsAndValues("update") {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
}
@ -252,13 +247,12 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
return
}
func (s *Do) update() (i interface{}) {
func (s *Do) update() *Do {
update_attrs := s.updateAttrs
if len(update_attrs) > 0 {
var need_update bool
update_attrs, need_update = s.prepareUpdateAttrs()
if !need_update {
return
if update_attrs, need_update = s.prepareUpdateAttrs(); !need_update {
return s
}
}
@ -276,29 +270,22 @@ func (s *Do) update() (i interface{}) {
s.model.callMethod("AfterSave")
}
return s.model.primaryKeyValue()
return s
}
func (s *Do) prepareDeleteSql() {
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
return
}
func (s *Do) delete() {
func (s *Do) delete() *Do {
s.model.callMethod("BeforeDelete")
if !s.chain.hasError() {
if !s.unscoped && s.model.hasColumn("DeletedAt") {
delete_sql := "deleted_at=" + s.addToVars(time.Now())
s.sql = fmt.Sprintf("UPDATE %v SET %v %v", s.tableName(), delete_sql, s.combinedSql())
s.exec()
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql())
} else {
s.prepareDeleteSql()
s.exec()
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
}
s.exec()
s.model.callMethod("AfterDelete")
}
return
return s
}
func (s *Do) prepareQuerySql() {
@ -379,24 +366,18 @@ func (s *Do) query() {
now := time.Now()
rows, err := s.db.Query(s.sql, s.sqlVars...)
s.chain.slog(s.sql, now, s.sqlVars...)
if s.err(err) != nil {
return
}
defer rows.Close()
if rows.Err() != nil {
s.err(rows.Err())
}
counts := 0
var has_record bool
for rows.Next() {
counts += 1
var dest reflect.Value
has_record = true
dest := dest_out
if is_slice {
dest = reflect.New(dest_type).Elem()
} else {
dest = dest_out
}
columns, _ := rows.Columns()
@ -404,15 +385,10 @@ func (s *Do) query() {
for _, value := range columns {
field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
if field.CanAddr() {
values = append(values, field.Addr().Interface())
} else {
s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest)))
return
}
values = append(values, field.Addr().Interface())
} else {
var null interface{}
values = append(values, &null)
var ignore interface{}
values = append(values, &ignore)
}
}
s.err(rows.Scan(values...))
@ -422,7 +398,7 @@ func (s *Do) query() {
}
}
if (counts == 0) && !is_slice {
if !has_record && !is_slice {
s.err(errors.New("Record not found!"))
}
}
@ -477,9 +453,8 @@ func (s *Do) primaryCondiation(value interface{}) string {
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
query := clause["query"]
switch query.(type) {
switch value := query.(type) {
case string:
value := query.(string)
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return s.primaryCondiation(s.addToVars(id))
@ -489,18 +464,18 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
case int, int64, int32:
return s.primaryCondiation(s.addToVars(query))
case sql.NullInt64:
return s.primaryCondiation(s.addToVars(query.(sql.NullInt64).Int64))
return s.primaryCondiation(s.addToVars(value.Int64))
case []int64, []int, []int32, []string:
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
clause["args"] = []interface{}{query}
clause["args"] = []interface{}{value}
case map[string]interface{}:
var sqls []string
for key, value := range query.(map[string]interface{}) {
for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
}
return strings.Join(sqls, " AND ")
case interface{}:
m := &Model{data: query, do: s}
m := &Model{data: value, do: s}
var sqls []string
for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value)))
@ -532,9 +507,8 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
query := clause["query"]
var not_equal_sql string
switch query.(type) {
switch value := query.(type) {
case string:
value := query.(string)
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
@ -556,7 +530,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
}
case map[string]interface{}:
var sqls []string
for key, value := range query.(map[string]interface{}) {
for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
}
return strings.Join(sqls, " AND ")
@ -751,13 +725,13 @@ func (s *Do) initializeWithSearchCondition() {
for _, clause := range s.whereClause {
query := clause["query"]
switch query.(type) {
switch value := query.(type) {
case map[string]interface{}:
for key, value := range query.(map[string]interface{}) {
m.setValueByColumn(key, value, s.value)
for k, v := range value {
m.setValueByColumn(k, v, s.value)
}
case []interface{}:
for _, obj := range query.([]interface{}) {
for _, obj := range value {
switch reflect.ValueOf(obj).Kind() {
case reflect.Struct:
m := &Model{data: obj, do: s}