clean up code

This commit is contained in:
Jinzhu 2013-11-12 15:21:21 +08:00
parent 6b31a9a41f
commit 5341a5ed43
2 changed files with 52 additions and 72 deletions

55
do.go
View File

@ -452,8 +452,7 @@ func (s *Do) primaryCondiation(value interface{}) string {
} }
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
query := clause["query"] switch value := clause["query"].(type) {
switch value := query.(type) {
case string: case string:
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value) id, _ := strconv.Atoi(value)
@ -462,7 +461,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
str = "(" + value + ")" str = "(" + value + ")"
} }
case int, int64, int32: case int, int64, int32:
return s.primaryCondiation(s.addToVars(query)) return s.primaryCondiation(s.addToVars(value))
case sql.NullInt64: case sql.NullInt64:
return s.primaryCondiation(s.addToVars(value.Int64)) return s.primaryCondiation(s.addToVars(value.Int64))
case []int64, []int, []int32, []string: case []int64, []int, []int32, []string:
@ -504,10 +503,9 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
} }
func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
query := clause["query"]
var not_equal_sql string var not_equal_sql string
switch value := query.(type) { switch value := clause["query"].(type) {
case string: case string:
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value) id, _ := strconv.Atoi(value)
@ -520,11 +518,11 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
not_equal_sql = fmt.Sprintf("(%v <> ?)", value) not_equal_sql = fmt.Sprintf("(%v <> ?)", value)
} }
case int, int64, int32: case int, int64, int32:
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), query) return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value)
case []int64, []int, []int32, []string: case []int64, []int, []int32, []string:
if reflect.ValueOf(query).Len() > 0 { if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb()) str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb())
clause["args"] = []interface{}{query} clause["args"] = []interface{}{value}
} else { } else {
return "" return ""
} }
@ -535,7 +533,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
m := &Model{data: query, do: s} m := &Model{data: value, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value)))
@ -586,24 +584,23 @@ func (s *Do) whereSql() (sql string) {
and_conditions = append(and_conditions, s.buildNotCondition(clause)) and_conditions = append(and_conditions, s.buildNotCondition(clause))
} }
and_sql := strings.Join(and_conditions, " AND ")
or_sql := strings.Join(or_conditions, " OR ") or_sql := strings.Join(or_conditions, " OR ")
combined_conditions := and_sql combined_sql := strings.Join(and_conditions, " AND ")
if len(combined_conditions) > 0 { if len(combined_sql) > 0 {
if len(or_sql) > 0 { if len(or_sql) > 0 {
combined_conditions = combined_conditions + " OR " + or_sql combined_sql = combined_sql + " OR " + or_sql
} }
} else { } else {
combined_conditions = or_sql combined_sql = or_sql
} }
if len(primary_condiations) > 0 { if len(primary_condiations) > 0 {
sql = "WHERE " + strings.Join(primary_condiations, " AND ") sql = "WHERE " + strings.Join(primary_condiations, " AND ")
if len(combined_conditions) > 0 { if len(combined_sql) > 0 {
sql = sql + " AND (" + combined_conditions + ")" sql = sql + " AND (" + combined_sql + ")"
} }
} else if len(combined_conditions) > 0 { } else if len(combined_sql) > 0 {
sql = "WHERE " + combined_conditions sql = "WHERE " + combined_sql
} }
return return
} }
@ -646,7 +643,7 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do { func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("other") { for _, field := range s.model.fields("migration") {
if len(field.SqlType) > 0 { if len(field.SqlType) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType) sqls = append(sqls, field.DbName+" "+field.SqlType)
} }
@ -681,7 +678,7 @@ func (s *Do) autoMigrate() *Do {
if len(table_name) == 0 { if len(table_name) == 0 {
s.createTable() s.createTable()
} else { } else {
for _, field := range s.model.fields("other") { for _, field := range s.model.fields("migration") {
var column_name, data_type string var column_name, data_type string
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName())) sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName()))
s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.DbName)), s.sqlVars...).Scan(&column_name, &data_type) s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.DbName)), s.sqlVars...).Scan(&column_name, &data_type)
@ -699,8 +696,7 @@ func (s *Do) autoMigrate() *Do {
func (s *Do) begin() *Do { func (s *Do) begin() *Do {
if db, ok := s.db.(sql_db); ok { if db, ok := s.db.(sql_db); ok {
tx, err := db.Begin() if tx, err := db.Begin(); err == nil {
if err == nil {
s.db = interface{}(tx).(sql_common) s.db = interface{}(tx).(sql_common)
s.startedTransaction = true s.startedTransaction = true
} }
@ -721,14 +717,11 @@ func (s *Do) commit_or_rollback() {
} }
func (s *Do) initializeWithSearchCondition() { func (s *Do) initializeWithSearchCondition() {
m := Model{data: s.value, do: s}
for _, clause := range s.whereClause { for _, clause := range s.whereClause {
query := clause["query"] switch value := clause["query"].(type) {
switch value := query.(type) {
case map[string]interface{}: case map[string]interface{}:
for k, v := range value { for k, v := range value {
m.setValueByColumn(k, v, s.value) s.model.setValueByColumn(k, v, s.value)
} }
case []interface{}: case []interface{}:
for _, obj := range value { for _, obj := range value {
@ -736,18 +729,18 @@ func (s *Do) initializeWithSearchCondition() {
case reflect.Struct: case reflect.Struct:
m := &Model{data: obj, do: s} m := &Model{data: obj, do: s}
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value) s.model.setValueByColumn(field.DbName, field.Value, s.value)
} }
case reflect.Map: case reflect.Map:
for key, value := range obj.(map[string]interface{}) { for key, value := range obj.(map[string]interface{}) {
m.setValueByColumn(key, value, s.value) s.model.setValueByColumn(key, value, s.value)
} }
} }
} }
case interface{}: case interface{}:
m := &Model{data: query, do: s} m := &Model{data: value, do: s}
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value) s.model.setValueByColumn(field.DbName, field.Value, s.value)
} }
} }
} }

