package gorm import ( "bytes" "database/sql/driver" "fmt" "reflect" "regexp" "runtime" "strings" "sync" "time" ) // NowFunc returns current time, this function is exported in order to be able // to give the flexibility to the developer to customize it according to their // needs, e.g: // gorm.NowFunc = func() time.Time { // return time.Now().UTC() // } var NowFunc = func() time.Time { return time.Now() } // Copied from golint var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) func init() { var commonInitialismsForReplacer []string for _, initialism := range commonInitialisms { commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) } commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } type safeMap struct { m map[string]string l *sync.RWMutex } func (s *safeMap) Set(key string, value string) { s.l.Lock() defer s.l.Unlock() s.m[key] = value } func (s *safeMap) Get(key string) string { s.l.RLock() defer s.l.RUnlock() return s.m[key] } func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } var smap = newSafeMap() type strCase bool const ( lower strCase = false upper strCase = true ) // ToDBName convert string to db name func ToDBName(name string) string { if v := smap.Get(name); v != "" { return v } if name == "" { return "" } var ( value = commonInitialismsReplacer.Replace(name) buf = bytes.NewBufferString("") lastCase, currCase, nextCase strCase ) for i, v := range value[:len(value)-1] { nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') if i > 0 { if currCase == upper { if lastCase == upper && nextCase == upper { buf.WriteRune(v) } else { if value[i-1] != '_' && value[i+1] != '_' { buf.WriteRune('_') } buf.WriteRune(v) } } else { buf.WriteRune(v) } } else { currCase = upper buf.WriteRune(v) } lastCase = currCase currCase = nextCase } buf.WriteByte(value[len(value)-1]) s := strings.ToLower(buf.String()) smap.Set(name, s) return s } // SQL expression type expr struct { expr string args []interface{} } // Expr generate raw SQL expression, for example: // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) func Expr(expression string, args ...interface{}) *expr { return &expr{expr: expression, args: args} } func indirect(reflectValue reflect.Value) reflect.Value { for reflectValue.Kind() == reflect.Ptr { reflectValue = reflectValue.Elem() } return reflectValue } func toQueryMarks(primaryValues [][]interface{}) string { var results []string for _, primaryValue := range primaryValues { var marks []string for range primaryValue { marks = append(marks, "?") } if len(marks) > 1 { results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) } else { results = append(results, strings.Join(marks, "")) } } return strings.Join(results, ",") } func toQueryCondition(scope *Scope, columns []string) string { var newColumns []string for _, column := range columns { newColumns = append(newColumns, scope.Quote(column)) } if len(columns) > 1 { return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) } return strings.Join(newColumns, ",") } func toQueryValues(values [][]interface{}) (results []interface{}) { for _, value := range values { for _, v := range value { results = append(results, v) } } return } func fileWithLineNum() string { for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { return fmt.Sprintf("%v:%v", file, line) } } return "" } func isBlank(value reflect.Value) bool { switch value.Kind() { case reflect.String: return value.Len() == 0 case reflect.Bool: return !value.Bool() case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return value.Int() == 0 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return value.Uint() == 0 case reflect.Float32, reflect.Float64: return value.Float() == 0 case reflect.Interface, reflect.Ptr: return value.IsNil() } return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) } func toSearchableMap(attrs ...interface{}) (result interface{}) { if len(attrs) > 1 { if str, ok := attrs[0].(string); ok { result = map[string]interface{}{str: attrs[1]} } } else if len(attrs) == 1 { if attr, ok := attrs[0].(map[string]interface{}); ok { result = attr } if attr, ok := attrs[0].(interface{}); ok { result = attr } } return } func equalAsString(a interface{}, b interface{}) bool { return toString(a) == toString(b) } func toString(str interface{}) string { if values, ok := str.([]interface{}); ok { var results []string for _, value := range values { results = append(results, toString(value)) } return strings.Join(results, "_") } else if bytes, ok := str.([]byte); ok { return string(bytes) } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { return fmt.Sprintf("%v", reflectValue.Interface()) } return "" } func makeSlice(elemType reflect.Type) interface{} { if elemType.Kind() == reflect.Slice { elemType = elemType.Elem() } sliceType := reflect.SliceOf(elemType) slice := reflect.New(sliceType) slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) return slice.Interface() } func strInSlice(a string, list []string) bool { for _, b := range list { if b == a { return true } } return false } // getValueFromFields return given fields's value func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { // If value is a nil pointer, Indirect returns a zero Value! // Therefor we need to check for a zero value, // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() } results = append(results, result) } } } return } func addExtraSpaceIfExist(str string) string { if str != "" { return " " + str } return "" }