diff --git a/field.go b/field.go index 1d20e4d9..4575bd6c 100644 --- a/field.go +++ b/field.go @@ -23,13 +23,13 @@ type Field struct { } func (f *Field) IsScanner() bool { - _, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner) - return is_scanner + _, isScanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner) + return isScanner } func (f *Field) IsTime() bool { - _, is_time := f.Value.(time.Time) - return is_time + _, isTime := f.Value.(time.Time) + return isTime } func parseSqlTag(str string) (typ string, additionalType string, size int) { diff --git a/scope.go b/scope.go index 05e45fc7..a77baa14 100644 --- a/scope.go +++ b/scope.go @@ -82,7 +82,7 @@ func (scope *Scope) PrimaryKey() string { return scope.primaryKey } - scope.primaryKey = scope.getPrimaryKey() + scope.primaryKey = ToSnake(GetPrimaryKey(scope.Value)) return scope.primaryKey } diff --git a/scope_private.go b/scope_private.go index 6298c3d7..0fe406c0 100644 --- a/scope_private.go +++ b/scope_private.go @@ -5,7 +5,6 @@ import ( "database/sql/driver" "errors" "fmt" - "go/ast" "reflect" "regexp" "strconv" @@ -466,33 +465,3 @@ func (scope *Scope) autoMigrate() *Scope { } return scope } - -func (scope *Scope) getPrimaryKey() string { - var indirectValue reflect.Value - - indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value)) - - if indirectValue.Kind() == reflect.Slice { - indirectValue = reflect.New(indirectValue.Type().Elem()).Elem() - } - - if !indirectValue.IsValid() { - return "id" - } - - scopeTyp := indirectValue.Type() - for i := 0; i < scopeTyp.NumField(); i++ { - fieldStruct := scopeTyp.Field(i) - if !ast.IsExported(fieldStruct.Name) { - continue - } - - // if primaryKey tag found, return column name - if fieldStruct.Tag.Get("primaryKey") != "" { - return ToSnake(fieldStruct.Name) - } - } - - //If primaryKey tag not found, fallback to id - return "id" -} diff --git a/utils.go b/utils.go index a9091aeb..afef469e 100644 --- a/utils.go +++ b/utils.go @@ -2,12 +2,8 @@ package gorm import ( "bytes" - "database/sql" - "fmt" - "os" + "go/ast" "reflect" - "regexp" - "runtime" "strings" "sync" ) @@ -95,82 +91,26 @@ func SnakeToUpperCamel(s string) string { return u } -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } +func GetPrimaryKey(value interface{}) string { + var indirectValue = reflect.Indirect(reflect.ValueOf(value)) - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } + if indirectValue.Kind() == reflect.Slice { + indirectValue = reflect.New(indirectValue.Type().Elem()).Elem() } - return -} -func fileWithLineNum() string { - for i := 1; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { - return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) - } - } - return "" -} - -func setFieldValue(field reflect.Value, value interface{}) (result bool) { - result = false - if field.IsValid() && field.CanAddr() { - result = true - if scanner, ok := field.Addr().Interface().(sql.Scanner); ok { - scanner.Scan(value) - } else if reflect.TypeOf(value).ConvertibleTo(field.Type()) { - field.Set(reflect.ValueOf(value).Convert(field.Type())) - } else { - result = false - } - } - return -} - -func isBlank(value reflect.Value) bool { - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func convertInterfaceToMap(values interface{}) map[string]interface{} { - attrs := map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - for k, v := range value { - attrs[ToSnake(k)] = v - } - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v) { - attrs[key] = value + if indirectValue.IsValid() { + scopeTyp := indirectValue.Type() + for i := 0; i < scopeTyp.NumField(); i++ { + fieldStruct := scopeTyp.Field(i) + if !ast.IsExported(fieldStruct.Name) { + continue } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToSnake(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - scope := Scope{Value: values} - for _, field := range scope.Fields() { - if !field.IsBlank { - attrs[field.DBName] = field.Value - } + if fieldStruct.Tag.Get("primaryKey") != "" { + return fieldStruct.Name } } } - return attrs + + return "Id" } diff --git a/utils_private.go b/utils_private.go new file mode 100644 index 00000000..75c38fcd --- /dev/null +++ b/utils_private.go @@ -0,0 +1,91 @@ +package gorm + +import ( + "database/sql" + "fmt" + "os" + "reflect" + "regexp" + "runtime" + "strings" +) + +func fileWithLineNum() string { + for i := 1; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { + return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) + } + } + return "" +} + +func setFieldValue(field reflect.Value, value interface{}) (result bool) { + result = false + if field.IsValid() && field.CanAddr() { + result = true + if scanner, ok := field.Addr().Interface().(sql.Scanner); ok { + scanner.Scan(value) + } else if reflect.TypeOf(value).ConvertibleTo(field.Type()) { + field.Set(reflect.ValueOf(value).Convert(field.Type())) + } else { + result = false + } + } + return +} + +func isBlank(value reflect.Value) bool { + return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) +} + +func toSearchableMap(attrs ...interface{}) (result interface{}) { + if len(attrs) > 1 { + if str, ok := attrs[0].(string); ok { + result = map[string]interface{}{str: attrs[1]} + } + } else if len(attrs) == 1 { + if attr, ok := attrs[0].(map[string]interface{}); ok { + result = attr + } + + if attr, ok := attrs[0].(interface{}); ok { + result = attr + } + } + return +} + +func convertInterfaceToMap(values interface{}) map[string]interface{} { + attrs := map[string]interface{}{} + + switch value := values.(type) { + case map[string]interface{}: + for k, v := range value { + attrs[ToSnake(k)] = v + } + case []interface{}: + for _, v := range value { + for key, value := range convertInterfaceToMap(v) { + attrs[key] = value + } + } + case interface{}: + reflectValue := reflect.ValueOf(values) + + switch reflectValue.Kind() { + case reflect.Map: + for _, key := range reflectValue.MapKeys() { + attrs[ToSnake(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + } + default: + scope := Scope{Value: values} + for _, field := range scope.Fields() { + if !field.IsBlank { + attrs[field.DBName] = field.Value + } + } + } + } + return attrs +}