Support NullFloat64, NullInt64, NullBool, NullString

This commit is contained in:
Jinzhu 2013-11-10 08:57:11 +08:00
parent 562bca71e4
commit 8d97fdb172
3 changed files with 68 additions and 30 deletions

View File

@ -1124,6 +1124,7 @@ func TestSubStruct(t *testing.T) {
var p Post
db.First(&p, post.Id)
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
t.Errorf("Category Id should exist")
}
@ -1265,6 +1266,40 @@ func TestAutoMigration(t *testing.T) {
}
}
type NullValue struct {
Id int64
Name sql.NullString
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
}
func TestSqlNullValue(t *testing.T) {
db.DropTable(&NullValue{})
db.AutoMigrate(&NullValue{})
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
var nv NullValue
db.First(&nv, "name = ?", "hello")
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 {
t.Errorf("Should be able to fetch null value")
}
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
var nv2 NullValue
db.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 {
t.Errorf("Should be able to fetch null value")
}
}
func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}

View File

@ -106,10 +106,15 @@ func (m *Model) fields(operation string) (fields []Field) {
if is_time {
field.IsBlank = time_value.IsZero()
} else {
m := &Model{data: value.Interface(), driver: m.driver}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
switch value.Interface().(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
field.IsBlank = value.FieldByName("Valid").Interface().(bool)
default:
m := &Model{data: value.Interface(), driver: m.driver}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
}
}
}
}
@ -151,22 +156,19 @@ func (m *Model) fields(operation string) (fields []Field) {
}
field.afterAssociation = true
case reflect.Struct:
if is_time {
switch value.Interface().(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
field.SqlType = getSqlType(m.driver, field.Value, 0)
} else {
switch value.Interface().(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
default:
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
} else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
default:
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
} else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
}
}
case reflect.Ptr:

View File

@ -1,6 +1,7 @@
package gorm
import (
"database/sql"
"fmt"
"time"
)
@ -34,15 +35,15 @@ func getSqlType(adaptor string, column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "datetime"
case bool:
case bool, sql.NullBool:
return "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
case int64, uint64:
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64:
case float32, float64, sql.NullFloat64:
return "real"
case string:
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
@ -54,20 +55,20 @@ func getSqlType(adaptor string, column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "timestamp"
case bool:
case bool, sql.NullBool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int"
case int64, uint64:
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64:
case float32, float64, sql.NullFloat64:
return "double"
case []byte:
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
}
return "longblob"
case string:
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
@ -80,17 +81,17 @@ func getSqlType(adaptor string, column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "timestamp with time zone"
case bool:
case bool, sql.NullBool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
case int64, uint64:
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64:
case float32, float64, sql.NullFloat64:
return "double precision"
case []byte:
return "bytea"
case string:
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}