clean up code

This commit is contained in:
Jinzhu 2013-11-12 15:21:21 +08:00
parent 6b31a9a41f
commit 5341a5ed43
2 changed files with 52 additions and 72 deletions

55
do.go
View File

@ -452,8 +452,7 @@ func (s *Do) primaryCondiation(value interface{}) string {
}
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
query := clause["query"]
switch value := query.(type) {
switch value := clause["query"].(type) {
case string:
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
@ -462,7 +461,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
str = "(" + value + ")"
}
case int, int64, int32:
return s.primaryCondiation(s.addToVars(query))
return s.primaryCondiation(s.addToVars(value))
case sql.NullInt64:
return s.primaryCondiation(s.addToVars(value.Int64))
case []int64, []int, []int32, []string:
@ -504,10 +503,9 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
}
func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
query := clause["query"]
var not_equal_sql string
switch value := query.(type) {
switch value := clause["query"].(type) {
case string:
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
@ -520,11 +518,11 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
not_equal_sql = fmt.Sprintf("(%v <> ?)", value)
}
case int, int64, int32:
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), query)
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value)
case []int64, []int, []int32, []string:
if reflect.ValueOf(query).Len() > 0 {
if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb())
clause["args"] = []interface{}{query}
clause["args"] = []interface{}{value}
} else {
return ""
}
@ -535,7 +533,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
}
return strings.Join(sqls, " AND ")
case interface{}:
m := &Model{data: query, do: s}
m := &Model{data: value, do: s}
var sqls []string
for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value)))
@ -586,24 +584,23 @@ func (s *Do) whereSql() (sql string) {
and_conditions = append(and_conditions, s.buildNotCondition(clause))
}
and_sql := strings.Join(and_conditions, " AND ")
or_sql := strings.Join(or_conditions, " OR ")
combined_conditions := and_sql
if len(combined_conditions) > 0 {
combined_sql := strings.Join(and_conditions, " AND ")
if len(combined_sql) > 0 {
if len(or_sql) > 0 {
combined_conditions = combined_conditions + " OR " + or_sql
combined_sql = combined_sql + " OR " + or_sql
}
} else {
combined_conditions = or_sql
combined_sql = or_sql
}
if len(primary_condiations) > 0 {
sql = "WHERE " + strings.Join(primary_condiations, " AND ")
if len(combined_conditions) > 0 {
sql = sql + " AND (" + combined_conditions + ")"
if len(combined_sql) > 0 {
sql = sql + " AND (" + combined_sql + ")"
}
} else if len(combined_conditions) > 0 {
sql = "WHERE " + combined_conditions
} else if len(combined_sql) > 0 {
sql = "WHERE " + combined_sql
}
return
}
@ -646,7 +643,7 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do {
var sqls []string
for _, field := range s.model.fields("other") {
for _, field := range s.model.fields("migration") {
if len(field.SqlType) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType)
}
@ -681,7 +678,7 @@ func (s *Do) autoMigrate() *Do {
if len(table_name) == 0 {
s.createTable()
} else {
for _, field := range s.model.fields("other") {
for _, field := range s.model.fields("migration") {
var column_name, data_type string
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName()))
s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.DbName)), s.sqlVars...).Scan(&column_name, &data_type)
@ -699,8 +696,7 @@ func (s *Do) autoMigrate() *Do {
func (s *Do) begin() *Do {
if db, ok := s.db.(sql_db); ok {
tx, err := db.Begin()
if err == nil {
if tx, err := db.Begin(); err == nil {
s.db = interface{}(tx).(sql_common)
s.startedTransaction = true
}
@ -721,14 +717,11 @@ func (s *Do) commit_or_rollback() {
}
func (s *Do) initializeWithSearchCondition() {
m := Model{data: s.value, do: s}
for _, clause := range s.whereClause {
query := clause["query"]
switch value := query.(type) {
switch value := clause["query"].(type) {
case map[string]interface{}:
for k, v := range value {
m.setValueByColumn(k, v, s.value)
s.model.setValueByColumn(k, v, s.value)
}
case []interface{}:
for _, obj := range value {
@ -736,18 +729,18 @@ func (s *Do) initializeWithSearchCondition() {
case reflect.Struct:
m := &Model{data: obj, do: s}
for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value)
s.model.setValueByColumn(field.DbName, field.Value, s.value)
}
case reflect.Map:
for key, value := range obj.(map[string]interface{}) {
m.setValueByColumn(key, value, s.value)
s.model.setValueByColumn(key, value, s.value)
}
}
}
case interface{}:
m := &Model{data: query, do: s}
m := &Model{data: value, do: s}
for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value)
s.model.setValueByColumn(field.DbName, field.Value, s.value)
}
}
}

View File

@ -87,8 +87,6 @@ func (m *Model) fields(operation string) (fields []Field) {
field.Name = p.Name
field.DbName = toSnake(p.Name)
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName
value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time)
@ -98,9 +96,7 @@ func (m *Model) fields(operation string) (fields []Field) {
case reflect.String:
field.IsBlank = value.String() == ""
case reflect.Slice:
if value.Len() == 0 {
field.IsBlank = true
}
field.IsBlank = value.Len() == 0
case reflect.Struct:
if is_time {
field.IsBlank = time_value.IsZero()
@ -111,7 +107,6 @@ func (m *Model) fields(operation string) (fields []Field) {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: m.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
@ -121,20 +116,20 @@ func (m *Model) fields(operation string) (fields []Field) {
}
if is_time {
field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName
switch operation {
case "create":
if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() {
value.Set(reflect.ValueOf(time.Now()))
}
case "update":
if field.AutoCreateTime && time_value.IsZero() {
value.Set(reflect.ValueOf(time.Now()))
}
if field.AutoUpdateTime {
value.Set(reflect.ValueOf(time.Now()))
}
}
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
} else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
@ -176,7 +171,7 @@ func (m *Model) fields(operation string) (fields []Field) {
}
if len(m._cache_fields) == 0 {
m._cache_fields = make(map[string][]Field)
m._cache_fields = map[string][]Field{}
}
m._cache_fields[operation] = fields
return
@ -191,14 +186,12 @@ func (m *Model) columnsHasValue(operation string) (fields []Field) {
return
}
func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[string]interface{}, bool) {
func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results map[string]interface{}, any_updated bool) {
if m.data == nil {
return values, true
}
data := reflect.Indirect(reflect.ValueOf(m.data))
results := map[string]interface{}{}
for key, value := range values {
field := data.FieldByName(snakeToUpperCamel(key))
if field.IsValid() {
@ -206,33 +199,31 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[stri
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if field.Int() != reflect.ValueOf(value).Int() {
results[key] = value
any_updated = true
}
field.SetInt(reflect.ValueOf(value).Int())
default:
results[key] = value
any_updated = true
field.Set(reflect.ValueOf(value))
}
}
}
}
if values["updated_at"] != nil && len(results) > 0 {
if values["updated_at"] != nil && any_updated {
setFieldValue(data.FieldByName("UpdatedAt"), time.Now())
}
result := len(results) > 0
return map[string]interface{}{}, result
return
}
func (m *Model) columnsAndValues(operation string) map[string]interface{} {
if m.data == nil {
return map[string]interface{}{}
}
results := map[string]interface{}{}
for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
results[field.DbName] = field.Value
if m.data != nil {
for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
results[field.DbName] = field.Value
}
}
}
return results
@ -252,17 +243,15 @@ func (m *Model) hasColumn(name string) bool {
}
func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
if m.data == nil {
return
}
data := reflect.Indirect(reflect.ValueOf(m.data))
if data.Kind() == reflect.Slice {
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
is_slice = true
} else {
if has_column = data.FieldByName(name).IsValid(); has_column {
value = data.FieldByName(name).Interface()
if m.data != nil {
data := reflect.Indirect(reflect.ValueOf(m.data))
if data.Kind() == reflect.Slice {
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
is_slice = true
} else {
if has_column = data.FieldByName(name).IsValid(); has_column {
value = data.FieldByName(name).Interface()
}
}
}
return
@ -285,8 +274,7 @@ func (m *Model) tableName() (str string) {
fm := reflect.Indirect(reflect.ValueOf(m.data)).MethodByName("TableName")
if fm.IsValid() {
v := fm.Call([]reflect.Value{})
if len(v) > 0 {
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if result, ok := v[0].Interface().(string); ok {
return result
}
@ -315,8 +303,7 @@ func (m *Model) callMethod(method string) {
fm := reflect.ValueOf(m.data).MethodByName(method)
if fm.IsValid() {
v := fm.Call([]reflect.Value{})
if len(v) > 0 {
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if verr, ok := v[0].Interface().(error); ok {
m.do.err(verr)
}