Support NullFloat64, NullInt64, NullBool, NullString

This commit is contained in:
Jinzhu 2013-11-10 08:57:11 +08:00
parent 562bca71e4
commit 8d97fdb172
3 changed files with 68 additions and 30 deletions

View File

@ -1124,6 +1124,7 @@ func TestSubStruct(t *testing.T) {
var p Post var p Post
db.First(&p, post.Id) db.First(&p, post.Id)
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 { if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
t.Errorf("Category Id should exist") t.Errorf("Category Id should exist")
} }
@ -1265,6 +1266,40 @@ func TestAutoMigration(t *testing.T) {
} }
} }
type NullValue struct {
Id int64
Name sql.NullString
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
}
func TestSqlNullValue(t *testing.T) {
db.DropTable(&NullValue{})
db.AutoMigrate(&NullValue{})
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
var nv NullValue
db.First(&nv, "name = ?", "hello")
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 {
t.Errorf("Should be able to fetch null value")
}
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
t.Errorf("Not error should raise when test null value", err)
}
var nv2 NullValue
db.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 {
t.Errorf("Should be able to fetch null value")
}
}
func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}

View File

@ -106,10 +106,15 @@ func (m *Model) fields(operation string) (fields []Field) {
if is_time { if is_time {
field.IsBlank = time_value.IsZero() field.IsBlank = time_value.IsZero()
} else { } else {
m := &Model{data: value.Interface(), driver: m.driver} switch value.Interface().(type) {
fields := m.columnsHasValue("other") case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
if len(fields) == 0 { field.IsBlank = value.FieldByName("Valid").Interface().(bool)
field.IsBlank = true default:
m := &Model{data: value.Interface(), driver: m.driver}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
}
} }
} }
} }
@ -151,22 +156,19 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
field.afterAssociation = true field.afterAssociation = true
case reflect.Struct: case reflect.Struct:
if is_time { switch value.Interface().(type) {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
field.SqlType = getSqlType(m.driver, field.Value, 0) field.SqlType = getSqlType(m.driver, field.Value, 0)
} else { default:
switch value.Interface().(type) { if indirect_value.FieldByName(p.Name + "Id").IsValid() {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: field.foreignKey = p.Name + "Id"
default: field.beforeAssociation = true
if indirect_value.FieldByName(p.Name + "Id").IsValid() { } else {
field.foreignKey = p.Name + "Id" foreign_key := typ.Name() + "Id"
field.beforeAssociation = true if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
} else { field.foreignKey = foreign_key
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
} }
field.afterAssociation = true
} }
} }
case reflect.Ptr: case reflect.Ptr:

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"database/sql"
"fmt" "fmt"
"time" "time"
) )
@ -34,15 +35,15 @@ func getSqlType(adaptor string, column interface{}, size int) string {
switch column.(type) { switch column.(type) {
case time.Time: case time.Time:
return "datetime" return "datetime"
case bool: case bool, sql.NullBool:
return "bool" return "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32: case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer" return "integer"
case int64, uint64: case int64, uint64, sql.NullInt64:
return "bigint" return "bigint"
case float32, float64: case float32, float64, sql.NullFloat64:
return "real" return "real"
case string: case string, sql.NullString:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} }
@ -54,20 +55,20 @@ func getSqlType(adaptor string, column interface{}, size int) string {
switch column.(type) { switch column.(type) {
case time.Time: case time.Time:
return "timestamp" return "timestamp"
case bool: case bool, sql.NullBool:
return "boolean" return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32: case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int" return "int"
case int64, uint64: case int64, uint64, sql.NullInt64:
return "bigint" return "bigint"
case float32, float64: case float32, float64, sql.NullFloat64:
return "double" return "double"
case []byte: case []byte:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size) return fmt.Sprintf("varbinary(%d)", size)
} }
return "longblob" return "longblob"
case string: case string, sql.NullString:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} }
@ -80,17 +81,17 @@ func getSqlType(adaptor string, column interface{}, size int) string {
switch column.(type) { switch column.(type) {
case time.Time: case time.Time:
return "timestamp with time zone" return "timestamp with time zone"
case bool: case bool, sql.NullBool:
return "boolean" return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32: case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer" return "integer"
case int64, uint64: case int64, uint64, sql.NullInt64:
return "bigint" return "bigint"
case float32, float64: case float32, float64, sql.NullFloat64:
return "double precision" return "double precision"
case []byte: case []byte:
return "bytea" return "bytea"
case string: case string, sql.NullString:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} }