diff --git a/do.go b/do.go index 0b159302..a624fffa 100644 --- a/do.go +++ b/do.go @@ -452,8 +452,7 @@ func (s *Do) primaryCondiation(value interface{}) string { } func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { - query := clause["query"] - switch value := query.(type) { + switch value := clause["query"].(type) { case string: if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) @@ -462,7 +461,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { str = "(" + value + ")" } case int, int64, int32: - return s.primaryCondiation(s.addToVars(query)) + return s.primaryCondiation(s.addToVars(value)) case sql.NullInt64: return s.primaryCondiation(s.addToVars(value.Int64)) case []int64, []int, []int32, []string: @@ -504,10 +503,9 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { } func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { - query := clause["query"] var not_equal_sql string - switch value := query.(type) { + switch value := clause["query"].(type) { case string: if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) @@ -520,11 +518,11 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { not_equal_sql = fmt.Sprintf("(%v <> ?)", value) } case int, int64, int32: - return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), query) + return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value) case []int64, []int, []int32, []string: - if reflect.ValueOf(query).Len() > 0 { + if reflect.ValueOf(value).Len() > 0 { str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb()) - clause["args"] = []interface{}{query} + clause["args"] = []interface{}{value} } else { return "" } @@ -535,7 +533,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { } return strings.Join(sqls, " AND ") case interface{}: - m := &Model{data: query, do: s} + m := &Model{data: value, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value))) @@ -586,24 +584,23 @@ func (s *Do) whereSql() (sql string) { and_conditions = append(and_conditions, s.buildNotCondition(clause)) } - and_sql := strings.Join(and_conditions, " AND ") or_sql := strings.Join(or_conditions, " OR ") - combined_conditions := and_sql - if len(combined_conditions) > 0 { + combined_sql := strings.Join(and_conditions, " AND ") + if len(combined_sql) > 0 { if len(or_sql) > 0 { - combined_conditions = combined_conditions + " OR " + or_sql + combined_sql = combined_sql + " OR " + or_sql } } else { - combined_conditions = or_sql + combined_sql = or_sql } if len(primary_condiations) > 0 { sql = "WHERE " + strings.Join(primary_condiations, " AND ") - if len(combined_conditions) > 0 { - sql = sql + " AND (" + combined_conditions + ")" + if len(combined_sql) > 0 { + sql = sql + " AND (" + combined_sql + ")" } - } else if len(combined_conditions) > 0 { - sql = "WHERE " + combined_conditions + } else if len(combined_sql) > 0 { + sql = "WHERE " + combined_sql } return } @@ -646,7 +643,7 @@ func (s *Do) combinedSql() string { func (s *Do) createTable() *Do { var sqls []string - for _, field := range s.model.fields("other") { + for _, field := range s.model.fields("migration") { if len(field.SqlType) > 0 { sqls = append(sqls, field.DbName+" "+field.SqlType) } @@ -681,7 +678,7 @@ func (s *Do) autoMigrate() *Do { if len(table_name) == 0 { s.createTable() } else { - for _, field := range s.model.fields("other") { + for _, field := range s.model.fields("migration") { var column_name, data_type string sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName())) s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.DbName)), s.sqlVars...).Scan(&column_name, &data_type) @@ -699,8 +696,7 @@ func (s *Do) autoMigrate() *Do { func (s *Do) begin() *Do { if db, ok := s.db.(sql_db); ok { - tx, err := db.Begin() - if err == nil { + if tx, err := db.Begin(); err == nil { s.db = interface{}(tx).(sql_common) s.startedTransaction = true } @@ -721,14 +717,11 @@ func (s *Do) commit_or_rollback() { } func (s *Do) initializeWithSearchCondition() { - m := Model{data: s.value, do: s} - for _, clause := range s.whereClause { - query := clause["query"] - switch value := query.(type) { + switch value := clause["query"].(type) { case map[string]interface{}: for k, v := range value { - m.setValueByColumn(k, v, s.value) + s.model.setValueByColumn(k, v, s.value) } case []interface{}: for _, obj := range value { @@ -736,18 +729,18 @@ func (s *Do) initializeWithSearchCondition() { case reflect.Struct: m := &Model{data: obj, do: s} for _, field := range m.columnsHasValue("other") { - m.setValueByColumn(field.DbName, field.Value, s.value) + s.model.setValueByColumn(field.DbName, field.Value, s.value) } case reflect.Map: for key, value := range obj.(map[string]interface{}) { - m.setValueByColumn(key, value, s.value) + s.model.setValueByColumn(key, value, s.value) } } } case interface{}: - m := &Model{data: query, do: s} + m := &Model{data: value, do: s} for _, field := range m.columnsHasValue("other") { - m.setValueByColumn(field.DbName, field.Value, s.value) + s.model.setValueByColumn(field.DbName, field.Value, s.value) } } } diff --git a/model.go b/model.go index 17c872fd..cc93a4e1 100644 --- a/model.go +++ b/model.go @@ -87,8 +87,6 @@ func (m *Model) fields(operation string) (fields []Field) { field.Name = p.Name field.DbName = toSnake(p.Name) field.IsPrimaryKey = m.primaryKeyDb() == field.DbName - field.AutoCreateTime = "created_at" == field.DbName - field.AutoUpdateTime = "updated_at" == field.DbName value := indirect_value.FieldByName(p.Name) time_value, is_time := value.Interface().(time.Time) @@ -98,9 +96,7 @@ func (m *Model) fields(operation string) (fields []Field) { case reflect.String: field.IsBlank = value.String() == "" case reflect.Slice: - if value.Len() == 0 { - field.IsBlank = true - } + field.IsBlank = value.Len() == 0 case reflect.Struct: if is_time { field.IsBlank = time_value.IsZero() @@ -111,7 +107,6 @@ func (m *Model) fields(operation string) (fields []Field) { field.IsBlank = !value.FieldByName("Valid").Interface().(bool) } else { m := &Model{data: value.Interface(), do: m.do} - fields := m.columnsHasValue("other") if len(fields) == 0 { field.IsBlank = true @@ -121,20 +116,20 @@ func (m *Model) fields(operation string) (fields []Field) { } if is_time { + field.AutoCreateTime = "created_at" == field.DbName + field.AutoUpdateTime = "updated_at" == field.DbName + switch operation { case "create": if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() { value.Set(reflect.ValueOf(time.Now())) } case "update": - if field.AutoCreateTime && time_value.IsZero() { - value.Set(reflect.ValueOf(time.Now())) - } - if field.AutoUpdateTime { value.Set(reflect.ValueOf(time.Now())) } } + field.SqlType = getSqlType(m.do.chain.driver(), value, 0) } else if field.IsPrimaryKey { field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0) @@ -176,7 +171,7 @@ func (m *Model) fields(operation string) (fields []Field) { } if len(m._cache_fields) == 0 { - m._cache_fields = make(map[string][]Field) + m._cache_fields = map[string][]Field{} } m._cache_fields[operation] = fields return @@ -191,14 +186,12 @@ func (m *Model) columnsHasValue(operation string) (fields []Field) { return } -func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[string]interface{}, bool) { +func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results map[string]interface{}, any_updated bool) { if m.data == nil { return values, true } data := reflect.Indirect(reflect.ValueOf(m.data)) - results := map[string]interface{}{} - for key, value := range values { field := data.FieldByName(snakeToUpperCamel(key)) if field.IsValid() { @@ -206,33 +199,31 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[stri switch field.Kind() { case reflect.Int, reflect.Int32, reflect.Int64: if field.Int() != reflect.ValueOf(value).Int() { - results[key] = value + any_updated = true } field.SetInt(reflect.ValueOf(value).Int()) default: - results[key] = value + any_updated = true field.Set(reflect.ValueOf(value)) } } } } - if values["updated_at"] != nil && len(results) > 0 { + if values["updated_at"] != nil && any_updated { setFieldValue(data.FieldByName("UpdatedAt"), time.Now()) } - result := len(results) > 0 - return map[string]interface{}{}, result + return } func (m *Model) columnsAndValues(operation string) map[string]interface{} { - if m.data == nil { - return map[string]interface{}{} - } - results := map[string]interface{}{} - for _, field := range m.fields(operation) { - if !field.IsPrimaryKey && (len(field.SqlType) > 0) { - results[field.DbName] = field.Value + + if m.data != nil { + for _, field := range m.fields(operation) { + if !field.IsPrimaryKey && (len(field.SqlType) > 0) { + results[field.DbName] = field.Value + } } } return results @@ -252,17 +243,15 @@ func (m *Model) hasColumn(name string) bool { } func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { - if m.data == nil { - return - } - - data := reflect.Indirect(reflect.ValueOf(m.data)) - if data.Kind() == reflect.Slice { - has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() - is_slice = true - } else { - if has_column = data.FieldByName(name).IsValid(); has_column { - value = data.FieldByName(name).Interface() + if m.data != nil { + data := reflect.Indirect(reflect.ValueOf(m.data)) + if data.Kind() == reflect.Slice { + has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() + is_slice = true + } else { + if has_column = data.FieldByName(name).IsValid(); has_column { + value = data.FieldByName(name).Interface() + } } } return @@ -285,8 +274,7 @@ func (m *Model) tableName() (str string) { fm := reflect.Indirect(reflect.ValueOf(m.data)).MethodByName("TableName") if fm.IsValid() { - v := fm.Call([]reflect.Value{}) - if len(v) > 0 { + if v := fm.Call([]reflect.Value{}); len(v) > 0 { if result, ok := v[0].Interface().(string); ok { return result } @@ -315,8 +303,7 @@ func (m *Model) callMethod(method string) { fm := reflect.ValueOf(m.data).MethodByName(method) if fm.IsValid() { - v := fm.Call([]reflect.Value{}) - if len(v) > 0 { + if v := fm.Call([]reflect.Value{}); len(v) > 0 { if verr, ok := v[0].Interface().(error); ok { m.do.err(verr) }