forked from mirror/gorm
Support NullFloat64, NullInt64, NullBool, NullString
This commit is contained in:
parent
562bca71e4
commit
8d97fdb172
35
gorm_test.go
35
gorm_test.go
|
@ -1124,6 +1124,7 @@ func TestSubStruct(t *testing.T) {
|
||||||
|
|
||||||
var p Post
|
var p Post
|
||||||
db.First(&p, post.Id)
|
db.First(&p, post.Id)
|
||||||
|
|
||||||
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
|
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
|
||||||
t.Errorf("Category Id should exist")
|
t.Errorf("Category Id should exist")
|
||||||
}
|
}
|
||||||
|
@ -1265,6 +1266,40 @@ func TestAutoMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NullValue struct {
|
||||||
|
Id int64
|
||||||
|
Name sql.NullString
|
||||||
|
Age sql.NullInt64
|
||||||
|
Male sql.NullBool
|
||||||
|
Height sql.NullFloat64
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlNullValue(t *testing.T) {
|
||||||
|
db.DropTable(&NullValue{})
|
||||||
|
db.AutoMigrate(&NullValue{})
|
||||||
|
|
||||||
|
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
||||||
|
t.Errorf("Not error should raise when test null value", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nv NullValue
|
||||||
|
db.First(&nv, "name = ?", "hello")
|
||||||
|
|
||||||
|
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 {
|
||||||
|
t.Errorf("Should be able to fetch null value")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
||||||
|
t.Errorf("Not error should raise when test null value", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nv2 NullValue
|
||||||
|
db.First(&nv2, "name = ?", "hello-2")
|
||||||
|
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 {
|
||||||
|
t.Errorf("Should be able to fetch null value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
|
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
|
||||||
|
|
38
model.go
38
model.go
|
@ -106,10 +106,15 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
if is_time {
|
if is_time {
|
||||||
field.IsBlank = time_value.IsZero()
|
field.IsBlank = time_value.IsZero()
|
||||||
} else {
|
} else {
|
||||||
m := &Model{data: value.Interface(), driver: m.driver}
|
switch value.Interface().(type) {
|
||||||
fields := m.columnsHasValue("other")
|
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
||||||
if len(fields) == 0 {
|
field.IsBlank = value.FieldByName("Valid").Interface().(bool)
|
||||||
field.IsBlank = true
|
default:
|
||||||
|
m := &Model{data: value.Interface(), driver: m.driver}
|
||||||
|
fields := m.columnsHasValue("other")
|
||||||
|
if len(fields) == 0 {
|
||||||
|
field.IsBlank = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -151,22 +156,19 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
}
|
}
|
||||||
field.afterAssociation = true
|
field.afterAssociation = true
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if is_time {
|
switch value.Interface().(type) {
|
||||||
|
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
|
||||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||||
} else {
|
default:
|
||||||
switch value.Interface().(type) {
|
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
field.foreignKey = p.Name + "Id"
|
||||||
default:
|
field.beforeAssociation = true
|
||||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
} else {
|
||||||
field.foreignKey = p.Name + "Id"
|
foreign_key := typ.Name() + "Id"
|
||||||
field.beforeAssociation = true
|
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||||
} else {
|
field.foreignKey = foreign_key
|
||||||
foreign_key := typ.Name() + "Id"
|
|
||||||
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
|
||||||
field.foreignKey = foreign_key
|
|
||||||
}
|
|
||||||
field.afterAssociation = true
|
|
||||||
}
|
}
|
||||||
|
field.afterAssociation = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
|
|
25
sql_type.go
25
sql_type.go
|
@ -1,6 +1,7 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -34,15 +35,15 @@ func getSqlType(adaptor string, column interface{}, size int) string {
|
||||||
switch column.(type) {
|
switch column.(type) {
|
||||||
case time.Time:
|
case time.Time:
|
||||||
return "datetime"
|
return "datetime"
|
||||||
case bool:
|
case bool, sql.NullBool:
|
||||||
return "bool"
|
return "bool"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||||
return "integer"
|
return "integer"
|
||||||
case int64, uint64:
|
case int64, uint64, sql.NullInt64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64:
|
case float32, float64, sql.NullFloat64:
|
||||||
return "real"
|
return "real"
|
||||||
case string:
|
case string, sql.NullString:
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
}
|
}
|
||||||
|
@ -54,20 +55,20 @@ func getSqlType(adaptor string, column interface{}, size int) string {
|
||||||
switch column.(type) {
|
switch column.(type) {
|
||||||
case time.Time:
|
case time.Time:
|
||||||
return "timestamp"
|
return "timestamp"
|
||||||
case bool:
|
case bool, sql.NullBool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||||
return "int"
|
return "int"
|
||||||
case int64, uint64:
|
case int64, uint64, sql.NullInt64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64:
|
case float32, float64, sql.NullFloat64:
|
||||||
return "double"
|
return "double"
|
||||||
case []byte:
|
case []byte:
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varbinary(%d)", size)
|
return fmt.Sprintf("varbinary(%d)", size)
|
||||||
}
|
}
|
||||||
return "longblob"
|
return "longblob"
|
||||||
case string:
|
case string, sql.NullString:
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
}
|
}
|
||||||
|
@ -80,17 +81,17 @@ func getSqlType(adaptor string, column interface{}, size int) string {
|
||||||
switch column.(type) {
|
switch column.(type) {
|
||||||
case time.Time:
|
case time.Time:
|
||||||
return "timestamp with time zone"
|
return "timestamp with time zone"
|
||||||
case bool:
|
case bool, sql.NullBool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||||
return "integer"
|
return "integer"
|
||||||
case int64, uint64:
|
case int64, uint64, sql.NullInt64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64:
|
case float32, float64, sql.NullFloat64:
|
||||||
return "double precision"
|
return "double precision"
|
||||||
case []byte:
|
case []byte:
|
||||||
return "bytea"
|
return "bytea"
|
||||||
case string:
|
case string, sql.NullString:
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue