From 93986de8e43bc9af6864621c9a4855f0f860cde2 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 28 May 2022 23:09:13 +0800 Subject: [PATCH] fix: migrate column default value (#5359) Co-authored-by: Jinzhu --- migrator/migrator.go | 16 ++++- tests/migrate_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 757ab949..4acc9df6 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -448,10 +448,20 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } // check default value - if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { - // not primary key - if !field.PrimaryKey { + if !field.PrimaryKey { + dv, dvNotNull := columnType.DefaultValue() + if dvNotNull && field.DefaultValueInterface == nil { + // defalut value -> null alterColumn = true + } else if !dvNotNull && field.DefaultValueInterface != nil { + // null -> default value + alterColumn = true + } else if dv != field.DefaultValue { + // default value not equal + // not both null + if !(field.DefaultValueInterface == nil && !dvNotNull) { + alterColumn = true + } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2b5d7ecd..9e7caec9 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "fmt" "math/rand" "reflect" "strings" @@ -714,6 +715,141 @@ func TestPrimarykeyID(t *testing.T) { } } +func TestUniqueColumn(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + + type UniqueTest struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique"` + } + + type UniqueTest2 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:NULL"` + } + + type UniqueTest3 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:''"` + } + + type UniqueTest4 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:'123'"` + } + + var err error + err = DB.Migrator().DropTable(&UniqueTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // null -> null + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok := ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // not trigger alert column + AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> empty string + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, true, ok) + + // empty string -> 123 + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest4{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "123", value) + AssertEqual(t, true, ok) + + // 123 -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + +} + +func findColumnType(dest interface{}, columnName string) ( + foundColumn gorm.ColumnType, err error) { + columnTypes, err := DB.Migrator().ColumnTypes(dest) + if err != nil { + err = fmt.Errorf("ColumnTypes err:%v", err) + return + } + + for _, c := range columnTypes { + if c.Name() == columnName { + foundColumn = c + break + } + } + return +} + func TestInvalidCachedPlan(t *testing.T) { if DB.Dialector.Name() != "postgres" { return