gorm/utils.go

288 lines
7.0 KiB
Go
Raw Normal View History

2013-10-25 14:04:48 +04:00
package gorm
2013-10-26 05:49:40 +04:00
import (
"bytes"
2016-02-14 18:29:06 +03:00
"database/sql/driver"
2016-01-16 07:18:04 +03:00
"fmt"
2016-01-18 07:20:27 +03:00
"reflect"
2016-02-14 18:29:06 +03:00
"regexp"
"runtime"
2013-10-26 05:49:40 +04:00
"strings"
2015-08-31 10:07:06 +03:00
"sync"
2016-03-07 09:54:20 +03:00
"time"
2013-10-26 05:49:40 +04:00
)
2016-03-07 09:54:20 +03:00
// NowFunc returns current time, this function is exported in order to be able
// to give the flexibility to the developer to customize it according to their
2016-03-07 16:09:05 +03:00
// needs, e.g:
// gorm.NowFunc = func() time.Time {
// return time.Now().UTC()
// }
2016-03-07 09:54:20 +03:00
var NowFunc = func() time.Time {
return time.Now()
}
2015-02-18 05:19:34 +03:00
// Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
2015-02-18 05:19:34 +03:00
var commonInitialismsReplacer *strings.Replacer
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
2015-02-18 05:19:34 +03:00
func init() {
var commonInitialismsForReplacer []string
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
2015-08-31 10:07:06 +03:00
type safeMap struct {
m map[string]string
l *sync.RWMutex
}
func (s *safeMap) Set(key string, value string) {
s.l.Lock()
defer s.l.Unlock()
s.m[key] = value
}
func (s *safeMap) Get(key string) string {
s.l.RLock()
defer s.l.RUnlock()
return s.m[key]
}
func newSafeMap() *safeMap {
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
}
var smap = newSafeMap()
2013-12-15 06:09:19 +04:00
2016-01-15 16:03:35 +03:00
type strCase bool
const (
2016-01-15 16:03:35 +03:00
lower strCase = false
upper strCase = true
)
2016-02-14 18:29:06 +03:00
// ToDBName convert string to db name
2015-02-18 05:19:34 +03:00
func ToDBName(name string) string {
2015-08-31 10:07:06 +03:00
if v := smap.Get(name); v != "" {
2013-11-17 04:28:30 +04:00
return v
}
2016-01-16 04:40:13 +03:00
if name == "" {
return ""
}
var (
value = commonInitialismsReplacer.Replace(name)
buf = bytes.NewBufferString("")
lastCase, currCase, nextCase, nextNumber strCase
)
for i, v := range value[:len(value)-1] {
2016-03-09 15:44:24 +03:00
nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9')
if i > 0 {
if currCase == upper {
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
buf.WriteRune(v)
} else {
2016-02-20 17:26:41 +03:00
if value[i-1] != '_' && value[i+1] != '_' {
buf.WriteRune('_')
}
buf.WriteRune(v)
}
} else {
buf.WriteRune(v)
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
buf.WriteRune('_')
}
}
} else {
currCase = upper
buf.WriteRune(v)
}
lastCase = currCase
currCase = nextCase
}
2013-11-17 04:28:30 +04:00
buf.WriteByte(value[len(value)-1])
2016-02-20 17:26:41 +03:00
s := strings.ToLower(buf.String())
2015-08-31 10:07:06 +03:00
smap.Set(name, s)
2013-11-17 04:28:30 +04:00
return s
}
2016-02-14 18:29:06 +03:00
// SQL expression
2015-02-24 17:06:35 +03:00
type expr struct {
expr string
args []interface{}
}
2016-03-07 16:09:05 +03:00
// Expr generate raw SQL expression, for example:
2016-02-14 18:29:06 +03:00
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
2015-02-24 17:06:35 +03:00
func Expr(expression string, args ...interface{}) *expr {
return &expr{expr: expression, args: args}
}
2016-01-16 07:18:04 +03:00
2016-01-18 07:20:27 +03:00
func indirect(reflectValue reflect.Value) reflect.Value {
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
return reflectValue
}
2016-01-16 07:18:04 +03:00
func toQueryMarks(primaryValues [][]interface{}) string {
var results []string
for _, primaryValue := range primaryValues {
var marks []string
for range primaryValue {
2016-01-16 07:18:04 +03:00
marks = append(marks, "?")
}
if len(marks) > 1 {
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
} else {
results = append(results, strings.Join(marks, ""))
}
}
return strings.Join(results, ",")
}
func toQueryCondition(scope *Scope, columns []string) string {
var newColumns []string
for _, column := range columns {
newColumns = append(newColumns, scope.Quote(column))
}
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
}
return strings.Join(newColumns, ",")
}
2016-01-17 12:46:56 +03:00
func toQueryValues(values [][]interface{}) (results []interface{}) {
for _, value := range values {
for _, v := range value {
results = append(results, v)
2016-01-16 07:18:04 +03:00
}
}
2016-01-17 12:46:56 +03:00
return
2016-01-16 07:18:04 +03:00
}
2016-02-14 18:29:06 +03:00
func fileWithLineNum() string {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
2016-02-14 18:29:06 +03:00
return fmt.Sprintf("%v:%v", file, line)
}
}
return ""
}
func isBlank(value reflect.Value) bool {
2017-01-25 12:42:15 +03:00
switch value.Kind() {
2017-02-02 03:58:28 +03:00
case reflect.String:
2017-01-25 12:42:15 +03:00
return value.Len() == 0
case reflect.Bool:
return !value.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return value.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return value.Uint() == 0
case reflect.Float32, reflect.Float64:
return value.Float() == 0
case reflect.Interface, reflect.Ptr:
return value.IsNil()
}
2016-02-14 18:29:06 +03:00
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
}
func toSearchableMap(attrs ...interface{}) (result interface{}) {
if len(attrs) > 1 {
if str, ok := attrs[0].(string); ok {
result = map[string]interface{}{str: attrs[1]}
}
} else if len(attrs) == 1 {
if attr, ok := attrs[0].(map[string]interface{}); ok {
result = attr
}
if attr, ok := attrs[0].(interface{}); ok {
result = attr
}
}
return
}
func equalAsString(a interface{}, b interface{}) bool {
return toString(a) == toString(b)
}
func toString(str interface{}) string {
if values, ok := str.([]interface{}); ok {
var results []string
for _, value := range values {
results = append(results, toString(value))
}
return strings.Join(results, "_")
} else if bytes, ok := str.([]byte); ok {
return string(bytes)
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
return fmt.Sprintf("%v", reflectValue.Interface())
}
return ""
}
func makeSlice(elemType reflect.Type) interface{} {
if elemType.Kind() == reflect.Slice {
elemType = elemType.Elem()
}
sliceType := reflect.SliceOf(elemType)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
func strInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
// getValueFromFields return given fields's value
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
for _, fieldName := range fieldNames {
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
result := fieldValue.Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
results = append(results, result)
}
}
}
return
}
func addExtraSpaceIfExist(str string) string {
if str != "" {
return " " + str
}
return ""
}