Fix tests for postgres

This commit is contained in:
Jinzhu 2014-08-30 22:39:28 +08:00
parent e9ecf9c1aa
commit 6271cf0e55
6 changed files with 57 additions and 55 deletions

View File

@ -26,7 +26,7 @@ func Create(scope *Scope) {
var sqls, columns []string var sqls, columns []string
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if len(field.SqlTag) > 0 && !field.IsIgnored && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) { if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
columns = append(columns, scope.Quote(field.DBName)) columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Value)) sqls = append(sqls, scope.AddToVars(field.Value))
} }

View File

@ -25,16 +25,14 @@ func AssignUpdateAttributes(scope *Scope) {
} }
func BeforeUpdate(scope *Scope) { func BeforeUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column") if _, ok := scope.Get("gorm:update_column"); !ok {
if !ok {
scope.CallMethod("BeforeSave") scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeUpdate") scope.CallMethod("BeforeUpdate")
} }
} }
func UpdateTimeStampWhenUpdate(scope *Scope) { func UpdateTimeStampWhenUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column") if _, ok := scope.Get("gorm:update_column"); !ok {
if !ok {
scope.SetColumn("UpdatedAt", NowFunc()) scope.SetColumn("UpdatedAt", NowFunc())
} }
} }
@ -50,7 +48,7 @@ func Update(scope *Scope) {
} }
} else { } else {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if !field.IsPrimaryKey && len(field.SqlTag) > 0 && !field.IsIgnored { if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
} }
} }

View File

@ -19,8 +19,8 @@ type Field struct {
Field reflect.Value Field reflect.Value
Value interface{} Value interface{}
Tag reflect.StructTag Tag reflect.StructTag
SqlTag string
Relationship *relationship Relationship *relationship
IsNormal bool
IsBlank bool IsBlank bool
IsIgnored bool IsIgnored bool
IsPrimaryKey bool IsPrimaryKey bool

View File

@ -95,7 +95,20 @@ func (scope *Scope) PrimaryKey() string {
return scope.primaryKey return scope.primaryKey
} }
scope.primaryKey = ToSnake(GetPrimaryKey(scope.Value)) var indirectValue = scope.IndirectValue()
clone := scope
if indirectValue.Kind() == reflect.Slice {
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
}
for _, field := range clone.Fields() {
if field.IsPrimaryKey {
scope.primaryKey = field.DBName
break
}
}
return scope.primaryKey return scope.primaryKey
} }
@ -130,8 +143,12 @@ func (scope *Scope) SetColumn(column string, value interface{}) bool {
if scope.Value == nil { if scope.Value == nil {
return false return false
} }
for _, field := range scope.Fields() {
return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value) if field.Name == column || field.DBName == column {
return setFieldValue(field.Field, value)
}
}
return false
} }
// CallMethod invoke method with necessary argument // CallMethod invoke method with necessary argument
@ -262,13 +279,19 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
// Search for primary key tag identifier // Search for primary key tag identifier
settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
if _, ok := settings["PRIMARY_KEY"]; ok {
if scope.PrimaryKey() == field.DBName {
field.IsPrimaryKey = true field.IsPrimaryKey = true
} }
field.Tag = fieldStruct.Tag field.Tag = fieldStruct.Tag
field.SqlTag = scope.sqlTagForField(&field)
tagIdentifier := "sql"
if scope.db != nil {
tagIdentifier = scope.db.parent.tagIdentifier
}
if fieldStruct.Tag.Get(tagIdentifier) == "-" {
field.IsIgnored = true
}
if !field.IsIgnored { if !field.IsIgnored {
// parse association // parse association
@ -311,6 +334,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
if many2many != "" { if many2many != "" {
field.Relationship.Kind = "many_to_many" field.Relationship.Kind = "many_to_many"
} }
} else {
field.IsNormal = true
} }
case reflect.Struct: case reflect.Struct:
embedded := settings["EMBEDDED"] embedded := settings["EMBEDDED"]
@ -321,7 +346,9 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
fields = append(fields, field) fields = append(fields, field)
} }
return fields return fields
} else if !field.IsTime() && !field.IsScanner() { } else if field.IsTime() || field.IsScanner() {
field.IsNormal = true
} else {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") { if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"} field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
} else if scope.HasColumn(foreignKey) { } else if scope.HasColumn(foreignKey) {
@ -335,6 +362,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
} }
} }
} }
default:
field.IsNormal = true
} }
} }
return []*Field{&field} return []*Field{&field}
@ -345,12 +374,16 @@ func (scope *Scope) Fields() map[string]*Field {
var fields = map[string]*Field{} var fields = map[string]*Field{}
if scope.IndirectValue().IsValid() { if scope.IndirectValue().IsValid() {
scopeTyp := scope.IndirectValue().Type() scopeTyp := scope.IndirectValue().Type()
var hasPrimaryKey = false
for i := 0; i < scopeTyp.NumField(); i++ { for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i) fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) { if !ast.IsExported(fieldStruct.Name) {
continue continue
} }
for _, field := range scope.fieldFromStruct(fieldStruct) { for _, field := range scope.fieldFromStruct(fieldStruct) {
if field.IsPrimaryKey {
hasPrimaryKey = true
}
if _, ok := fields[field.DBName]; ok { if _, ok := fields[field.DBName]; ok {
panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum())) panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum()))
} else { } else {
@ -358,6 +391,12 @@ func (scope *Scope) Fields() map[string]*Field {
} }
} }
} }
if !hasPrimaryKey {
if field, ok := fields["id"]; ok {
field.IsPrimaryKey = true
}
}
} }
return fields return fields
} }

