diff --git a/schema/schema.go b/schema/schema.go index 60a434fa..c8d79ddc 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -73,6 +73,15 @@ type Tabler interface { // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + return parse(dest, cacheStore, namer, "") +} + +// ParseWithSchemaTable get data type from dialector with extra schema table +func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { + return parse(dest, cacheStore, namer, schemaTable) +} + +func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } @@ -107,6 +116,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) modelValue := reflect.New(modelType) tableName := namer.TableName(modelType.Name()) + if schemaTable != "" { + tableName = schemaTable + } if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } @@ -235,11 +247,13 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { - s := v.(*Schema) - // Wait for the initialization of other goroutines to complete - <-s.initialized - return s, s.err + if schemaTable == "" { + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } } defer func() { diff --git a/statement.go b/statement.go index c631031e..bbe00106 100644 --- a/statement.go +++ b/statement.go @@ -456,7 +456,7 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] diff --git a/tests/migrate_test.go b/tests/migrate_test.go index ba271478..06eb96b3 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -381,3 +381,33 @@ func TestMigrateConstraint(t *testing.T) { } } } + +type MigrateUser struct { + gorm.Model + Name string `gorm:"index"` +} + +// https://github.com/go-gorm/gorm/issues/4752 +func TestMigrateIndexesWithDynamicTableName(t *testing.T) { + tableNameSuffixes := []string{"01", "02", "03"} + for _, v := range tableNameSuffixes { + tableName := "migrate_user_" + v + m := DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table(tableName) + }).Migrator() + + if err := m.AutoMigrate(&MigrateUser{}); err != nil { + t.Fatalf("Failed to create table for %#v", tableName) + } + + if !m.HasTable(tableName) { + t.Fatalf("Failed to create table for %#v", tableName) + } + if !m.HasIndex(&MigrateUser{}, "Name") { + t.Fatalf("Should find index for %s's name after AutoMigrate", tableName) + } + if !m.HasIndex(&MigrateUser{}, "DeletedAt") { + t.Fatalf("Should find index for %s's deleted_at after AutoMigrate", tableName) + } + } +}