package gorm import ( "bytes" "database/sql/driver" "fmt" "reflect" "regexp" "runtime" "strings" "sync" ) // Copied from golint var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer func init() { var commonInitialismsForReplacer []string for _, initialism := range commonInitialisms { commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) } commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } type safeMap struct { m map[string]string l *sync.RWMutex } func (s *safeMap) Set(key string, value string) { s.l.Lock() defer s.l.Unlock() s.m[key] = value } func (s *safeMap) Get(key string) string { s.l.RLock() defer s.l.RUnlock() return s.m[key] } func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } var smap = newSafeMap() type strCase bool const ( lower strCase = false upper strCase = true ) // ToDBName convert string to db name func ToDBName(name string) string { if v := smap.Get(name); v != "" { return v } if name == "" { return "" } var ( value = commonInitialismsReplacer.Replace(name) buf = bytes.NewBufferString("") lastCase, currCase, nextCase strCase ) for i, v := range value[:len(value)-1] { nextCase = value[i+1] >= 'A' && value[i+1] <= 'Z' if i > 0 { if currCase == upper { if lastCase == upper && nextCase == upper { buf.WriteRune(v) } else { buf.WriteRune('_') buf.WriteRune(v) } } else { buf.WriteRune(v) } } else { currCase = upper buf.WriteRune(v) } lastCase = currCase currCase = nextCase } buf.WriteByte(value[len(value)-1]) s := strings.Replace(strings.ToLower(buf.String()), "__", "_", -1) smap.Set(name, s) return s } // SQL expression type expr struct { expr string args []interface{} } // Expr generate raw SQL expression for SQL, for example: // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) func Expr(expression string, args ...interface{}) *expr { return &expr{expr: expression, args: args} } func indirect(reflectValue reflect.Value) reflect.Value { for reflectValue.Kind() == reflect.Ptr { reflectValue = reflectValue.Elem() } return reflectValue } func toQueryMarks(primaryValues [][]interface{}) string { var results []string for _, primaryValue := range primaryValues { var marks []string for _ = range primaryValue { marks = append(marks, "?") } if len(marks) > 1 { results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) } else { results = append(results, strings.Join(marks, "")) } } return strings.Join(results, ",") } func toQueryCondition(scope *Scope, columns []string) string { var newColumns []string for _, column := range columns { newColumns = append(newColumns, scope.Quote(column)) } if len(columns) > 1 { return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) } return strings.Join(newColumns, ",") } func toQueryValues(values [][]interface{}) (results []interface{}) { for _, value := range values { for _, v := range value { results = append(results, v) } } return } func fileWithLineNum() string { for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { return fmt.Sprintf("%v:%v", file, line) } } return "" } func isBlank(value reflect.Value) bool { return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) } func toSearchableMap(attrs ...interface{}) (result interface{}) { if len(attrs) > 1 { if str, ok := attrs[0].(string); ok { result = map[string]interface{}{str: attrs[1]} } } else if len(attrs) == 1 { if attr, ok := attrs[0].(map[string]interface{}); ok { result = attr } if attr, ok := attrs[0].(interface{}); ok { result = attr } } return } func convertInterfaceToMap(values interface{}) map[string]interface{} { attrs := map[string]interface{}{} switch value := values.(type) { case map[string]interface{}: for k, v := range value { attrs[k] = v } case []interface{}: for _, v := range value { for key, value := range convertInterfaceToMap(v) { attrs[key] = value } } case interface{}: reflectValue := reflect.ValueOf(values) switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: for _, field := range (&Scope{Value: values}).Fields() { if !field.IsBlank && !field.IsIgnored { attrs[field.DBName] = field.Field.Interface() } } } } return attrs } func equalAsString(a interface{}, b interface{}) bool { return toString(a) == toString(b) } func toString(str interface{}) string { if values, ok := str.([]interface{}); ok { var results []string for _, value := range values { results = append(results, toString(value)) } return strings.Join(results, "_") } else if bytes, ok := str.([]byte); ok { return string(bytes) } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { return fmt.Sprintf("%v", reflectValue.Interface()) } return "" } func makeSlice(elemType reflect.Type) interface{} { if elemType.Kind() == reflect.Slice { elemType = elemType.Elem() } sliceType := reflect.SliceOf(elemType) slice := reflect.New(sliceType) slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) return slice.Interface() } func strInSlice(a string, list []string) bool { for _, b := range list { if b == a { return true } } return false } // getValueFromFields return given fields's value func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { // If value is a nil pointer, Indirect returns a zero Value! // Therefor we need to check for a zero value, // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() } results = append(results, result) } } } return } func addExtraSpaceIfExist(str string) string { if str != "" { return " " + str } return "" }