Try to add support for NullInt64

This commit is contained in:
Jinzhu 2013-11-10 08:14:41 +08:00
parent 99b87315c4
commit 562bca71e4
4 changed files with 33 additions and 11 deletions

View File

@ -70,6 +70,7 @@ func (s *Chain) do(value interface{}) *Do {
do.specifiedTableName = s.specifiedTableName
do.unscoped = s.unscoped
do.singularTableName = s.singularTableName
do.debug = s.debug
s.value = value
do.setModel(value)

1
do.go
View File

@ -17,6 +17,7 @@ type Do struct {
driver string
guessedTableName string
specifiedTableName string
debug bool
Errors []error
model *Model

View File

@ -1081,7 +1081,7 @@ type Category struct {
type Post struct {
Id int64
CategoryId int64
CategoryId sql.NullInt64
MainCategoryId int64
Title string
Body string
@ -1124,7 +1124,7 @@ func TestSubStruct(t *testing.T) {
var p Post
db.First(&p, post.Id)
if post.CategoryId == 0 || p.CategoryId == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
t.Errorf("Category Id should exist")
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"database/sql"
"errors"
"fmt"
"go/ast"
@ -153,15 +154,19 @@ func (m *Model) fields(operation string) (fields []Field) {
if is_time {
field.SqlType = getSqlType(m.driver, field.Value, 0)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
} else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
switch value.Interface().(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
default:
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
} else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
}
field.afterAssociation = true
}
}
case reflect.Ptr:
@ -363,7 +368,22 @@ func setFieldValue(field reflect.Value, value interface{}) bool {
}
field.SetInt(reflect.ValueOf(value).Int())
default:
field.Set(reflect.ValueOf(value))
field_type := field.Type()
if field_type == reflect.TypeOf(value) {
field.Set(reflect.ValueOf(value))
} else if value == nil {
field.Set(reflect.Zero(field.Type()))
} else if field_type == reflect.TypeOf(sql.NullBool{}) {
field.Set(reflect.ValueOf(sql.NullBool{value.(bool), true}))
} else if field_type == reflect.TypeOf(sql.NullFloat64{}) {
field.Set(reflect.ValueOf(sql.NullFloat64{value.(float64), true}))
} else if field_type == reflect.TypeOf(sql.NullInt64{}) {
field.Set(reflect.ValueOf(sql.NullInt64{value.(int64), true}))
} else if field_type == reflect.TypeOf(sql.NullString{}) {
field.Set(reflect.ValueOf(sql.NullString{value.(string), true}))
} else {
field.Set(reflect.ValueOf(value))
}
}
return true
} else {