forked from mirror/gorm
Support NullFloat64, NullInt64, NullBool, NullString
This commit is contained in:
parent
562bca71e4
commit
8d97fdb172
35
gorm_test.go
35
gorm_test.go
|
@ -1124,6 +1124,7 @@ func TestSubStruct(t *testing.T) {
|
|||
|
||||
var p Post
|
||||
db.First(&p, post.Id)
|
||||
|
||||
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
|
||||
t.Errorf("Category Id should exist")
|
||||
}
|
||||
|
@ -1265,6 +1266,40 @@ func TestAutoMigration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type NullValue struct {
|
||||
Id int64
|
||||
Name sql.NullString
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
}
|
||||
|
||||
func TestSqlNullValue(t *testing.T) {
|
||||
db.DropTable(&NullValue{})
|
||||
db.AutoMigrate(&NullValue{})
|
||||
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
||||
t.Errorf("Not error should raise when test null value", err)
|
||||
}
|
||||
|
||||
var nv NullValue
|
||||
db.First(&nv, "name = ?", "hello")
|
||||
|
||||
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 {
|
||||
t.Errorf("Should be able to fetch null value")
|
||||
}
|
||||
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil {
|
||||
t.Errorf("Not error should raise when test null value", err)
|
||||
}
|
||||
|
||||
var nv2 NullValue
|
||||
db.First(&nv2, "name = ?", "hello-2")
|
||||
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 {
|
||||
t.Errorf("Should be able to fetch null value")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
for x := 0; x < b.N; x++ {
|
||||
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
|
|
38
model.go
38
model.go
|
@ -106,10 +106,15 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
if is_time {
|
||||
field.IsBlank = time_value.IsZero()
|
||||
} else {
|
||||
m := &Model{data: value.Interface(), driver: m.driver}
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
field.IsBlank = true
|
||||
switch value.Interface().(type) {
|
||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
||||
field.IsBlank = value.FieldByName("Valid").Interface().(bool)
|
||||
default:
|
||||
m := &Model{data: value.Interface(), driver: m.driver}
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
field.IsBlank = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -151,22 +156,19 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
field.afterAssociation = true
|
||||
case reflect.Struct:
|
||||
if is_time {
|
||||
switch value.Interface().(type) {
|
||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time:
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
} else {
|
||||
switch value.Interface().(type) {
|
||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
||||
default:
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
field.beforeAssociation = true
|
||||
} else {
|
||||
foreign_key := typ.Name() + "Id"
|
||||
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
field.foreignKey = foreign_key
|
||||
}
|
||||
field.afterAssociation = true
|
||||
default:
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
field.beforeAssociation = true
|
||||
} else {
|
||||
foreign_key := typ.Name() + "Id"
|
||||
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
field.foreignKey = foreign_key
|
||||
}
|
||||
field.afterAssociation = true
|
||||
}
|
||||
}
|
||||
case reflect.Ptr:
|
||||
|
|
25
sql_type.go
25
sql_type.go
|
@ -1,6 +1,7 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
@ -34,15 +35,15 @@ func getSqlType(adaptor string, column interface{}, size int) string {
|
|||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "datetime"
|
||||
case bool:
|
||||
case bool, sql.NullBool:
|
||||
return "bool"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "integer"
|
||||
case int64, uint64:
|
||||
case int64, uint64, sql.NullInt64:
|
||||
return "bigint"
|
||||
case float32, float64:
|
||||
case float32, float64, sql.NullFloat64:
|
||||
return "real"
|
||||
case string:
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
|
@ -54,20 +55,20 @@ func getSqlType(adaptor string, column interface{}, size int) string {
|
|||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "timestamp"
|
||||
case bool:
|
||||
case bool, sql.NullBool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "int"
|
||||
case int64, uint64:
|
||||
case int64, uint64, sql.NullInt64:
|
||||
return "bigint"
|
||||
case float32, float64:
|
||||
case float32, float64, sql.NullFloat64:
|
||||
return "double"
|
||||
case []byte:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
}
|
||||
return "longblob"
|
||||
case string:
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
|
@ -80,17 +81,17 @@ func getSqlType(adaptor string, column interface{}, size int) string {
|
|||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "timestamp with time zone"
|
||||
case bool:
|
||||
case bool, sql.NullBool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "integer"
|
||||
case int64, uint64:
|
||||
case int64, uint64, sql.NullInt64:
|
||||
return "bigint"
|
||||
case float32, float64:
|
||||
case float32, float64, sql.NullFloat64:
|
||||
return "double precision"
|
||||
case []byte:
|
||||
return "bytea"
|
||||
case string:
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue