diff --git a/main_test.go b/main_test.go index 2bdb8471..92d5a7e0 100644 --- a/main_test.go +++ b/main_test.go @@ -145,6 +145,11 @@ type Animal struct { UpdatedAt time.Time } +type Details struct { + Id int64 + Bulk gorm.Hstore +} + var ( db gorm.DB t1, t2, t3, t4, t5 time.Time @@ -2037,3 +2042,44 @@ func TestIndices(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } } + +func TestHstore(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { + t.Skip() + } + + if err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { + panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) + } + + db.Exec("drop table details") + + if err := db.CreateTable(&Details{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + + bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" + bulk := map[string]*string{ + "bankAccountId": &bankAccountId, + "phoneNumber": &phoneNumber, + "opinion": &opinion, + } + d := Details{Bulk: bulk} + db.Save(&d) + + var d2 Details + if err := db.First(&d2).Error; err != nil { + t.Errorf("Got error when tried to fetch details: %+v", err) + } + + for k := range bulk { + r, ok := d2.Bulk[k] + if !ok { + t.Errorf("Details should be existed") + } + if res, _ := bulk[k]; *res != *r { + t.Errorf("Details should be equal") + } + } + +} diff --git a/postgres.go b/postgres.go index e54a8680..ef4c902d 100644 --- a/postgres.go +++ b/postgres.go @@ -1,8 +1,12 @@ package gorm import ( + "database/sql" + "database/sql/driver" "fmt" "reflect" + + "github.com/lib/pq/hstore" ) type postgres struct { @@ -35,6 +39,10 @@ func (d *postgres) SqlTag(value reflect.Value, size int) string { if value.Type() == timeType { return "timestamp with time zone" } + case reflect.Map: + if value.Type() == hstoreType { + return "hstore" + } default: if _, ok := value.Interface().([]byte); ok { return "bytea" @@ -80,3 +88,43 @@ func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } + +var hstoreType = reflect.TypeOf(Hstore{}) + +type Hstore map[string]*string + +func (h Hstore) Value() (driver.Value, error) { + hstore := hstore.Hstore{Map: map[string]sql.NullString{}} + if len(h) == 0 { + return nil, nil + } + + for key, value := range h { + hstore.Map[key] = sql.NullString{*value, true} + } + return hstore.Value() +} + +func (h *Hstore) Scan(value interface{}) error { + hstore := hstore.Hstore{} + + if err := hstore.Scan(value); err != nil { + return err + } + + if len(hstore.Map) == 0 { + return nil + } + + *h = Hstore{} + for k := range hstore.Map { + if hstore.Map[k].Valid { + s := hstore.Map[k].String + (*h)[k] = &s + } else { + (*h)[k] = nil + } + } + + return nil +}