mirror of https://github.com/go-gorm/gorm.git
add postgres hstore type support
This commit is contained in:
parent
744cb7dfda
commit
5f0e640f3d
46
main_test.go
46
main_test.go
|
@ -145,6 +145,11 @@ type Animal struct {
|
|||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Details struct {
|
||||
Id int64
|
||||
Bulk gorm.Hstore
|
||||
}
|
||||
|
||||
var (
|
||||
db gorm.DB
|
||||
t1, t2, t3, t4, t5 time.Time
|
||||
|
@ -2037,3 +2042,44 @@ func TestIndices(t *testing.T) {
|
|||
t.Errorf("Got error when tried to create index: %+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHstore(t *testing.T) {
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
if err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
|
||||
}
|
||||
|
||||
db.Exec("drop table details")
|
||||
|
||||
if err := db.CreateTable(&Details{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait"
|
||||
bulk := map[string]*string{
|
||||
"bankAccountId": &bankAccountId,
|
||||
"phoneNumber": &phoneNumber,
|
||||
"opinion": &opinion,
|
||||
}
|
||||
d := Details{Bulk: bulk}
|
||||
db.Save(&d)
|
||||
|
||||
var d2 Details
|
||||
if err := db.First(&d2).Error; err != nil {
|
||||
t.Errorf("Got error when tried to fetch details: %+v", err)
|
||||
}
|
||||
|
||||
for k := range bulk {
|
||||
r, ok := d2.Bulk[k]
|
||||
if !ok {
|
||||
t.Errorf("Details should be existed")
|
||||
}
|
||||
if res, _ := bulk[k]; *res != *r {
|
||||
t.Errorf("Details should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
48
postgres.go
48
postgres.go
|
@ -1,8 +1,12 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/lib/pq/hstore"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
|
@ -35,6 +39,10 @@ func (d *postgres) SqlTag(value reflect.Value, size int) string {
|
|||
if value.Type() == timeType {
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
case reflect.Map:
|
||||
if value.Type() == hstoreType {
|
||||
return "hstore"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "bytea"
|
||||
|
@ -80,3 +88,43 @@ func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string)
|
|||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
var hstoreType = reflect.TypeOf(Hstore{})
|
||||
|
||||
type Hstore map[string]*string
|
||||
|
||||
func (h Hstore) Value() (driver.Value, error) {
|
||||
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||
if len(h) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for key, value := range h {
|
||||
hstore.Map[key] = sql.NullString{*value, true}
|
||||
}
|
||||
return hstore.Value()
|
||||
}
|
||||
|
||||
func (h *Hstore) Scan(value interface{}) error {
|
||||
hstore := hstore.Hstore{}
|
||||
|
||||
if err := hstore.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hstore.Map) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
*h = Hstore{}
|
||||
for k := range hstore.Map {
|
||||
if hstore.Map[k].Valid {
|
||||
s := hstore.Map[k].String
|
||||
(*h)[k] = &s
|
||||
} else {
|
||||
(*h)[k] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue