diff --git a/gorm_test.go b/gorm_test.go index 2a7db8df..d119258a 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1124,6 +1124,7 @@ func TestSubStruct(t *testing.T) { var p Post db.First(&p, post.Id) + if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 { t.Errorf("Category Id should exist") } @@ -1265,6 +1266,40 @@ func TestAutoMigration(t *testing.T) { } } +type NullValue struct { + Id int64 + Name sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 +} + +func TestSqlNullValue(t *testing.T) { + db.DropTable(&NullValue{}) + db.AutoMigrate(&NullValue{}) + + if err := db.Save(&NullValue{Name: sql.NullString{"hello", true}, Age: sql.NullInt64{18, true}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil { + t.Errorf("Not error should raise when test null value", err) + } + + var nv NullValue + db.First(&nv, "name = ?", "hello") + + if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 { + t.Errorf("Should be able to fetch null value") + } + + if err := db.Save(&NullValue{Name: sql.NullString{"hello-2", true}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}}).Error; err != nil { + t.Errorf("Not error should raise when test null value", err) + } + + var nv2 NullValue + db.First(&nv2, "name = ?", "hello-2") + if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 { + t.Errorf("Should be able to fetch null value") + } +} + func BenchmarkGorm(b *testing.B) { for x := 0; x < b.N; x++ { email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} diff --git a/model.go b/model.go index da507b6c..2e3b3f11 100644 --- a/model.go +++ b/model.go @@ -106,10 +106,15 @@ func (m *Model) fields(operation string) (fields []Field) { if is_time { field.IsBlank = time_value.IsZero() } else { - m := &Model{data: value.Interface(), driver: m.driver} - fields := m.columnsHasValue("other") - if len(fields) == 0 { - field.IsBlank = true + switch value.Interface().(type) { + case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: + field.IsBlank = value.FieldByName("Valid").Interface().(bool) + default: + m := &Model{data: value.Interface(), driver: m.driver} + fields := m.columnsHasValue("other") + if len(fields) == 0 { + field.IsBlank = true + } } } } @@ -151,22 +156,19 @@ func (m *Model) fields(operation string) (fields []Field) { } field.afterAssociation = true case reflect.Struct: - if is_time { + switch value.Interface().(type) { + case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString, time.Time: field.SqlType = getSqlType(m.driver, field.Value, 0) - } else { - switch value.Interface().(type) { - case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: - default: - if indirect_value.FieldByName(p.Name + "Id").IsValid() { - field.foreignKey = p.Name + "Id" - field.beforeAssociation = true - } else { - foreign_key := typ.Name() + "Id" - if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { - field.foreignKey = foreign_key - } - field.afterAssociation = true + default: + if indirect_value.FieldByName(p.Name + "Id").IsValid() { + field.foreignKey = p.Name + "Id" + field.beforeAssociation = true + } else { + foreign_key := typ.Name() + "Id" + if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { + field.foreignKey = foreign_key } + field.afterAssociation = true } } case reflect.Ptr: diff --git a/sql_type.go b/sql_type.go index 31ff880d..4db77102 100644 --- a/sql_type.go +++ b/sql_type.go @@ -1,6 +1,7 @@ package gorm import ( + "database/sql" "fmt" "time" ) @@ -34,15 +35,15 @@ func getSqlType(adaptor string, column interface{}, size int) string { switch column.(type) { case time.Time: return "datetime" - case bool: + case bool, sql.NullBool: return "bool" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "integer" - case int64, uint64: + case int64, uint64, sql.NullInt64: return "bigint" - case float32, float64: + case float32, float64, sql.NullFloat64: return "real" - case string: + case string, sql.NullString: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } @@ -54,20 +55,20 @@ func getSqlType(adaptor string, column interface{}, size int) string { switch column.(type) { case time.Time: return "timestamp" - case bool: + case bool, sql.NullBool: return "boolean" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "int" - case int64, uint64: + case int64, uint64, sql.NullInt64: return "bigint" - case float32, float64: + case float32, float64, sql.NullFloat64: return "double" case []byte: if size > 0 && size < 65532 { return fmt.Sprintf("varbinary(%d)", size) } return "longblob" - case string: + case string, sql.NullString: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } @@ -80,17 +81,17 @@ func getSqlType(adaptor string, column interface{}, size int) string { switch column.(type) { case time.Time: return "timestamp with time zone" - case bool: + case bool, sql.NullBool: return "boolean" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "integer" - case int64, uint64: + case int64, uint64, sql.NullInt64: return "bigint" - case float32, float64: + case float32, float64, sql.NullFloat64: return "double precision" case []byte: return "bytea" - case string: + case string, sql.NullString: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) }