gorm/do.go

756 lines
18 KiB
Go
Raw Normal View History

2013-10-26 05:49:40 +04:00
package gorm
import (
2013-10-27 15:41:58 +04:00
"database/sql"
2013-11-10 14:33:37 +04:00
"database/sql/driver"
"errors"
2013-10-26 08:33:05 +04:00
"fmt"
2013-11-16 09:28:22 +04:00
"github.com/jinzhu/gorm/dialect"
2013-10-26 07:59:58 +04:00
"reflect"
2013-10-27 16:54:23 +04:00
"regexp"
"strconv"
2013-10-26 05:49:40 +04:00
"strings"
2013-11-04 16:32:46 +04:00
"time"
2013-10-26 05:49:40 +04:00
)
2013-10-27 15:41:58 +04:00
type Do struct {
2013-11-16 08:45:29 +04:00
db *DB
2013-11-16 09:19:29 +04:00
search *search
2013-11-16 08:45:29 +04:00
model *Model
tableName string
2013-11-16 11:01:31 +04:00
usingUpdate bool
2013-11-16 08:45:29 +04:00
value interface{}
update_attrs map[string]interface{}
hasUpdate bool
ignoreProtectedAttrs bool
sql string
sqlVars []interface{}
startedTransaction bool
}
2013-11-16 08:19:35 +04:00
func (s *Do) table() string {
if len(s.tableName) == 0 {
2013-11-16 09:19:29 +04:00
if len(s.search.tableName) == 0 {
2013-11-16 08:19:35 +04:00
s.tableName = s.model.tableName()
} else {
2013-11-16 09:19:29 +04:00
s.tableName = s.search.tableName
2013-11-16 08:19:35 +04:00
}
2013-10-28 16:27:25 +04:00
}
2013-11-16 08:19:35 +04:00
return s.tableName
2013-10-28 16:27:25 +04:00
}
2013-11-16 09:28:22 +04:00
func (s *Do) dialect() dialect.Dialect {
return s.db.parent.dialect
}
2013-10-29 03:39:26 +04:00
func (s *Do) err(err error) error {
2013-10-27 15:41:58 +04:00
if err != nil {
2013-11-16 08:19:35 +04:00
s.db.err(err)
2013-10-26 07:59:58 +04:00
}
2013-10-29 03:39:26 +04:00
return err
2013-10-26 07:59:58 +04:00
}
2013-11-02 13:29:56 +04:00
func (s *Do) setModel(value interface{}) *Do {
2013-11-11 09:40:35 +04:00
s.model = &Model{data: value, do: s}
2013-10-27 15:41:58 +04:00
s.value = value
2013-11-16 09:19:29 +04:00
if s.db.search == nil {
s.search = &search{}
} else {
s.search = s.db.search
}
2013-11-02 13:29:56 +04:00
return s
2013-10-26 07:59:58 +04:00
}
2013-10-27 15:41:58 +04:00
func (s *Do) addToVars(value interface{}) string {
2013-10-29 03:39:26 +04:00
s.sqlVars = append(s.sqlVars, value)
2013-11-16 09:28:22 +04:00
return fmt.Sprintf(s.dialect().BinVar(), len(s.sqlVars))
2013-10-27 15:41:58 +04:00
}
2013-10-26 08:33:05 +04:00
2013-11-16 08:31:05 +04:00
func (s *Do) trace(t time.Time) {
s.db.slog(s.sql, t, s.sqlVars...)
}
2013-11-16 11:01:31 +04:00
func (s *Do) exec(sqls ...string) *Do {
2013-11-16 08:31:05 +04:00
defer s.trace(time.Now())
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-11-11 07:53:56 +04:00
if len(sqls) > 0 {
s.sql = sqls[0]
}
2013-11-16 08:31:05 +04:00
_, err := s.db.db.Exec(s.sql, s.sqlVars...)
s.err(err)
2013-10-28 08:12:12 +04:00
}
2013-11-16 11:01:31 +04:00
return s
2013-10-26 05:49:40 +04:00
}
2013-11-11 18:27:17 +04:00
func (s *Do) save() *Do {
2013-10-28 11:55:41 +04:00
if s.model.primaryKeyZero() {
2013-11-11 18:27:17 +04:00
s.create()
2013-10-27 15:41:58 +04:00
} else {
2013-11-11 18:27:17 +04:00
s.update()
2013-10-27 05:32:49 +04:00
}
2013-11-11 18:27:17 +04:00
return s
2013-10-27 05:32:49 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) prepareCreateSql() {
2013-10-28 17:52:22 +04:00
var sqls, columns []string
2013-10-26 19:30:17 +04:00
2013-10-28 17:52:22 +04:00
for key, value := range s.model.columnsAndValues("create") {
columns = append(columns, key)
2013-10-26 19:30:17 +04:00
sqls = append(sqls, s.addToVars(value))
}
2013-10-29 03:39:26 +04:00
s.sql = fmt.Sprintf(
2013-11-03 17:19:38 +04:00
"INSERT INTO %v (%v) VALUES (%v) %v",
2013-11-16 08:19:35 +04:00
s.table(),
2013-10-29 03:39:26 +04:00
strings.Join(columns, ","),
2013-10-26 19:30:17 +04:00
strings.Join(sqls, ","),
2013-11-16 09:28:22 +04:00
s.dialect().ReturningStr(s.model.primaryKeyDb()),
2013-10-26 05:49:40 +04:00
)
2013-10-29 03:39:26 +04:00
return
2013-10-26 05:49:40 +04:00
}
2013-11-02 13:29:56 +04:00
func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() {
2013-11-16 08:19:35 +04:00
do := &Do{db: s.db}
reflect_value := reflect.ValueOf(field.Value)
if reflect_value.CanAddr() {
2013-11-11 18:27:17 +04:00
do.setModel(reflect_value.Addr().Interface()).save()
} else {
2013-11-11 18:27:17 +04:00
// If can't take address, then clone the value and set it back
dest_value := reflect.New(reflect_value.Type()).Elem()
2013-11-11 09:40:35 +04:00
m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
2013-11-11 18:27:17 +04:00
do.setModel(dest_value.Addr().Interface()).save()
m.setValueByColumn(field.Name, dest_value.Interface(), s.value)
}
2013-11-02 16:05:05 +04:00
if len(field.foreignKey) > 0 {
2013-11-11 18:27:17 +04:00
s.model.setValueByColumn(field.foreignKey, do.model.primaryKeyValue(), s.model.data)
2013-11-02 16:05:05 +04:00
}
2013-11-02 13:29:56 +04:00
}
}
func (s *Do) saveAfterAssociations() {
for _, field := range s.model.afterAssociations() {
reflect_value := reflect.ValueOf(field.Value)
2013-11-05 18:34:49 +04:00
2013-11-11 18:27:17 +04:00
switch reflect_value.Kind() {
2013-11-02 13:29:56 +04:00
case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ {
2013-11-16 08:19:35 +04:00
do := &Do{db: s.db}
2013-11-11 18:27:17 +04:00
value := reflect_value.Index(i).Addr().Interface()
2013-11-02 16:05:05 +04:00
if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
}
do.setModel(value).save()
2013-11-02 13:29:56 +04:00
}
default:
2013-11-16 08:19:35 +04:00
do := &Do{db: s.db}
2013-11-05 18:34:49 +04:00
if reflect_value.CanAddr() {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
do.setModel(field.Value).save()
} else {
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
2013-11-11 09:40:35 +04:00
m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") {
2013-11-05 18:34:49 +04:00
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
setFieldValue(dest_value.FieldByName(field.foreignKey), s.model.primaryKeyValue())
do.setModel(dest_value.Addr().Interface()).save()
m.setValueByColumn(field.Name, dest_value.Interface(), s.value)
2013-11-05 18:34:49 +04:00
}
2013-11-02 13:29:56 +04:00
}
2013-11-02 11:17:11 +04:00
}
}
2013-11-10 06:09:57 +04:00
func (s *Do) create() (i interface{}) {
2013-11-11 09:53:04 +04:00
s.model.callMethod("BeforeCreate")
s.model.callMethod("BeforeSave")
2013-10-27 15:41:58 +04:00
2013-11-02 13:29:56 +04:00
s.saveBeforeAssociations()
2013-10-27 15:41:58 +04:00
s.prepareCreateSql()
2013-10-26 16:20:49 +04:00
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-11-10 06:09:57 +04:00
var id interface{}
2013-11-11 18:27:17 +04:00
now := time.Now()
2013-11-16 09:28:22 +04:00
if s.dialect().SupportLastInsertId() {
2013-11-16 08:19:35 +04:00
if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
2013-11-10 19:07:09 +04:00
id, err = sql_result.LastInsertId()
2013-11-04 16:32:46 +04:00
s.err(err)
}
2013-11-14 14:59:11 +04:00
} else {
2013-11-16 08:19:35 +04:00
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
2013-10-27 12:06:45 +04:00
}
2013-11-16 08:19:35 +04:00
s.db.slog(s.sql, now, s.sqlVars...)
2013-10-27 12:06:45 +04:00
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-11-11 18:27:17 +04:00
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
2013-10-29 03:39:26 +04:00
2013-11-11 18:27:17 +04:00
s.saveAfterAssociations()
2013-11-11 09:53:04 +04:00
s.model.callMethod("AfterCreate")
s.model.callMethod("AfterSave")
2013-10-29 03:39:26 +04:00
}
2013-11-02 16:05:05 +04:00
return id
2013-10-27 12:06:45 +04:00
}
2013-10-27 15:41:58 +04:00
2013-10-29 03:39:26 +04:00
return
2013-10-26 16:20:49 +04:00
}
2013-11-16 08:45:29 +04:00
func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
2013-11-16 09:19:29 +04:00
ignore_protected := len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
2013-11-16 11:01:31 +04:00
s.usingUpdate = true
2013-11-16 09:19:29 +04:00
2013-11-16 08:45:29 +04:00
switch value := values.(type) {
case map[string]interface{}:
2013-11-16 08:45:29 +04:00
if len(value) > 0 {
2013-11-16 09:19:29 +04:00
results, has_update := s.model.updatedColumnsAndValues(value, ignore_protected)
2013-11-16 08:45:29 +04:00
if len(results) > 0 {
s.update_attrs = results
}
2013-11-16 11:01:31 +04:00
s.hasUpdate = has_update
2013-11-16 08:45:29 +04:00
}
2013-10-31 05:34:27 +04:00
case []interface{}:
2013-11-16 08:45:29 +04:00
for _, v := range value {
s.updateAttrs(v)
2013-10-31 05:34:27 +04:00
}
case interface{}:
2013-11-11 09:40:35 +04:00
m := &Model{data: values, do: s}
2013-11-16 08:45:29 +04:00
attrs := map[string]interface{}{}
2013-11-11 18:27:17 +04:00
for _, field := range m.columnsHasValue("other") {
2013-11-16 08:45:29 +04:00
attrs[field.dbName] = field.Value
}
2013-11-16 08:45:29 +04:00
s.updateAttrs(attrs)
}
2013-10-31 04:15:19 +04:00
return s
}
2013-11-16 09:19:29 +04:00
func (s *Do) prepareUpdateSql() {
2013-10-28 17:52:22 +04:00
var sqls []string
2013-11-16 09:19:29 +04:00
for key, value := range s.update_attrs {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
}
2013-11-11 18:27:17 +04:00
for key, value := range s.model.columnsAndValues("update") {
2013-10-29 03:39:26 +04:00
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
2013-10-26 17:37:42 +04:00
}
2013-10-29 03:39:26 +04:00
s.sql = fmt.Sprintf(
2013-10-26 19:30:17 +04:00
"UPDATE %v SET %v %v",
2013-11-16 08:19:35 +04:00
s.table(),
2013-10-28 17:52:22 +04:00
strings.Join(sqls, ", "),
2013-10-27 07:21:33 +04:00
s.combinedSql(),
2013-10-26 17:37:42 +04:00
)
2013-10-29 03:39:26 +04:00
return
2013-10-26 16:20:49 +04:00
}
2013-11-11 18:27:17 +04:00
func (s *Do) update() *Do {
2013-11-16 11:01:31 +04:00
if s.usingUpdate && !s.hasUpdate {
return s
}
2013-11-11 09:53:04 +04:00
s.model.callMethod("BeforeUpdate")
s.model.callMethod("BeforeSave")
2013-11-02 13:29:56 +04:00
s.saveBeforeAssociations()
2013-11-16 09:19:29 +04:00
s.prepareUpdateSql()
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-10-29 03:39:26 +04:00
s.exec()
2013-11-02 13:29:56 +04:00
s.saveAfterAssociations()
2013-10-29 03:39:26 +04:00
2013-11-11 09:53:04 +04:00
s.model.callMethod("AfterUpdate")
s.model.callMethod("AfterSave")
2013-10-27 12:06:45 +04:00
}
2013-11-02 16:05:05 +04:00
2013-11-11 18:27:17 +04:00
return s
2013-10-26 05:49:40 +04:00
}
2013-10-27 10:51:23 +04:00
2013-11-11 18:27:17 +04:00
func (s *Do) delete() *Do {
2013-11-11 09:53:04 +04:00
s.model.callMethod("BeforeDelete")
2013-10-29 03:39:26 +04:00
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-11-16 09:19:29 +04:00
if !s.search.unscope && s.model.hasColumn("DeletedAt") {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql())
} else {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql())
}
2013-11-11 18:27:17 +04:00
s.exec()
2013-11-11 09:53:04 +04:00
s.model.callMethod("AfterDelete")
2013-10-27 12:06:45 +04:00
}
2013-11-11 18:27:17 +04:00
return s
2013-10-27 15:41:58 +04:00
}
2013-10-29 03:39:26 +04:00
func (s *Do) prepareQuerySql() {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql())
2013-10-29 03:39:26 +04:00
return
2013-10-27 15:41:58 +04:00
}
2013-11-16 11:01:31 +04:00
func (s *Do) first() *Do {
s.search = s.search.clone().order(s.model.primaryKeyDb()).limit(1)
s.query()
2013-11-16 11:01:31 +04:00
return s
}
2013-11-16 11:01:31 +04:00
func (s *Do) last() *Do {
s.search = s.search.clone().order(s.model.primaryKeyDb() + " DESC").limit(1)
s.query()
2013-11-16 11:01:31 +04:00
return s
}
func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) {
2013-11-15 09:30:31 +04:00
if has_column, is_slice, value := from.columnAndValue(foreign_key); has_column {
from_from = true
if is_slice {
foreign_value = to.primaryKeyValue()
} else {
foreign_value = value
}
2013-11-15 09:30:31 +04:00
} else if has_column, _, _ := to.columnAndValue(foreign_key); has_column {
2013-11-05 18:34:49 +04:00
foreign_value = from.primaryKeyValue()
} else {
err = errors.New("Can't find valid foreign Key")
}
return
}
2013-11-16 11:01:31 +04:00
func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
var foreign_value interface{}
var from_from bool
var foreign_key string
var err error
2013-11-11 09:40:35 +04:00
from := &Model{data: value, do: s}
to := &Model{data: s.value, do: s}
foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id")
for _, fk := range foreign_keys {
err, from_from, foreign_value = s.getForeignKey(from, to, snakeToUpperCamel(fk))
if err == nil {
foreign_key = fk
break
}
}
if from_from {
s.where(foreign_value).query()
} else {
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value))
s.where(query).query()
}
2013-11-16 11:01:31 +04:00
return s
}
2013-11-16 11:01:31 +04:00
func (s *Do) query() *Do {
2013-10-27 15:41:58 +04:00
var (
is_slice bool
dest_type reflect.Type
)
dest_out := reflect.Indirect(reflect.ValueOf(s.value))
2013-10-29 06:19:20 +04:00
if dest_out.Kind() == reflect.Slice {
2013-10-27 15:41:58 +04:00
is_slice = true
dest_type = dest_out.Type().Elem()
} else {
2013-11-16 11:01:31 +04:00
s.search = s.search.clone().limit(1)
2013-10-27 15:41:58 +04:00
}
s.prepareQuerySql()
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-11-11 11:48:31 +04:00
now := time.Now()
2013-11-16 08:19:35 +04:00
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
2013-11-11 18:27:17 +04:00
2013-10-29 03:39:26 +04:00
if s.err(err) != nil {
2013-11-16 11:01:31 +04:00
return s
2013-10-29 03:39:26 +04:00
}
2013-10-28 08:12:12 +04:00
2013-10-29 03:39:26 +04:00
defer rows.Close()
2013-11-11 18:27:17 +04:00
var has_record bool
2013-10-29 03:39:26 +04:00
for rows.Next() {
2013-11-11 18:27:17 +04:00
has_record = true
dest := dest_out
2013-10-29 03:39:26 +04:00
if is_slice {
dest = reflect.New(dest_type).Elem()
}
2013-10-27 15:41:58 +04:00
2013-10-29 03:39:26 +04:00
columns, _ := rows.Columns()
var values []interface{}
for _, value := range columns {
field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
2013-11-11 18:27:17 +04:00
values = append(values, field.Addr().Interface())
} else {
2013-11-11 18:27:17 +04:00
var ignore interface{}
values = append(values, &ignore)
2013-10-29 03:39:26 +04:00
}
}
s.err(rows.Scan(values...))
2013-10-27 15:41:58 +04:00
2013-10-29 03:39:26 +04:00
if is_slice {
dest_out.Set(reflect.Append(dest_out, dest))
2013-10-28 16:27:25 +04:00
}
2013-10-27 15:41:58 +04:00
}
2013-11-11 18:27:17 +04:00
if !has_record && !is_slice {
2013-11-15 14:36:39 +04:00
s.err(RecordNotFound)
2013-10-27 15:41:58 +04:00
}
}
2013-11-16 11:01:31 +04:00
return s
2013-10-27 10:51:23 +04:00
}
2013-11-16 11:01:31 +04:00
func (s *Do) count(value interface{}) *Do {
s.search = s.search.clone().selects("count(*)")
2013-10-27 16:07:13 +04:00
s.prepareQuerySql()
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-11-11 11:48:31 +04:00
now := time.Now()
2013-11-16 08:19:35 +04:00
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
s.db.slog(s.sql, now, s.sqlVars...)
2013-10-27 16:07:13 +04:00
}
2013-11-16 11:01:31 +04:00
return s
2013-10-27 16:07:13 +04:00
}
2013-11-16 11:01:31 +04:00
func (s *Do) pluck(column string, value interface{}) *Do {
2013-10-27 15:41:58 +04:00
dest_out := reflect.Indirect(reflect.ValueOf(value))
2013-11-16 11:01:31 +04:00
s.search = s.search.clone().selects(column)
2013-10-29 18:00:06 +04:00
if dest_out.Kind() != reflect.Slice {
2013-11-11 17:55:44 +04:00
s.err(errors.New("Results should be a slice"))
2013-11-16 11:01:31 +04:00
return s
2013-10-29 18:00:06 +04:00
}
2013-10-27 15:41:58 +04:00
s.prepareQuerySql()
2013-11-16 08:19:35 +04:00
if !s.db.hasError() {
2013-11-11 11:48:31 +04:00
now := time.Now()
2013-11-16 08:19:35 +04:00
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
2013-10-29 03:39:26 +04:00
2013-11-11 17:55:44 +04:00
if s.err(err) == nil {
defer rows.Close()
for rows.Next() {
dest := reflect.New(dest_out.Type().Elem()).Interface()
s.err(rows.Scan(dest))
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest).Elem()))
2013-10-27 15:41:58 +04:00
}
}
}
2013-11-16 11:01:31 +04:00
return s
2013-10-27 15:41:58 +04:00
}
2013-10-27 16:54:23 +04:00
func (s *Do) primaryCondiation(value interface{}) string {
2013-10-29 03:39:26 +04:00
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value)
2013-10-27 16:54:23 +04:00
}
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
2013-11-12 11:21:21 +04:00
switch value := clause["query"].(type) {
2013-10-27 16:54:23 +04:00
case string:
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return s.primaryCondiation(s.addToVars(id))
} else {
2013-11-11 11:48:31 +04:00
str = "(" + value + ")"
2013-10-27 16:54:23 +04:00
}
case int, int64, int32:
2013-11-12 11:21:21 +04:00
return s.primaryCondiation(s.addToVars(value))
case sql.NullInt64:
2013-11-11 18:27:17 +04:00
return s.primaryCondiation(s.addToVars(value.Int64))
2013-10-29 11:37:58 +04:00
case []int64, []int, []int32, []string:
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
2013-11-11 18:27:17 +04:00
clause["args"] = []interface{}{value}
2013-10-29 13:52:37 +04:00
case map[string]interface{}:
var sqls []string
2013-11-11 18:27:17 +04:00
for key, value := range value {
2013-11-11 11:48:31 +04:00
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
2013-10-29 13:52:37 +04:00
}
return strings.Join(sqls, " AND ")
2013-10-29 13:37:45 +04:00
case interface{}:
2013-11-11 18:27:17 +04:00
m := &Model{data: value, do: s}
2013-10-29 13:37:45 +04:00
var sqls []string
2013-11-02 11:17:11 +04:00
for _, field := range m.columnsHasValue("other") {
2013-11-15 07:36:27 +04:00
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.dbName, s.addToVars(field.Value)))
2013-10-29 13:37:45 +04:00
}
return strings.Join(sqls, " AND ")
2013-10-27 16:54:23 +04:00
}
2013-10-27 08:00:39 +04:00
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
2013-10-29 18:00:06 +04:00
values := reflect.ValueOf(arg)
2013-10-27 08:00:39 +04:00
var temp_marks []string
2013-10-29 18:00:06 +04:00
for i := 0; i < values.Len(); i++ {
2013-11-16 11:01:31 +04:00
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface()))
2013-10-27 08:00:39 +04:00
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
2013-11-10 15:38:28 +04:00
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
2013-11-10 14:33:37 +04:00
str = strings.Replace(str, "?", s.addToVars(arg), 1)
2013-10-27 08:00:39 +04:00
}
}
2013-10-27 16:54:23 +04:00
return
2013-10-27 08:00:39 +04:00
}
2013-10-26 05:49:40 +04:00
2013-10-31 13:31:00 +04:00
func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
var not_equal_sql string
2013-11-12 11:21:21 +04:00
switch value := clause["query"].(type) {
2013-10-31 13:31:00 +04:00
case string:
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
2013-10-31 18:49:48 +04:00
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value)
2013-11-11 13:16:15 +04:00
not_equal_sql = fmt.Sprintf("NOT (%v)", value)
2013-10-31 13:31:00 +04:00
} else {
2013-11-11 13:16:15 +04:00
str = fmt.Sprintf("(%v NOT IN (?))", value)
not_equal_sql = fmt.Sprintf("(%v <> ?)", value)
2013-10-31 13:31:00 +04:00
}
case int, int64, int32:
2013-11-12 11:21:21 +04:00
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value)
2013-10-31 13:31:00 +04:00
case []int64, []int, []int32, []string:
2013-11-12 11:21:21 +04:00
if reflect.ValueOf(value).Len() > 0 {
2013-10-31 13:31:00 +04:00
str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb())
2013-11-12 11:21:21 +04:00
clause["args"] = []interface{}{value}
2013-10-31 13:31:00 +04:00
} else {
2013-10-31 14:12:18 +04:00
return ""
2013-10-31 13:31:00 +04:00
}
case map[string]interface{}:
var sqls []string
2013-11-11 18:27:17 +04:00
for key, value := range value {
2013-11-11 11:48:31 +04:00
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
2013-10-31 13:31:00 +04:00
}
return strings.Join(sqls, " AND ")
case interface{}:
2013-11-12 11:21:21 +04:00
m := &Model{data: value, do: s}
2013-10-31 13:31:00 +04:00
var sqls []string
2013-11-02 11:17:11 +04:00
for _, field := range m.columnsHasValue("other") {
2013-11-15 07:36:27 +04:00
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.dbName, s.addToVars(field.Value)))
2013-10-31 13:31:00 +04:00
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
values := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < values.Len(); i++ {
2013-11-16 11:01:31 +04:00
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface()))
2013-10-31 13:31:00 +04:00
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
2013-11-10 14:33:37 +04:00
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
}
2013-11-10 14:33:37 +04:00
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
2013-10-31 13:31:00 +04:00
}
}
return
}
2013-11-16 11:01:31 +04:00
func (s *Do) where(where ...interface{}) *Do {
if len(where) > 0 {
s.search = s.search.clone().where(where[0], where[1:]...)
}
return s
}
2013-10-27 15:41:58 +04:00
func (s *Do) whereSql() (sql string) {
2013-10-29 06:19:20 +04:00
var primary_condiations, and_conditions, or_conditions []string
2013-11-16 09:19:29 +04:00
if !s.search.unscope && s.model.hasColumn("DeletedAt") {
2013-11-11 13:16:15 +04:00
primary_condiations = append(primary_condiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')")
2013-10-29 06:19:20 +04:00
}
2013-10-29 18:00:06 +04:00
2013-10-28 11:55:41 +04:00
if !s.model.primaryKeyZero() {
2013-10-29 06:19:20 +04:00
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
2013-10-27 08:00:39 +04:00
}
2013-11-16 09:19:29 +04:00
for _, clause := range s.search.whereClause {
2013-10-27 08:00:39 +04:00
and_conditions = append(and_conditions, s.buildWhereCondition(clause))
}
2013-11-16 09:19:29 +04:00
for _, clause := range s.search.orClause {
2013-10-27 08:00:39 +04:00
or_conditions = append(or_conditions, s.buildWhereCondition(clause))
}
2013-11-16 09:19:29 +04:00
for _, clause := range s.search.notClause {
2013-10-31 13:31:00 +04:00
and_conditions = append(and_conditions, s.buildNotCondition(clause))
}
2013-10-27 08:31:51 +04:00
or_sql := strings.Join(or_conditions, " OR ")
2013-11-12 11:21:21 +04:00
combined_sql := strings.Join(and_conditions, " AND ")
if len(combined_sql) > 0 {
2013-10-27 08:00:39 +04:00
if len(or_sql) > 0 {
2013-11-12 11:21:21 +04:00
combined_sql = combined_sql + " OR " + or_sql
2013-10-26 10:10:47 +04:00
}
2013-10-27 08:00:39 +04:00
} else {
2013-11-12 11:21:21 +04:00
combined_sql = or_sql
2013-10-26 10:10:47 +04:00
}
2013-10-26 19:30:17 +04:00
2013-10-29 06:19:20 +04:00
if len(primary_condiations) > 0 {
sql = "WHERE " + strings.Join(primary_condiations, " AND ")
2013-11-12 11:21:21 +04:00
if len(combined_sql) > 0 {
sql = sql + " AND (" + combined_sql + ")"
2013-10-27 08:00:39 +04:00
}
2013-11-12 11:21:21 +04:00
} else if len(combined_sql) > 0 {
sql = "WHERE " + combined_sql
2013-10-26 19:30:17 +04:00
}
2013-10-26 05:49:40 +04:00
return
}
2013-10-26 19:30:17 +04:00
2013-10-27 15:41:58 +04:00
func (s *Do) selectSql() string {
2013-11-16 09:19:29 +04:00
if len(s.search.selectStr) == 0 {
2013-11-11 11:48:31 +04:00
return "*"
2013-10-27 05:50:11 +04:00
} else {
2013-11-16 09:19:29 +04:00
return s.search.selectStr
2013-10-27 05:50:11 +04:00
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) orderSql() string {
2013-11-16 09:19:29 +04:00
if len(s.search.orders) == 0 {
2013-10-27 07:38:05 +04:00
return ""
2013-10-27 07:21:33 +04:00
} else {
2013-11-16 09:19:29 +04:00
return " ORDER BY " + strings.Join(s.search.orders, ",")
2013-10-27 07:21:33 +04:00
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) limitSql() string {
2013-11-16 09:19:29 +04:00
if len(s.search.limitStr) == 0 {
2013-10-27 07:38:05 +04:00
return ""
} else {
2013-11-16 09:19:29 +04:00
return " LIMIT " + s.search.limitStr
2013-10-27 07:38:05 +04:00
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) offsetSql() string {
2013-11-16 09:19:29 +04:00
if len(s.search.offsetStr) == 0 {
2013-10-27 07:44:47 +04:00
return ""
} else {
2013-11-16 09:19:29 +04:00
return " OFFSET " + s.search.offsetStr
2013-10-27 07:44:47 +04:00
}
}
2013-10-27 15:41:58 +04:00
func (s *Do) combinedSql() string {
2013-10-27 07:44:47 +04:00
return s.whereSql() + s.orderSql() + s.limitSql() + s.offsetSql()
2013-10-27 07:21:33 +04:00
}
2013-10-27 15:41:58 +04:00
func (s *Do) createTable() *Do {
2013-10-28 08:12:12 +04:00
var sqls []string
2013-11-12 11:21:21 +04:00
for _, field := range s.model.fields("migration") {
2013-11-14 17:38:35 +04:00
if len(field.sqlTag()) > 0 {
2013-11-15 07:36:27 +04:00
sqls = append(sqls, field.dbName+" "+field.sqlTag())
2013-11-02 10:12:18 +04:00
}
2013-10-28 08:12:12 +04:00
}
2013-10-29 18:00:06 +04:00
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))
2013-11-10 15:38:28 +04:00
2013-11-07 07:42:36 +04:00
s.exec()
2013-10-27 15:41:58 +04:00
return s
2013-10-26 19:30:17 +04:00
}
2013-11-01 11:01:39 +04:00
func (s *Do) dropTable() *Do {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("DROP TABLE %v", s.table())
2013-11-07 07:42:36 +04:00
s.exec()
2013-11-01 11:01:39 +04:00
return s
}
func (s *Do) updateColumn(column string, typ string) {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ)
s.exec()
}
func (s *Do) dropColumn(column string) {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column)
s.exec()
}
func (s *Do) addIndex(column string, names ...string) {
var index_name string
if len(names) > 0 {
index_name = names[0]
} else {
2013-11-16 08:19:35 +04:00
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
}
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column)
s.exec()
}
func (s *Do) removeIndex(index_name string) {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table())
s.exec()
}
2013-11-07 05:09:54 +04:00
func (s *Do) autoMigrate() *Do {
2013-11-07 07:42:36 +04:00
var table_name string
2013-11-16 08:19:35 +04:00
sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table()))
s.db.db.QueryRow(sql, s.sqlVars...).Scan(&table_name)
2013-11-07 07:42:36 +04:00
s.sqlVars = []interface{}{}
// If table doesn't exist
if len(table_name) == 0 {
s.createTable()
} else {
2013-11-12 11:21:21 +04:00
for _, field := range s.model.fields("migration") {
2013-11-07 07:42:36 +04:00
var column_name, data_type string
2013-11-16 08:19:35 +04:00
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.table()))
s.db.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
2013-11-07 07:42:36 +04:00
s.sqlVars = []interface{}{}
// If column doesn't exist
2013-11-14 17:38:35 +04:00
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
2013-11-16 08:19:35 +04:00
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag())
2013-11-07 07:42:36 +04:00
s.exec()
}
}
2013-11-07 05:09:54 +04:00
}
return s
}
2013-11-11 15:06:26 +04:00
func (s *Do) begin() *Do {
2013-11-16 08:19:35 +04:00
if db, ok := s.db.db.(sqlDb); ok {
2013-11-12 11:21:21 +04:00
if tx, err := db.Begin(); err == nil {
2013-11-16 08:19:35 +04:00
s.db.db = interface{}(tx).(sqlCommon)
2013-11-11 15:06:26 +04:00
s.startedTransaction = true
}
}
2013-11-11 15:06:26 +04:00
return s
}
2013-11-16 11:01:31 +04:00
func (s *Do) commit_or_rollback() *Do {
2013-11-11 15:06:26 +04:00
if s.startedTransaction {
2013-11-16 08:19:35 +04:00
if db, ok := s.db.db.(sqlTx); ok {
if s.db.hasError() {
2013-11-11 15:06:26 +04:00
db.Rollback()
} else {
db.Commit()
}
2013-11-16 11:01:31 +04:00
s.db.db = s.db.parent.db
2013-11-11 15:06:26 +04:00
}
}
2013-11-16 09:19:29 +04:00
return s
2013-11-16 07:36:30 +04:00
}
2013-11-16 11:01:31 +04:00
func (s *Do) initialize() *Do {
2013-11-16 09:19:29 +04:00
for _, clause := range s.search.whereClause {
2013-11-16 11:01:31 +04:00
s.updateAttrs(clause["query"])
}
for _, attrs := range s.search.initAttrs {
s.updateAttrs(attrs)
}
for _, attrs := range s.search.assignAttrs {
s.updateAttrs(attrs)
}
2013-11-16 11:01:31 +04:00
return s
}