Fix tests for postgres

This commit is contained in:
Jinzhu 2014-08-30 22:39:28 +08:00
parent e9ecf9c1aa
commit 6271cf0e55
6 changed files with 57 additions and 55 deletions

View File

@ -26,7 +26,7 @@ func Create(scope *Scope) {
var sqls, columns []string
for _, field := range scope.Fields() {
if len(field.SqlTag) > 0 && !field.IsIgnored && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Value))
}

View File

@ -25,16 +25,14 @@ func AssignUpdateAttributes(scope *Scope) {
}
func BeforeUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column")
if !ok {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeUpdate")
}
}
func UpdateTimeStampWhenUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column")
if !ok {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}
@ -50,7 +48,7 @@ func Update(scope *Scope) {
}
} else {
for _, field := range scope.Fields() {
if !field.IsPrimaryKey && len(field.SqlTag) > 0 && !field.IsIgnored {
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
}
}

View File

@ -19,8 +19,8 @@ type Field struct {
Field reflect.Value
Value interface{}
Tag reflect.StructTag
SqlTag string
Relationship *relationship
IsNormal bool
IsBlank bool
IsIgnored bool
IsPrimaryKey bool

View File

@ -95,7 +95,20 @@ func (scope *Scope) PrimaryKey() string {
return scope.primaryKey
}
scope.primaryKey = ToSnake(GetPrimaryKey(scope.Value))
var indirectValue = scope.IndirectValue()
clone := scope
if indirectValue.Kind() == reflect.Slice {
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
}
for _, field := range clone.Fields() {
if field.IsPrimaryKey {
scope.primaryKey = field.DBName
break
}
}
return scope.primaryKey
}
@ -130,8 +143,12 @@ func (scope *Scope) SetColumn(column string, value interface{}) bool {
if scope.Value == nil {
return false
}
return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value)
for _, field := range scope.Fields() {
if field.Name == column || field.DBName == column {
return setFieldValue(field.Field, value)
}
}
return false
}
// CallMethod invoke method with necessary argument
@ -262,13 +279,19 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
// Search for primary key tag identifier
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
if scope.PrimaryKey() == field.DBName {
if _, ok := settings["PRIMARY_KEY"]; ok {
field.IsPrimaryKey = true
}
field.Tag = fieldStruct.Tag
field.SqlTag = scope.sqlTagForField(&field)
tagIdentifier := "sql"
if scope.db != nil {
tagIdentifier = scope.db.parent.tagIdentifier
}
if fieldStruct.Tag.Get(tagIdentifier) == "-" {
field.IsIgnored = true
}
if !field.IsIgnored {
// parse association
@ -311,6 +334,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
if many2many != "" {
field.Relationship.Kind = "many_to_many"
}
} else {
field.IsNormal = true
}
case reflect.Struct:
embedded := settings["EMBEDDED"]
@ -321,7 +346,9 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
fields = append(fields, field)
}
return fields
} else if !field.IsTime() && !field.IsScanner() {
} else if field.IsTime() || field.IsScanner() {
field.IsNormal = true
} else {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
} else if scope.HasColumn(foreignKey) {
@ -335,6 +362,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
}
}
}
default:
field.IsNormal = true
}
}
return []*Field{&field}
@ -345,12 +374,16 @@ func (scope *Scope) Fields() map[string]*Field {
var fields = map[string]*Field{}
if scope.IndirectValue().IsValid() {
scopeTyp := scope.IndirectValue().Type()
var hasPrimaryKey = false
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
for _, field := range scope.fieldFromStruct(fieldStruct) {
if field.IsPrimaryKey {
hasPrimaryKey = true
}
if _, ok := fields[field.DBName]; ok {
panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum()))
} else {
@ -358,6 +391,12 @@ func (scope *Scope) Fields() map[string]*Field {
}
}
}
if !hasPrimaryKey {
if field, ok := fields["id"]; ok {
field.IsPrimaryKey = true
}
}
}
return fields
}

View File

@ -305,11 +305,6 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
var size = 255
fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier)
if fieldTag == "-" {
field.IsIgnored = true
return
}
var setting = parseTagSetting(fieldTag)
if value, ok := setting["SIZE"]; ok {
@ -481,8 +476,9 @@ func (scope *Scope) createJoinTable(field *Field) {
func (scope *Scope) createTable() *Scope {
var sqls []string
for _, field := range scope.Fields() {
if !field.IsIgnored && len(field.SqlTag) > 0 {
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
if field.IsNormal {
sqlTag := scope.sqlTagForField(field)
sqls = append(sqls, scope.Quote(field.DBName)+" "+sqlTag)
}
scope.createJoinTable(field)
}
@ -535,8 +531,9 @@ func (scope *Scope) autoMigrate() *Scope {
} else {
for _, field := range scope.Fields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if len(field.SqlTag) > 0 && !field.IsIgnored {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec()
if field.IsNormal {
sqlTag := scope.sqlTagForField(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
}
}
scope.createJoinTable(field)

View File

@ -2,7 +2,6 @@ package gorm
import (
"bytes"
"go/ast"
"reflect"
"strings"
"sync"
@ -91,37 +90,6 @@ func SnakeToUpperCamel(s string) string {
return u
}
func GetPrimaryKey(value interface{}) string {
var indirectValue = reflect.Indirect(reflect.ValueOf(value))
if indirectValue.Kind() == reflect.Slice {
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
}
if indirectValue.IsValid() {
hasId := false
scopeTyp := indirectValue.Type()
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
if _, ok := settings["PRIMARY_KEY"]; ok {
return fieldStruct.Name
} else if fieldStruct.Name == "Id" {
hasId = true
}
}
if hasId {
return "Id"
}
}
return ""
}
func parseTagSetting(str string) map[string]string {
tags := strings.Split(str, ";")
setting := map[string]string{}