gorm/do.go

449 lines
9.5 KiB
Go
Raw Normal View History

2013-10-26 05:49:40 +04:00
package gorm
import (
2013-10-27 15:41:58 +04:00
"database/sql"
"errors"
2013-10-26 08:33:05 +04:00
"fmt"
2013-10-26 07:59:58 +04:00
"reflect"
2013-10-27 16:54:23 +04:00
"regexp"
"strconv"
"time"
2013-10-27 15:41:58 +04:00
2013-10-26 05:49:40 +04:00
"strings"
)
2013-10-27 15:41:58 +04:00
type Do struct {
2013-10-28 16:27:25 +04:00
chain *Chain
db *sql.DB
driver string
guessedTableName string
specifiedTableName string
Errors []error
2013-10-27 15:41:58 +04:00
model *Model
value interface{}
2013-10-29 03:39:26 +04:00
sqlResult sql.Result
sql string
sqlVars []interface{}
2013-10-27 15:41:58 +04:00
whereClause []map[string]interface{}
orClause []map[string]interface{}
selectStr string
orderStrs []string
offsetStr string
limitStr string
unscoped bool
2013-10-28 17:52:22 +04:00
updateAttrs map[string]interface{}
ignoreProtectedAttrs bool
}
2013-10-28 16:27:25 +04:00
func (s *Do) tableName() string {
if s.specifiedTableName == "" {
2013-10-29 03:39:26 +04:00
var err error
s.guessedTableName, err = s.model.tableName()
s.err(err)
2013-10-28 16:27:25 +04:00
return s.guessedTableName
} else {
return s.specifiedTableName
}
}
2013-10-29 03:39:26 +04:00
func (s *Do) err(err error) error {
2013-10-27 15:41:58 +04:00
if err != nil {
s.Errors = append(s.Errors, err)
s.chain.err(err)
2013-10-26 07:59:58 +04:00
}
2013-10-29 03:39:26 +04:00
return err
2013-10-26 07:59:58 +04:00
}
2013-10-28 08:12:12 +04:00
func (s *Do) hasError() bool {
return len(s.Errors) > 0
}
2013-10-27 15:41:58 +04:00
func (s *Do) setModel(value interface{}) {
2013-10-29 03:39:26 +04:00
s.model = &Model{data: value, driver: s.driver}
2013-10-27 15:41:58 +04:00
s.value = value
2013-10-26 07:59:58 +04:00
}
2013-10-27 15:41:58 +04:00
func (s *Do) addToVars(value interface{}) string {
2013-10-29 03:39:26 +04:00
s.sqlVars = append(s.sqlVars, value)
return fmt.Sprintf("$%d", len(s.sqlVars))
2013-10-27 15:41:58 +04:00
}
2013-10-26 08:33:05 +04:00
func (s *Do) exec(sql ...string) {
2013-10-28 08:12:12 +04:00
if s.hasError() {
return
}
2013-10-27 15:41:58 +04:00
var err error
if len(sql) == 0 {
2013-10-29 03:39:26 +04:00
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
2013-10-27 15:41:58 +04:00
} else {
2013-10-29 03:39:26 +04:00
s.sqlResult, err = s.db.Exec(sql[0])
2013-10-26 08:33:05 +04:00
}
s.err(err)
2013-10-26 05:49:40 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) save() {
2013-10-28 11:55:41 +04:00
if s.model.primaryKeyZero() {
2013-10-27 15:41:58 +04:00
s.create()
} else {
s.update()
2013-10-27 05:32:49 +04:00
}
2013-10-29 03:39:26 +04:00
return
2013-10-27 05:32:49 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) prepareCreateSql() {
2013-10-28 17:52:22 +04:00
var sqls, columns []string
2013-10-26 19:30:17 +04:00
2013-10-28 17:52:22 +04:00
for key, value := range s.model.columnsAndValues("create") {
columns = append(columns, key)
2013-10-26 19:30:17 +04:00
sqls = append(sqls, s.addToVars(value))
}
2013-10-29 03:39:26 +04:00
s.sql = fmt.Sprintf(
2013-10-26 16:20:49 +04:00
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
2013-10-28 16:27:25 +04:00
s.tableName(),
2013-10-29 03:39:26 +04:00
strings.Join(columns, ","),
2013-10-26 19:30:17 +04:00
strings.Join(sqls, ","),
2013-10-28 11:55:41 +04:00
s.model.returningStr(),
2013-10-26 05:49:40 +04:00
)
2013-10-29 03:39:26 +04:00
return
2013-10-26 05:49:40 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) create() {
2013-10-27 10:51:23 +04:00
s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave"))
2013-10-27 15:41:58 +04:00
s.prepareCreateSql()
2013-10-26 16:20:49 +04:00
2013-10-29 03:39:26 +04:00
if !s.hasError() {
2013-10-27 15:41:58 +04:00
var id int64
2013-10-27 12:06:45 +04:00
if s.driver == "postgres" {
2013-10-29 03:39:26 +04:00
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
2013-10-27 12:06:45 +04:00
} else {
var err error
2013-10-29 03:39:26 +04:00
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
2013-10-27 12:06:45 +04:00
s.err(err)
2013-10-29 03:39:26 +04:00
id, err = s.sqlResult.LastInsertId()
2013-10-27 12:06:45 +04:00
s.err(err)
}
2013-10-29 03:39:26 +04:00
if !s.hasError() {
result := reflect.ValueOf(s.value).Elem()
result.FieldByName(s.model.primaryKey()).SetInt(id)
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
}
2013-10-27 12:06:45 +04:00
}
2013-10-27 15:41:58 +04:00
2013-10-29 03:39:26 +04:00
return
2013-10-26 16:20:49 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) prepareUpdateSql() {
2013-10-28 17:52:22 +04:00
update_attrs := s.updateAttrs
if len(update_attrs) == 0 {
update_attrs = s.model.columnsAndValues("update")
}
var sqls []string
for key, value := range update_attrs {
2013-10-29 03:39:26 +04:00
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
2013-10-26 17:37:42 +04:00
}
2013-10-29 03:39:26 +04:00
s.sql = fmt.Sprintf(
2013-10-26 19:30:17 +04:00
"UPDATE %v SET %v %v",
2013-10-28 16:27:25 +04:00
s.tableName(),
2013-10-28 17:52:22 +04:00
strings.Join(sqls, ", "),
2013-10-27 07:21:33 +04:00
s.combinedSql(),
2013-10-26 17:37:42 +04:00
)
2013-10-29 03:39:26 +04:00
return
2013-10-26 16:20:49 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) update() {
2013-10-27 10:51:23 +04:00
s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave"))
2013-10-29 03:39:26 +04:00
s.prepareUpdateSql()
if !s.hasError() {
s.exec()
if !s.hasError() {
s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave"))
}
2013-10-27 12:06:45 +04:00
}
2013-10-29 03:39:26 +04:00
return
2013-10-26 16:20:49 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) prepareDeleteSql() {
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
return
2013-10-26 05:49:40 +04:00
}
2013-10-27 10:51:23 +04:00
2013-10-29 03:39:26 +04:00
func (s *Do) delete() {
2013-10-27 10:51:23 +04:00
s.err(s.model.callMethod("BeforeDelete"))
2013-10-29 03:39:26 +04:00
if !s.hasError() {
2013-10-29 06:19:20 +04:00
if !s.unscoped && s.model.hasColumn("DeletedAt") {
delete_sql := "deleted_at=" + s.addToVars(time.Now())
s.sql = fmt.Sprintf("UPDATE %v SET %v %v", s.tableName(), delete_sql, s.combinedSql())
s.exec()
} else {
s.prepareDeleteSql()
s.exec()
}
2013-10-29 03:39:26 +04:00
if !s.hasError() {
s.err(s.model.callMethod("AfterDelete"))
}
2013-10-27 12:06:45 +04:00
}
2013-10-29 03:39:26 +04:00
return
2013-10-27 15:41:58 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) prepareQuerySql() {
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
return
2013-10-27 15:41:58 +04:00
}
2013-10-27 16:54:23 +04:00
func (s *Do) query(where ...interface{}) {
if len(where) > 0 {
s.where(where[0], where[1:len(where)]...)
}
2013-10-27 15:41:58 +04:00
var (
is_slice bool
dest_type reflect.Type
)
dest_out := reflect.Indirect(reflect.ValueOf(s.value))
2013-10-29 06:19:20 +04:00
if dest_out.Kind() == reflect.Slice {
2013-10-27 15:41:58 +04:00
is_slice = true
dest_type = dest_out.Type().Elem()
}
s.prepareQuerySql()
2013-10-29 03:39:26 +04:00
if !s.hasError() {
rows, err := s.db.Query(s.sql, s.sqlVars...)
if s.err(err) != nil {
return
}
2013-10-28 08:12:12 +04:00
2013-10-29 03:39:26 +04:00
defer rows.Close()
2013-10-28 08:12:12 +04:00
2013-10-29 03:39:26 +04:00
if rows.Err() != nil {
s.err(rows.Err())
}
2013-10-28 08:12:12 +04:00
2013-10-29 03:39:26 +04:00
counts := 0
for rows.Next() {
counts += 1
var dest reflect.Value
if is_slice {
dest = reflect.New(dest_type).Elem()
} else {
dest = reflect.ValueOf(s.value).Elem()
}
2013-10-27 15:41:58 +04:00
2013-10-29 03:39:26 +04:00
columns, _ := rows.Columns()
var values []interface{}
for _, value := range columns {
field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
}
}
s.err(rows.Scan(values...))
2013-10-27 15:41:58 +04:00
2013-10-29 03:39:26 +04:00
if is_slice {
dest_out.Set(reflect.Append(dest_out, dest))
2013-10-28 16:27:25 +04:00
}
2013-10-27 15:41:58 +04:00
}
2013-10-29 03:39:26 +04:00
if (counts == 0) && !is_slice {
s.err(errors.New("Record not found!"))
2013-10-27 15:41:58 +04:00
}
}
2013-10-27 10:51:23 +04:00
}
2013-10-27 16:07:13 +04:00
func (s *Do) count(value interface{}) {
dest_out := reflect.Indirect(reflect.ValueOf(value))
s.prepareQuerySql()
2013-10-29 03:39:26 +04:00
if !s.hasError() {
rows, err := s.db.Query(s.sql, s.sqlVars...)
if s.err(err) != nil {
return
}
defer rows.Close()
for rows.Next() {
var dest int64
if s.err(rows.Scan(&dest)) == nil {
dest_out.Set(reflect.ValueOf(dest))
}
}
2013-10-27 16:07:13 +04:00
}
return
}
2013-10-29 03:39:26 +04:00
func (s *Do) pluck(column string, value interface{}) {
s.selectStr = column
2013-10-27 15:41:58 +04:00
dest_out := reflect.Indirect(reflect.ValueOf(value))
dest_type := dest_out.Type().Elem()
s.prepareQuerySql()
2013-10-29 03:39:26 +04:00
if !s.hasError() {
rows, err := s.db.Query(s.sql, s.sqlVars...)
if s.err(err) != nil {
return
}
defer rows.Close()
for rows.Next() {
dest := reflect.New(dest_type).Elem().Interface()
s.err(rows.Scan(&dest))
switch dest.(type) {
case []uint8:
if dest_type.String() == "string" {
dest = string(dest.([]uint8))
}
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
default:
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
2013-10-27 15:41:58 +04:00
}
}
}
2013-10-29 03:39:26 +04:00
return
2013-10-27 15:41:58 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) where(querystring interface{}, args ...interface{}) {
2013-10-27 16:54:23 +04:00
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
2013-10-29 03:39:26 +04:00
return
2013-10-27 16:54:23 +04:00
}
func (s *Do) primaryCondiation(value interface{}) string {
2013-10-29 03:39:26 +04:00
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value)
2013-10-27 16:54:23 +04:00
}
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
switch clause["query"].(type) {
case string:
value := clause["query"].(string)
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return s.primaryCondiation(s.addToVars(id))
} else {
str = "( " + value + " )"
}
case int, int64, int32:
return s.primaryCondiation(s.addToVars(clause["query"]))
}
2013-10-27 08:00:39 +04:00
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
v := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < v.Len(); i++ {
2013-10-29 03:39:26 +04:00
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface()))
2013-10-27 08:00:39 +04:00
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
str = strings.Replace(str, "?", s.addToVars(arg), 1)
}
}
2013-10-27 16:54:23 +04:00
return
2013-10-27 08:00:39 +04:00
}
2013-10-26 05:49:40 +04:00
2013-10-27 15:41:58 +04:00
func (s *Do) whereSql() (sql string) {
2013-10-29 06:19:20 +04:00
var primary_condiations, and_conditions, or_conditions []string
if !s.unscoped && s.model.hasColumn("DeletedAt") {
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
}
2013-10-27 08:00:39 +04:00
2013-10-28 11:55:41 +04:00
if !s.model.primaryKeyZero() {
2013-10-29 06:19:20 +04:00
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
2013-10-27 08:00:39 +04:00
}
for _, clause := range s.whereClause {
and_conditions = append(and_conditions, s.buildWhereCondition(clause))
}
for _, clause := range s.orClause {
or_conditions = append(or_conditions, s.buildWhereCondition(clause))
}
and_sql := strings.Join(and_conditions, " AND ")
2013-10-27 08:31:51 +04:00
or_sql := strings.Join(or_conditions, " OR ")
2013-10-27 08:00:39 +04:00
combined_conditions := and_sql
if len(combined_conditions) > 0 {
if len(or_sql) > 0 {
combined_conditions = combined_conditions + " OR " + or_sql
2013-10-26 10:10:47 +04:00
}
2013-10-27 08:00:39 +04:00
} else {
combined_conditions = or_sql
2013-10-26 10:10:47 +04:00
}
2013-10-26 19:30:17 +04:00
2013-10-29 06:19:20 +04:00
if len(primary_condiations) > 0 {
sql = "WHERE " + strings.Join(primary_condiations, " AND ")
2013-10-27 08:00:39 +04:00
if len(combined_conditions) > 0 {
sql = sql + " AND ( " + combined_conditions + " )"
}
} else if len(combined_conditions) > 0 {
sql = "WHERE " + combined_conditions
2013-10-26 19:30:17 +04:00
}
2013-10-26 05:49:40 +04:00
return
}
2013-10-26 19:30:17 +04:00
2013-10-27 15:41:58 +04:00
func (s *Do) selectSql() string {
2013-10-27 05:50:11 +04:00
if len(s.selectStr) == 0 {
return " * "
} else {
return s.selectStr
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) orderSql() string {
2013-10-27 07:21:33 +04:00
if len(s.orderStrs) == 0 {
2013-10-27 07:38:05 +04:00
return ""
2013-10-27 07:21:33 +04:00
} else {
return " ORDER BY " + strings.Join(s.orderStrs, ",")
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) limitSql() string {
2013-10-27 07:38:05 +04:00
if len(s.limitStr) == 0 {
return ""
} else {
return " LIMIT " + s.limitStr
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) offsetSql() string {
2013-10-27 07:44:47 +04:00
if len(s.offsetStr) == 0 {
return ""
} else {
return " OFFSET " + s.offsetStr
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) combinedSql() string {
2013-10-27 07:44:47 +04:00
return s.whereSql() + s.orderSql() + s.limitSql() + s.offsetSql()
2013-10-27 07:21:33 +04:00
}
2013-10-27 15:41:58 +04:00
func (s *Do) createTable() *Do {
2013-10-28 08:12:12 +04:00
var sqls []string
2013-10-28 11:55:41 +04:00
for _, field := range s.model.fields("null") {
2013-10-28 08:12:12 +04:00
sqls = append(sqls, field.DbName+" "+field.SqlType)
}
2013-10-29 03:39:26 +04:00
s.sql = fmt.Sprintf(
2013-10-28 08:12:12 +04:00
"CREATE TABLE \"%v\" (%v)",
2013-10-28 16:27:25 +04:00
s.tableName(),
2013-10-28 08:12:12 +04:00
strings.Join(sqls, ","),
)
2013-10-27 15:41:58 +04:00
return s
2013-10-26 19:30:17 +04:00
}