From 6b2f37189ee1cc1e46cdad9ef6b7f98c69748f0b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Jun 2020 08:20:41 +0800 Subject: [PATCH] Fix few cases with postgres --- migrator/migrator.go | 2 +- schema/field.go | 9 ++++++++- tests/go.mod | 2 ++ tests/postgres_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 tests/postgres_test.go diff --git a/migrator/migrator.go b/migrator/migrator.go index 6baa9dc3..955cc6bb 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -74,7 +74,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { } if field.HasDefaultValue && field.DefaultValue != "" { - if field.DataType == schema.String { + if field.DataType == schema.String && field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValue) diff --git a/schema/field.go b/schema/field.go index e0d49e2f..ea6dcd25 100644 --- a/schema/field.go +++ b/schema/field.go @@ -203,7 +203,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } case reflect.String: field.DataType = String - if field.HasDefaultValue { + isFunc := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") + + if field.HasDefaultValue && !isFunc { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue @@ -253,6 +256,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.DataType == "" && field.DBDataType != "" { + field.DataType = String + } + // setup permission if _, ok := field.TagSettings["-"]; ok { field.Creatable = false diff --git a/tests/go.mod b/tests/go.mod index e5e181d4..e500edd7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,7 +3,9 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 + github.com/lib/pq v1.6.0 gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 diff --git a/tests/postgres_test.go b/tests/postgres_test.go new file mode 100644 index 00000000..98302d87 --- /dev/null +++ b/tests/postgres_test.go @@ -0,0 +1,39 @@ +package tests_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/lib/pq" + "gorm.io/gorm" +) + +func TestPostgres(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Harumph struct { + gorm.Model + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + Things pq.StringArray `gorm:"type:text[]"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + harumph := Harumph{} + DB.Create(&harumph) + + var result Harumph + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil { + t.Errorf("No error should happen, but got %v", err) + } +}