From 4985d7bd96b66203dc18618c543b99a310595aae Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 26 Jan 2014 14:55:41 +0800 Subject: [PATCH] Add scope_condition.go --- scope.go | 10 +- scope_condition.go | 231 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 3 deletions(-) create mode 100644 scope_condition.go diff --git a/scope.go b/scope.go index 0affd78f..d72388c3 100644 --- a/scope.go +++ b/scope.go @@ -49,7 +49,7 @@ func (scope *Scope) HasError() bool { } func (scope *Scope) PrimaryKey() string { - return "Id" + return "id" } func (scope *Scope) HasColumn(name string) bool { @@ -128,8 +128,12 @@ func (scope *Scope) TableName() string { } } -func (scope *Scope) CombinedConditionSql() string { - return "" +func (s *Scope) CombinedConditionSql() string { + return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql() +} + +func (scope *Scope) Fields() []*Field { + return []*Field{} } func (scope *Scope) Raw(sql string) { diff --git a/scope_condition.go b/scope_condition.go new file mode 100644 index 00000000..ceb159b7 --- /dev/null +++ b/scope_condition.go @@ -0,0 +1,231 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" +) + +func (scope *Scope) quote(str string) string { + return scope.Dialect().Quote(str) +} + +func (scope *Scope) primaryCondiation(value interface{}) string { + return fmt.Sprintf("(%v = %v)", scope.quote(scope.PrimaryKey()), value) +} + +func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { + switch value := clause["query"].(type) { + case string: + // if string is number + if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + id, _ := strconv.Atoi(value) + return scope.primaryCondiation(scope.AddToVars(id)) + } else { + str = value + } + case int, int64, int32: + return scope.primaryCondiation(scope.AddToVars(value)) + case sql.NullInt64: + return scope.primaryCondiation(scope.AddToVars(value.Int64)) + case []int64, []int, []int32, []string: + str = fmt.Sprintf("(%v in (?))", scope.quote(scope.PrimaryKey())) + clause["args"] = []interface{}{value} + case map[string]interface{}: + var sqls []string + for key, value := range value { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(key), scope.AddToVars(value))) + } + return strings.Join(sqls, " AND ") + case interface{}: + var sqls []string + for _, field := range scope.Fields() { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.dbName), scope.AddToVars(field.Value))) + } + return strings.Join(sqls, " AND ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.TypeOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + values := reflect.ValueOf(arg) + var temp_marks []string + for i := 0; i < values.Len(); i++ { + temp_marks = append(temp_marks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { + var not_equal_sql string + + switch value := clause["query"].(type) { + case string: + if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + id, _ := strconv.Atoi(value) + return fmt.Sprintf("(%v <> %v)", scope.quote(scope.PrimaryKey()), id) + } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { + str = fmt.Sprintf(" NOT (%v) ", value) + not_equal_sql = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v NOT IN (?))", scope.quote(value)) + not_equal_sql = fmt.Sprintf("(%v <> ?)", scope.quote(value)) + } + case int, int64, int32: + return fmt.Sprintf("(%v <> %v)", scope.quote(scope.PrimaryKey()), value) + case []int64, []int, []int32, []string: + if reflect.ValueOf(value).Len() > 0 { + str = fmt.Sprintf("(%v not in (?))", scope.quote(scope.PrimaryKey())) + clause["args"] = []interface{}{value} + } else { + return "" + } + case map[string]interface{}: + var sqls []string + for key, value := range value { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(key), scope.AddToVars(value))) + } + return strings.Join(sqls, " AND ") + case interface{}: + var sqls []string + for _, field := range scope.Fields() { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.dbName), scope.AddToVars(field.Value))) + } + return strings.Join(sqls, " AND ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.TypeOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + values := reflect.ValueOf(arg) + var temp_marks []string + for i := 0; i < values.Len(); i++ { + temp_marks = append(temp_marks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) + default: + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = scanner.Value() + } + str = strings.Replace(not_equal_sql, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) where(where ...interface{}) { + if len(where) > 0 { + scope.Search = scope.Search.clone().where(where[0], where[1:]...) + } +} + +func (scope *Scope) whereSql() (sql string) { + var primary_condiations, and_conditions, or_conditions []string + + if !scope.Search.unscope && scope.HasColumn("DeletedAt") { + primary_condiations = append(primary_condiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')") + } + + if !scope.PrimaryKeyZero() { + primary_condiations = append(primary_condiations, scope.primaryCondiation(scope.AddToVars(scope.PrimaryKeyValue()))) + } + + for _, clause := range scope.Search.whereClause { + and_conditions = append(and_conditions, scope.buildWhereCondition(clause)) + } + + for _, clause := range scope.Search.orClause { + or_conditions = append(or_conditions, scope.buildWhereCondition(clause)) + } + + for _, clause := range scope.Search.notClause { + and_conditions = append(and_conditions, scope.buildNotCondition(clause)) + } + + or_sql := strings.Join(or_conditions, " OR ") + combined_sql := strings.Join(and_conditions, " AND ") + if len(combined_sql) > 0 { + if len(or_sql) > 0 { + combined_sql = combined_sql + " OR " + or_sql + } + } else { + combined_sql = or_sql + } + + if len(primary_condiations) > 0 { + sql = "WHERE " + strings.Join(primary_condiations, " AND ") + if len(combined_sql) > 0 { + sql = sql + " AND (" + combined_sql + ")" + } + } else if len(combined_sql) > 0 { + sql = "WHERE " + combined_sql + } + return +} + +func (s *Scope) selectSql() string { + if len(s.Search.selectStr) == 0 { + return "*" + } else { + return s.Search.selectStr + } +} + +func (s *Scope) orderSql() string { + if len(s.Search.orders) == 0 { + return "" + } else { + return " ORDER BY " + strings.Join(s.Search.orders, ",") + } +} + +func (s *Scope) limitSql() string { + if len(s.Search.limitStr) == 0 { + return "" + } else { + return " LIMIT " + s.Search.limitStr + } +} + +func (s *Scope) offsetSql() string { + if len(s.Search.offsetStr) == 0 { + return "" + } else { + return " OFFSET " + s.Search.offsetStr + } +} + +func (s *Scope) groupSql() string { + if len(s.Search.groupStr) == 0 { + return "" + } else { + return " GROUP BY " + s.Search.groupStr + } +} + +func (s *Scope) havingSql() string { + if s.Search.havingClause == nil { + return "" + } else { + return " HAVING " + s.buildWhereCondition(s.Search.havingClause) + } +} + +func (s *Scope) joinsSql() string { + return s.Search.joinsStr + " " +}