gorm/sql.go

280 lines
6.1 KiB
Go
Raw Normal View History

2013-10-26 05:49:40 +04:00
package gorm
import (
"errors"
2013-10-26 08:33:05 +04:00
"fmt"
2013-10-26 07:59:58 +04:00
"reflect"
"regexp"
2013-10-26 05:49:40 +04:00
"strings"
)
func (s *Orm) validSql(str string) (result bool) {
result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str)
if !result {
s.Error = errors.New(fmt.Sprintf("SQL is not valid, %s", str))
}
return
}
2013-10-26 07:59:58 +04:00
func (s *Orm) explain(value interface{}, operation string) *Orm {
2013-10-27 03:52:04 +04:00
s.Model(value)
2013-10-27 05:32:49 +04:00
2013-10-26 05:49:40 +04:00
switch operation {
2013-10-26 16:20:49 +04:00
case "Create":
s.createSql(value)
case "Update":
s.updateSql(value)
2013-10-26 05:49:40 +04:00
case "Delete":
s.deleteSql(value)
2013-10-26 07:59:58 +04:00
case "Query":
s.querySql(value)
2013-10-26 11:47:30 +04:00
case "CreateTable":
2013-10-27 03:52:04 +04:00
s.Sql = s.model.CreateTable()
2013-10-26 07:59:58 +04:00
}
return s
}
func (s *Orm) querySql(out interface{}) {
2013-10-27 07:21:33 +04:00
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql())
2013-10-26 07:59:58 +04:00
return
}
func (s *Orm) query(out interface{}) {
2013-10-26 08:33:05 +04:00
var (
is_slice bool
dest_type reflect.Type
)
dest_out := reflect.Indirect(reflect.ValueOf(out))
if x := dest_out.Kind(); x == reflect.Slice {
is_slice = true
dest_type = dest_out.Type().Elem()
}
2013-10-27 08:00:39 +04:00
debug(s.Sql)
debug(s.SqlVars)
2013-10-26 08:33:05 +04:00
2013-10-26 10:10:47 +04:00
rows, err := s.db.Query(s.Sql, s.SqlVars...)
defer rows.Close()
2013-10-26 07:59:58 +04:00
s.Error = err
if rows.Err() != nil {
s.Error = rows.Err()
}
2013-10-26 08:33:05 +04:00
counts := 0
2013-10-26 07:59:58 +04:00
for rows.Next() {
counts += 1
2013-10-26 08:33:05 +04:00
var dest reflect.Value
if is_slice {
dest = reflect.New(dest_type).Elem()
} else {
dest = reflect.ValueOf(out).Elem()
}
2013-10-26 07:59:58 +04:00
columns, _ := rows.Columns()
var values []interface{}
for _, value := range columns {
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
2013-10-26 07:59:58 +04:00
}
s.Error = rows.Scan(values...)
2013-10-26 20:36:56 +04:00
if is_slice {
dest_out.Set(reflect.Append(dest_out, dest))
}
2013-10-26 05:49:40 +04:00
}
2013-10-26 20:36:56 +04:00
if (counts == 0) && !is_slice {
s.Error = errors.New("Record not found!")
}
2013-10-26 05:49:40 +04:00
}
2013-10-27 05:32:49 +04:00
func (s *Orm) pluck(value interface{}) {
dest_out := reflect.Indirect(reflect.ValueOf(value))
dest_type := dest_out.Type().Elem()
rows, err := s.db.Query(s.Sql, s.SqlVars...)
s.Error = err
2013-10-27 07:21:33 +04:00
2013-10-27 05:32:49 +04:00
defer rows.Close()
for rows.Next() {
2013-10-27 06:28:47 +04:00
dest := reflect.New(dest_type).Elem().Interface()
s.Error = rows.Scan(&dest)
2013-10-27 07:21:33 +04:00
switch dest.(type) {
case []uint8:
if dest_type.String() == "string" {
dest = string(dest.([]uint8))
}
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
default:
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
}
2013-10-27 05:32:49 +04:00
}
return
}
2013-10-26 16:20:49 +04:00
func (s *Orm) createSql(value interface{}) {
2013-10-27 03:52:04 +04:00
columns, values := s.model.ColumnsAndValues()
2013-10-26 19:30:17 +04:00
var sqls []string
for _, value := range values {
sqls = append(sqls, s.addToVars(value))
}
2013-10-26 05:49:40 +04:00
s.Sql = fmt.Sprintf(
2013-10-26 16:20:49 +04:00
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
2013-10-26 05:49:40 +04:00
s.TableName,
2013-10-26 17:37:42 +04:00
strings.Join(s.quoteMap(columns), ","),
2013-10-26 19:30:17 +04:00
strings.Join(sqls, ","),
2013-10-27 03:52:04 +04:00
s.model.ReturningStr(),
2013-10-26 05:49:40 +04:00
)
return
}
2013-10-26 16:20:49 +04:00
func (s *Orm) create(value interface{}) {
var id int64
if s.driver == "postgres" {
s.Error = s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)
} else {
s.SqlResult, s.Error = s.db.Exec(s.Sql, s.SqlVars...)
id, s.Error = s.SqlResult.LastInsertId()
}
2013-10-27 03:52:04 +04:00
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.PrimaryKey()).SetInt(id)
2013-10-26 16:20:49 +04:00
}
func (s *Orm) updateSql(value interface{}) {
2013-10-27 03:52:04 +04:00
columns, values := s.model.ColumnsAndValues()
2013-10-26 17:37:42 +04:00
var sets []string
for index, column := range columns {
2013-10-26 19:30:17 +04:00
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))
2013-10-26 17:37:42 +04:00
}
s.Sql = fmt.Sprintf(
2013-10-26 19:30:17 +04:00
"UPDATE %v SET %v %v",
2013-10-26 17:37:42 +04:00
s.TableName,
strings.Join(sets, ", "),
2013-10-27 07:21:33 +04:00
s.combinedSql(),
2013-10-26 17:37:42 +04:00
)
2013-10-26 19:30:17 +04:00
2013-10-26 16:20:49 +04:00
return
}
func (s *Orm) update(value interface{}) {
2013-10-26 17:37:42 +04:00
s.Exec()
2013-10-26 16:20:49 +04:00
return
}
2013-10-26 05:49:40 +04:00
func (s *Orm) deleteSql(value interface{}) {
2013-10-27 07:21:33 +04:00
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
2013-10-26 05:49:40 +04:00
return
}
2013-10-27 08:00:39 +04:00
func (s *Orm) buildWhereCondition(clause map[string]interface{}) string {
str := "( " + clause["query"].(string) + " )"
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
v := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < v.Len(); i++ {
temp_marks = append(temp_marks, "?")
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
for i := 0; i < v.Len(); i++ {
str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1)
}
default:
str = strings.Replace(str, "?", s.addToVars(arg), 1)
}
}
return str
}
2013-10-26 05:49:40 +04:00
func (s *Orm) whereSql() (sql string) {
2013-10-27 08:00:39 +04:00
var primary_condiation string
var and_conditions, or_conditions, not_conditions []string
2013-10-27 03:52:04 +04:00
if !s.model.PrimaryKeyIsEmpty() {
2013-10-27 08:00:39 +04:00
primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue()))
}
for _, clause := range s.whereClause {
and_conditions = append(and_conditions, s.buildWhereCondition(clause))
}
for _, clause := range s.notClause {
and_conditions = append(and_conditions, "!"+s.buildWhereCondition(clause))
}
for _, clause := range s.orClause {
or_conditions = append(or_conditions, s.buildWhereCondition(clause))
}
and_sql := strings.Join(and_conditions, " AND ")
or_sql := strings.Join(not_conditions, " OR ")
combined_conditions := and_sql
if len(combined_conditions) > 0 {
if len(or_sql) > 0 {
combined_conditions = combined_conditions + " OR " + or_sql
2013-10-26 10:10:47 +04:00
}
2013-10-27 08:00:39 +04:00
} else {
combined_conditions = or_sql
2013-10-26 10:10:47 +04:00
}
2013-10-26 19:30:17 +04:00
2013-10-27 08:00:39 +04:00
if len(primary_condiation) > 0 {
sql = "WHERE " + primary_condiation
if len(combined_conditions) > 0 {
sql = sql + " AND ( " + combined_conditions + " )"
}
} else if len(combined_conditions) > 0 {
sql = "WHERE " + combined_conditions
2013-10-26 19:30:17 +04:00
}
2013-10-27 08:00:39 +04:00
debug(sql)
2013-10-26 05:49:40 +04:00
return
}
2013-10-26 19:30:17 +04:00
2013-10-27 05:50:11 +04:00
func (s *Orm) selectSql() string {
if len(s.selectStr) == 0 {
return " * "
} else {
return s.selectStr
}
}
2013-10-27 07:38:05 +04:00
func (s *Orm) orderSql() string {
2013-10-27 07:21:33 +04:00
if len(s.orderStrs) == 0 {
2013-10-27 07:38:05 +04:00
return ""
2013-10-27 07:21:33 +04:00
} else {
return " ORDER BY " + strings.Join(s.orderStrs, ",")
}
}
2013-10-27 07:38:05 +04:00
func (s *Orm) limitSql() string {
if len(s.limitStr) == 0 {
return ""
} else {
return " LIMIT " + s.limitStr
}
}
2013-10-27 07:44:47 +04:00
func (s *Orm) offsetSql() string {
if len(s.offsetStr) == 0 {
return ""
} else {
return " OFFSET " + s.offsetStr
}
}
2013-10-27 07:21:33 +04:00
func (s *Orm) combinedSql() string {
2013-10-27 07:44:47 +04:00
return s.whereSql() + s.orderSql() + s.limitSql() + s.offsetSql()
2013-10-27 07:21:33 +04:00
}
2013-10-26 19:30:17 +04:00
func (s *Orm) addToVars(value interface{}) string {
s.SqlVars = append(s.SqlVars, value)
return fmt.Sprintf("$%d", len(s.SqlVars))
}