mirror of https://github.com/go-gorm/gorm.git
Fix tests for postgres
This commit is contained in:
parent
e9ecf9c1aa
commit
6271cf0e55
|
@ -26,7 +26,7 @@ func Create(scope *Scope) {
|
|||
var sqls, columns []string
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if len(field.SqlTag) > 0 && !field.IsIgnored && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
|
||||
if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Value))
|
||||
}
|
||||
|
|
|
@ -25,16 +25,14 @@ func AssignUpdateAttributes(scope *Scope) {
|
|||
}
|
||||
|
||||
func BeforeUpdate(scope *Scope) {
|
||||
_, ok := scope.Get("gorm:update_column")
|
||||
if !ok {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.CallMethod("BeforeSave")
|
||||
scope.CallMethod("BeforeUpdate")
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateTimeStampWhenUpdate(scope *Scope) {
|
||||
_, ok := scope.Get("gorm:update_column")
|
||||
if !ok {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.SetColumn("UpdatedAt", NowFunc())
|
||||
}
|
||||
}
|
||||
|
@ -50,7 +48,7 @@ func Update(scope *Scope) {
|
|||
}
|
||||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
if !field.IsPrimaryKey && len(field.SqlTag) > 0 && !field.IsIgnored {
|
||||
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
|
||||
}
|
||||
}
|
||||
|
|
2
field.go
2
field.go
|
@ -19,8 +19,8 @@ type Field struct {
|
|||
Field reflect.Value
|
||||
Value interface{}
|
||||
Tag reflect.StructTag
|
||||
SqlTag string
|
||||
Relationship *relationship
|
||||
IsNormal bool
|
||||
IsBlank bool
|
||||
IsIgnored bool
|
||||
IsPrimaryKey bool
|
||||
|
|
53
scope.go
53
scope.go
|
@ -95,7 +95,20 @@ func (scope *Scope) PrimaryKey() string {
|
|||
return scope.primaryKey
|
||||
}
|
||||
|
||||
scope.primaryKey = ToSnake(GetPrimaryKey(scope.Value))
|
||||
var indirectValue = scope.IndirectValue()
|
||||
|
||||
clone := scope
|
||||
if indirectValue.Kind() == reflect.Slice {
|
||||
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
|
||||
}
|
||||
|
||||
for _, field := range clone.Fields() {
|
||||
if field.IsPrimaryKey {
|
||||
scope.primaryKey = field.DBName
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return scope.primaryKey
|
||||
}
|
||||
|
||||
|
@ -130,8 +143,12 @@ func (scope *Scope) SetColumn(column string, value interface{}) bool {
|
|||
if scope.Value == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value)
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == column || field.DBName == column {
|
||||
return setFieldValue(field.Field, value)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CallMethod invoke method with necessary argument
|
||||
|
@ -262,13 +279,19 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
|||
|
||||
// Search for primary key tag identifier
|
||||
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
|
||||
|
||||
if scope.PrimaryKey() == field.DBName {
|
||||
if _, ok := settings["PRIMARY_KEY"]; ok {
|
||||
field.IsPrimaryKey = true
|
||||
}
|
||||
|
||||
field.Tag = fieldStruct.Tag
|
||||
field.SqlTag = scope.sqlTagForField(&field)
|
||||
|
||||
tagIdentifier := "sql"
|
||||
if scope.db != nil {
|
||||
tagIdentifier = scope.db.parent.tagIdentifier
|
||||
}
|
||||
if fieldStruct.Tag.Get(tagIdentifier) == "-" {
|
||||
field.IsIgnored = true
|
||||
}
|
||||
|
||||
if !field.IsIgnored {
|
||||
// parse association
|
||||
|
@ -311,6 +334,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
|||
if many2many != "" {
|
||||
field.Relationship.Kind = "many_to_many"
|
||||
}
|
||||
} else {
|
||||
field.IsNormal = true
|
||||
}
|
||||
case reflect.Struct:
|
||||
embedded := settings["EMBEDDED"]
|
||||
|
@ -321,7 +346,9 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
|||
fields = append(fields, field)
|
||||
}
|
||||
return fields
|
||||
} else if !field.IsTime() && !field.IsScanner() {
|
||||
} else if field.IsTime() || field.IsScanner() {
|
||||
field.IsNormal = true
|
||||
} else {
|
||||
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
||||
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
||||
} else if scope.HasColumn(foreignKey) {
|
||||
|
@ -335,6 +362,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
|||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
field.IsNormal = true
|
||||
}
|
||||
}
|
||||
return []*Field{&field}
|
||||
|
@ -345,12 +374,16 @@ func (scope *Scope) Fields() map[string]*Field {
|
|||
var fields = map[string]*Field{}
|
||||
if scope.IndirectValue().IsValid() {
|
||||
scopeTyp := scope.IndirectValue().Type()
|
||||
var hasPrimaryKey = false
|
||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||
fieldStruct := scopeTyp.Field(i)
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
for _, field := range scope.fieldFromStruct(fieldStruct) {
|
||||
if field.IsPrimaryKey {
|
||||
hasPrimaryKey = true
|
||||
}
|
||||
if _, ok := fields[field.DBName]; ok {
|
||||
panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum()))
|
||||
} else {
|
||||
|
@ -358,6 +391,12 @@ func (scope *Scope) Fields() map[string]*Field {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasPrimaryKey {
|
||||
if field, ok := fields["id"]; ok {
|
||||
field.IsPrimaryKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
|
|
@ -305,11 +305,6 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
|
|||
var size = 255
|
||||
|
||||
fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier)
|
||||
if fieldTag == "-" {
|
||||
field.IsIgnored = true
|
||||
return
|
||||
}
|
||||
|
||||
var setting = parseTagSetting(fieldTag)
|
||||
|
||||
if value, ok := setting["SIZE"]; ok {
|
||||
|
@ -481,8 +476,9 @@ func (scope *Scope) createJoinTable(field *Field) {
|
|||
func (scope *Scope) createTable() *Scope {
|
||||
var sqls []string
|
||||
for _, field := range scope.Fields() {
|
||||
if !field.IsIgnored && len(field.SqlTag) > 0 {
|
||||
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.sqlTagForField(field)
|
||||
sqls = append(sqls, scope.Quote(field.DBName)+" "+sqlTag)
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
|
@ -535,8 +531,9 @@ func (scope *Scope) autoMigrate() *Scope {
|
|||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||
if len(field.SqlTag) > 0 && !field.IsIgnored {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec()
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.sqlTagForField(field)
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
|
|
32
utils.go
32
utils.go
|
@ -2,7 +2,6 @@ package gorm
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -91,37 +90,6 @@ func SnakeToUpperCamel(s string) string {
|
|||
return u
|
||||
}
|
||||
|
||||
func GetPrimaryKey(value interface{}) string {
|
||||
var indirectValue = reflect.Indirect(reflect.ValueOf(value))
|
||||
|
||||
if indirectValue.Kind() == reflect.Slice {
|
||||
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if indirectValue.IsValid() {
|
||||
hasId := false
|
||||
scopeTyp := indirectValue.Type()
|
||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||
fieldStruct := scopeTyp.Field(i)
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
|
||||
if _, ok := settings["PRIMARY_KEY"]; ok {
|
||||
return fieldStruct.Name
|
||||
} else if fieldStruct.Name == "Id" {
|
||||
hasId = true
|
||||
}
|
||||
}
|
||||
if hasId {
|
||||
return "Id"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseTagSetting(str string) map[string]string {
|
||||
tags := strings.Split(str, ";")
|
||||
setting := map[string]string{}
|
||||
|
|
Loading…
Reference in New Issue