mirror of https://github.com/go-gorm/gorm.git
Fix tests for postgres
This commit is contained in:
parent
e9ecf9c1aa
commit
6271cf0e55
|
@ -26,7 +26,7 @@ func Create(scope *Scope) {
|
||||||
var sqls, columns []string
|
var sqls, columns []string
|
||||||
|
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if len(field.SqlTag) > 0 && !field.IsIgnored && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
|
if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
|
||||||
columns = append(columns, scope.Quote(field.DBName))
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
sqls = append(sqls, scope.AddToVars(field.Value))
|
sqls = append(sqls, scope.AddToVars(field.Value))
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,16 +25,14 @@ func AssignUpdateAttributes(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BeforeUpdate(scope *Scope) {
|
func BeforeUpdate(scope *Scope) {
|
||||||
_, ok := scope.Get("gorm:update_column")
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
if !ok {
|
|
||||||
scope.CallMethod("BeforeSave")
|
scope.CallMethod("BeforeSave")
|
||||||
scope.CallMethod("BeforeUpdate")
|
scope.CallMethod("BeforeUpdate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTimeStampWhenUpdate(scope *Scope) {
|
func UpdateTimeStampWhenUpdate(scope *Scope) {
|
||||||
_, ok := scope.Get("gorm:update_column")
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
if !ok {
|
|
||||||
scope.SetColumn("UpdatedAt", NowFunc())
|
scope.SetColumn("UpdatedAt", NowFunc())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -50,7 +48,7 @@ func Update(scope *Scope) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if !field.IsPrimaryKey && len(field.SqlTag) > 0 && !field.IsIgnored {
|
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
2
field.go
2
field.go
|
@ -19,8 +19,8 @@ type Field struct {
|
||||||
Field reflect.Value
|
Field reflect.Value
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Tag reflect.StructTag
|
Tag reflect.StructTag
|
||||||
SqlTag string
|
|
||||||
Relationship *relationship
|
Relationship *relationship
|
||||||
|
IsNormal bool
|
||||||
IsBlank bool
|
IsBlank bool
|
||||||
IsIgnored bool
|
IsIgnored bool
|
||||||
IsPrimaryKey bool
|
IsPrimaryKey bool
|
||||||
|
|
53
scope.go
53
scope.go
|
@ -95,7 +95,20 @@ func (scope *Scope) PrimaryKey() string {
|
||||||
return scope.primaryKey
|
return scope.primaryKey
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.primaryKey = ToSnake(GetPrimaryKey(scope.Value))
|
var indirectValue = scope.IndirectValue()
|
||||||
|
|
||||||
|
clone := scope
|
||||||
|
if indirectValue.Kind() == reflect.Slice {
|
||||||
|
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range clone.Fields() {
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
scope.primaryKey = field.DBName
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return scope.primaryKey
|
return scope.primaryKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,8 +143,12 @@ func (scope *Scope) SetColumn(column string, value interface{}) bool {
|
||||||
if scope.Value == nil {
|
if scope.Value == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value)
|
if field.Name == column || field.DBName == column {
|
||||||
|
return setFieldValue(field.Field, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallMethod invoke method with necessary argument
|
// CallMethod invoke method with necessary argument
|
||||||
|
@ -262,13 +279,19 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
|
|
||||||
// Search for primary key tag identifier
|
// Search for primary key tag identifier
|
||||||
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
|
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
|
||||||
|
if _, ok := settings["PRIMARY_KEY"]; ok {
|
||||||
if scope.PrimaryKey() == field.DBName {
|
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
}
|
}
|
||||||
|
|
||||||
field.Tag = fieldStruct.Tag
|
field.Tag = fieldStruct.Tag
|
||||||
field.SqlTag = scope.sqlTagForField(&field)
|
|
||||||
|
tagIdentifier := "sql"
|
||||||
|
if scope.db != nil {
|
||||||
|
tagIdentifier = scope.db.parent.tagIdentifier
|
||||||
|
}
|
||||||
|
if fieldStruct.Tag.Get(tagIdentifier) == "-" {
|
||||||
|
field.IsIgnored = true
|
||||||
|
}
|
||||||
|
|
||||||
if !field.IsIgnored {
|
if !field.IsIgnored {
|
||||||
// parse association
|
// parse association
|
||||||
|
@ -311,6 +334,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
if many2many != "" {
|
if many2many != "" {
|
||||||
field.Relationship.Kind = "many_to_many"
|
field.Relationship.Kind = "many_to_many"
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
field.IsNormal = true
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
embedded := settings["EMBEDDED"]
|
embedded := settings["EMBEDDED"]
|
||||||
|
@ -321,7 +346,9 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
}
|
}
|
||||||
return fields
|
return fields
|
||||||
} else if !field.IsTime() && !field.IsScanner() {
|
} else if field.IsTime() || field.IsScanner() {
|
||||||
|
field.IsNormal = true
|
||||||
|
} else {
|
||||||
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
||||||
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
||||||
} else if scope.HasColumn(foreignKey) {
|
} else if scope.HasColumn(foreignKey) {
|
||||||
|
@ -335,6 +362,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
field.IsNormal = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return []*Field{&field}
|
return []*Field{&field}
|
||||||
|
@ -345,12 +374,16 @@ func (scope *Scope) Fields() map[string]*Field {
|
||||||
var fields = map[string]*Field{}
|
var fields = map[string]*Field{}
|
||||||
if scope.IndirectValue().IsValid() {
|
if scope.IndirectValue().IsValid() {
|
||||||
scopeTyp := scope.IndirectValue().Type()
|
scopeTyp := scope.IndirectValue().Type()
|
||||||
|
var hasPrimaryKey = false
|
||||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||||
fieldStruct := scopeTyp.Field(i)
|
fieldStruct := scopeTyp.Field(i)
|
||||||
if !ast.IsExported(fieldStruct.Name) {
|
if !ast.IsExported(fieldStruct.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, field := range scope.fieldFromStruct(fieldStruct) {
|
for _, field := range scope.fieldFromStruct(fieldStruct) {
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
hasPrimaryKey = true
|
||||||
|
}
|
||||||
if _, ok := fields[field.DBName]; ok {
|
if _, ok := fields[field.DBName]; ok {
|
||||||
panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum()))
|
panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum()))
|
||||||
} else {
|
} else {
|
||||||
|
@ -358,6 +391,12 @@ func (scope *Scope) Fields() map[string]*Field {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hasPrimaryKey {
|
||||||
|
if field, ok := fields["id"]; ok {
|
||||||
|
field.IsPrimaryKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
|
@ -305,11 +305,6 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
|
||||||
var size = 255
|
var size = 255
|
||||||
|
|
||||||
fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier)
|
fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier)
|
||||||
if fieldTag == "-" {
|
|
||||||
field.IsIgnored = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var setting = parseTagSetting(fieldTag)
|
var setting = parseTagSetting(fieldTag)
|
||||||
|
|
||||||
if value, ok := setting["SIZE"]; ok {
|
if value, ok := setting["SIZE"]; ok {
|
||||||
|
@ -481,8 +476,9 @@ func (scope *Scope) createJoinTable(field *Field) {
|
||||||
func (scope *Scope) createTable() *Scope {
|
func (scope *Scope) createTable() *Scope {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if !field.IsIgnored && len(field.SqlTag) > 0 {
|
if field.IsNormal {
|
||||||
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
|
sqlTag := scope.sqlTagForField(field)
|
||||||
|
sqls = append(sqls, scope.Quote(field.DBName)+" "+sqlTag)
|
||||||
}
|
}
|
||||||
scope.createJoinTable(field)
|
scope.createJoinTable(field)
|
||||||
}
|
}
|
||||||
|
@ -535,8 +531,9 @@ func (scope *Scope) autoMigrate() *Scope {
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||||
if len(field.SqlTag) > 0 && !field.IsIgnored {
|
if field.IsNormal {
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec()
|
sqlTag := scope.sqlTagForField(field)
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scope.createJoinTable(field)
|
scope.createJoinTable(field)
|
||||||
|
|
32
utils.go
32
utils.go
|
@ -2,7 +2,6 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"go/ast"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -91,37 +90,6 @@ func SnakeToUpperCamel(s string) string {
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPrimaryKey(value interface{}) string {
|
|
||||||
var indirectValue = reflect.Indirect(reflect.ValueOf(value))
|
|
||||||
|
|
||||||
if indirectValue.Kind() == reflect.Slice {
|
|
||||||
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if indirectValue.IsValid() {
|
|
||||||
hasId := false
|
|
||||||
scopeTyp := indirectValue.Type()
|
|
||||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
|
||||||
fieldStruct := scopeTyp.Field(i)
|
|
||||||
if !ast.IsExported(fieldStruct.Name) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
|
|
||||||
if _, ok := settings["PRIMARY_KEY"]; ok {
|
|
||||||
return fieldStruct.Name
|
|
||||||
} else if fieldStruct.Name == "Id" {
|
|
||||||
hasId = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasId {
|
|
||||||
return "Id"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTagSetting(str string) map[string]string {
|
func parseTagSetting(str string) map[string]string {
|
||||||
tags := strings.Split(str, ";")
|
tags := strings.Split(str, ";")
|
||||||
setting := map[string]string{}
|
setting := map[string]string{}
|
||||||
|
|
Loading…
Reference in New Issue