diff --git a/chain.go b/chain.go index 771a481e..18c24584 100644 --- a/chain.go +++ b/chain.go @@ -70,6 +70,7 @@ func (s *Chain) do(value interface{}) *Do { do.specifiedTableName = s.specifiedTableName do.unscoped = s.unscoped do.singularTableName = s.singularTableName + do.debug = s.debug s.value = value do.setModel(value) diff --git a/do.go b/do.go index ae8b4fb2..cd003656 100644 --- a/do.go +++ b/do.go @@ -17,6 +17,7 @@ type Do struct { driver string guessedTableName string specifiedTableName string + debug bool Errors []error model *Model diff --git a/gorm_test.go b/gorm_test.go index 1c3fb24d..2a7db8df 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1081,7 +1081,7 @@ type Category struct { type Post struct { Id int64 - CategoryId int64 + CategoryId sql.NullInt64 MainCategoryId int64 Title string Body string @@ -1124,7 +1124,7 @@ func TestSubStruct(t *testing.T) { var p Post db.First(&p, post.Id) - if post.CategoryId == 0 || p.CategoryId == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 { + if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 { t.Errorf("Category Id should exist") } diff --git a/model.go b/model.go index 9751908e..da507b6c 100644 --- a/model.go +++ b/model.go @@ -1,6 +1,7 @@ package gorm import ( + "database/sql" "errors" "fmt" "go/ast" @@ -153,15 +154,19 @@ func (m *Model) fields(operation string) (fields []Field) { if is_time { field.SqlType = getSqlType(m.driver, field.Value, 0) } else { - if indirect_value.FieldByName(p.Name + "Id").IsValid() { - field.foreignKey = p.Name + "Id" - field.beforeAssociation = true - } else { - foreign_key := typ.Name() + "Id" - if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { - field.foreignKey = foreign_key + switch value.Interface().(type) { + case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: + default: + if indirect_value.FieldByName(p.Name + "Id").IsValid() { + field.foreignKey = p.Name + "Id" + field.beforeAssociation = true + } else { + foreign_key := typ.Name() + "Id" + if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { + field.foreignKey = foreign_key + } + field.afterAssociation = true } - field.afterAssociation = true } } case reflect.Ptr: @@ -363,7 +368,22 @@ func setFieldValue(field reflect.Value, value interface{}) bool { } field.SetInt(reflect.ValueOf(value).Int()) default: - field.Set(reflect.ValueOf(value)) + field_type := field.Type() + if field_type == reflect.TypeOf(value) { + field.Set(reflect.ValueOf(value)) + } else if value == nil { + field.Set(reflect.Zero(field.Type())) + } else if field_type == reflect.TypeOf(sql.NullBool{}) { + field.Set(reflect.ValueOf(sql.NullBool{value.(bool), true})) + } else if field_type == reflect.TypeOf(sql.NullFloat64{}) { + field.Set(reflect.ValueOf(sql.NullFloat64{value.(float64), true})) + } else if field_type == reflect.TypeOf(sql.NullInt64{}) { + field.Set(reflect.ValueOf(sql.NullInt64{value.(int64), true})) + } else if field_type == reflect.TypeOf(sql.NullString{}) { + field.Set(reflect.ValueOf(sql.NullString{value.(string), true})) + } else { + field.Set(reflect.ValueOf(value)) + } } return true } else {