View File

@ -87,8 +87,6 @@ func (m *Model) fields(operation string) (fields []Field) {
field.Name = p.Name field.Name = p.Name
field.DbName = toSnake(p.Name) field.DbName = toSnake(p.Name)
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName
value := indirect_value.FieldByName(p.Name) value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time) time_value, is_time := value.Interface().(time.Time)
@ -98,9 +96,7 @@ func (m *Model) fields(operation string) (fields []Field) {
case reflect.String: case reflect.String:
field.IsBlank = value.String() == "" field.IsBlank = value.String() == ""
case reflect.Slice: case reflect.Slice:
if value.Len() == 0 { field.IsBlank = value.Len() == 0
field.IsBlank = true
}
case reflect.Struct: case reflect.Struct:
if is_time { if is_time {
field.IsBlank = time_value.IsZero() field.IsBlank = time_value.IsZero()
@ -111,7 +107,6 @@ func (m *Model) fields(operation string) (fields []Field) {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool) field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
} else { } else {
m := &Model{data: value.Interface(), do: m.do} m := &Model{data: value.Interface(), do: m.do}
fields := m.columnsHasValue("other") fields := m.columnsHasValue("other")
if len(fields) == 0 { if len(fields) == 0 {
field.IsBlank = true field.IsBlank = true
@ -121,20 +116,20 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
if is_time { if is_time {
field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName
switch operation { switch operation {
case "create": case "create":
if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() { if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
case "update": case "update":
if field.AutoCreateTime && time_value.IsZero() {
value.Set(reflect.ValueOf(time.Now()))
}
if field.AutoUpdateTime { if field.AutoUpdateTime {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
} }
field.SqlType = getSqlType(m.do.chain.driver(), value, 0) field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
} else if field.IsPrimaryKey { } else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0) field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
@ -176,7 +171,7 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
if len(m._cache_fields) == 0 { if len(m._cache_fields) == 0 {
m._cache_fields = make(map[string][]Field) m._cache_fields = map[string][]Field{}
} }
m._cache_fields[operation] = fields m._cache_fields[operation] = fields
return return
@ -191,14 +186,12 @@ func (m *Model) columnsHasValue(operation string) (fields []Field) {
return return
} }
func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[string]interface{}, bool) { func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results map[string]interface{}, any_updated bool) {
if m.data == nil { if m.data == nil {
return values, true return values, true
} }
data := reflect.Indirect(reflect.ValueOf(m.data)) data := reflect.Indirect(reflect.ValueOf(m.data))
results := map[string]interface{}{}
for key, value := range values { for key, value := range values {
field := data.FieldByName(snakeToUpperCamel(key)) field := data.FieldByName(snakeToUpperCamel(key))
if field.IsValid() { if field.IsValid() {
@ -206,35 +199,33 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[stri
switch field.Kind() { switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
if field.Int() != reflect.ValueOf(value).Int() { if field.Int() != reflect.ValueOf(value).Int() {
results[key] = value any_updated = true
} }
field.SetInt(reflect.ValueOf(value).Int()) field.SetInt(reflect.ValueOf(value).Int())
default: default:
results[key] = value any_updated = true
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
} }
} }
} }
} }
if values["updated_at"] != nil && len(results) > 0 { if values["updated_at"] != nil && any_updated {
setFieldValue(data.FieldByName("UpdatedAt"), time.Now()) setFieldValue(data.FieldByName("UpdatedAt"), time.Now())
} }
result := len(results) > 0 return
return map[string]interface{}{}, result
} }
func (m *Model) columnsAndValues(operation string) map[string]interface{} { func (m *Model) columnsAndValues(operation string) map[string]interface{} {
if m.data == nil {
return map[string]interface{}{}
}
results := map[string]interface{}{} results := map[string]interface{}{}
if m.data != nil {
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.SqlType) > 0) { if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
results[field.DbName] = field.Value results[field.DbName] = field.Value
} }
} }
}
return results return results
} }
@ -252,10 +243,7 @@ func (m *Model) hasColumn(name string) bool {
} }
func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
if m.data == nil { if m.data != nil {
return
}
data := reflect.Indirect(reflect.ValueOf(m.data)) data := reflect.Indirect(reflect.ValueOf(m.data))
if data.Kind() == reflect.Slice { if data.Kind() == reflect.Slice {
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
@ -265,6 +253,7 @@ func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, val
value = data.FieldByName(name).Interface() value = data.FieldByName(name).Interface()
} }
} }
}
return return
} }
@ -285,8 +274,7 @@ func (m *Model) tableName() (str string) {
fm := reflect.Indirect(reflect.ValueOf(m.data)).MethodByName("TableName") fm := reflect.Indirect(reflect.ValueOf(m.data)).MethodByName("TableName")
if fm.IsValid() { if fm.IsValid() {
v := fm.Call([]reflect.Value{}) if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if len(v) > 0 {
if result, ok := v[0].Interface().(string); ok { if result, ok := v[0].Interface().(string); ok {
return result return result
} }
@ -315,8 +303,7 @@ func (m *Model) callMethod(method string) {
fm := reflect.ValueOf(m.data).MethodByName(method) fm := reflect.ValueOf(m.data).MethodByName(method)
if fm.IsValid() { if fm.IsValid() {
v := fm.Call([]reflect.Value{}) if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if len(v) > 0 {
if verr, ok := v[0].Interface().(error); ok { if verr, ok := v[0].Interface().(error); ok {
m.do.err(verr) m.do.err(verr)
} }