View File

@ -305,11 +305,6 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
var size = 255 var size = 255
fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier) fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier)
if fieldTag == "-" {
field.IsIgnored = true
return
}
var setting = parseTagSetting(fieldTag) var setting = parseTagSetting(fieldTag)
if value, ok := setting["SIZE"]; ok { if value, ok := setting["SIZE"]; ok {
@ -481,8 +476,9 @@ func (scope *Scope) createJoinTable(field *Field) {
func (scope *Scope) createTable() *Scope { func (scope *Scope) createTable() *Scope {
var sqls []string var sqls []string
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if !field.IsIgnored && len(field.SqlTag) > 0 { if field.IsNormal {
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) sqlTag := scope.sqlTagForField(field)
sqls = append(sqls, scope.Quote(field.DBName)+" "+sqlTag)
} }
scope.createJoinTable(field) scope.createJoinTable(field)
} }
@ -535,8 +531,9 @@ func (scope *Scope) autoMigrate() *Scope {
} else { } else {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if len(field.SqlTag) > 0 && !field.IsIgnored { if field.IsNormal {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec() sqlTag := scope.sqlTagForField(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
} }
} }
scope.createJoinTable(field) scope.createJoinTable(field)

View File

@ -2,7 +2,6 @@ package gorm
import ( import (
"bytes" "bytes"
"go/ast"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
@ -91,37 +90,6 @@ func SnakeToUpperCamel(s string) string {
return u return u
} }
func GetPrimaryKey(value interface{}) string {
var indirectValue = reflect.Indirect(reflect.ValueOf(value))
if indirectValue.Kind() == reflect.Slice {
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
}
if indirectValue.IsValid() {
hasId := false
scopeTyp := indirectValue.Type()
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
if _, ok := settings["PRIMARY_KEY"]; ok {
return fieldStruct.Name
} else if fieldStruct.Name == "Id" {
hasId = true
}
}
if hasId {
return "Id"
}
}
return ""
}
func parseTagSetting(str string) map[string]string { func parseTagSetting(str string) map[string]string {
tags := strings.Split(str, ";") tags := strings.Split(str, ";")
setting := map[string]string{} setting := map[string]string{}