diff --git a/do.go b/do.go index cd003656..ebb5aa78 100644 --- a/do.go +++ b/do.go @@ -372,7 +372,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) { if from_from { s.where(foreign_value).query() } else { - query := fmt.Sprintf("%v = %v", toSnake(foreign_key), foreign_value) + query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value)) s.where(query).query() } } @@ -529,6 +529,8 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { } case int, int64, int32: return s.primaryCondiation(s.addToVars(query)) + case sql.NullInt64: + return s.primaryCondiation(s.addToVars(query.(sql.NullInt64).Int64)) case []int64, []int, []int32, []string: str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb()) clause["args"] = []interface{}{query} @@ -558,7 +560,13 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: - str = strings.Replace(str, "?", s.addToVars(arg), 1) + switch arg.(type) { + case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: + value := reflect.ValueOf(arg).Field(0).Interface() + str = strings.Replace(str, "?", s.addToVars(value), 1) + default: + str = strings.Replace(str, "?", s.addToVars(arg), 1) + } } } return @@ -616,7 +624,13 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: - str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1) + switch arg.(type) { + case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: + value := reflect.ValueOf(arg).Field(0).Interface() + str = strings.Replace(not_equal_sql, "?", s.addToVars(value), 1) + default: + str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1) + } } } return diff --git a/gorm_test.go b/gorm_test.go index d119258a..d855f598 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -18,21 +18,21 @@ type User struct { Birthday time.Time // Time Age int64 Name string - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressId int64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key + CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically + UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically + DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more + Emails []Email // Embedded structs + BillingAddress Address // Embedded struct + BillingAddressId sql.NullInt64 // Embedded struct's foreign key + ShippingAddress Address // Embedded struct + ShippingAddressId int64 // Embedded struct's foreign key CreditCard CreditCard } type CreditCard struct { Id int64 Number string - UserId int64 + UserId sql.NullInt64 CreatedAt time.Time UpdatedAt time.Time DeletedAt time.Time diff --git a/model.go b/model.go index 2e3b3f11..90a9d591 100644 --- a/model.go +++ b/model.go @@ -108,7 +108,7 @@ func (m *Model) fields(operation string) (fields []Field) { } else { switch value.Interface().(type) { case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: - field.IsBlank = value.FieldByName("Valid").Interface().(bool) + field.IsBlank = !value.FieldByName("Valid").Interface().(bool) default: m := &Model{data: value.Interface(), driver: m.driver} fields := m.columnsHasValue("other")