diff --git a/schema/schema.go b/schema/schema.go index 60e621de..9e05303a 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -82,6 +82,10 @@ func (schema Schema) LookUpField(name string) *Field { return nil } +type Tabler interface { + TableName() string +} + // get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() @@ -100,10 +104,16 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return v.(*Schema), nil } + modelValue := reflect.New(modelType) + tableName := namer.TableName(modelType.Name()) + if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } + schema := &Schema{ Name: modelType.Name(), ModelType: modelType, - Table: namer.TableName(modelType.Name()), + Table: tableName, FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, Relationships: Relationships{Relations: map[string]*Relationship{}}, @@ -200,10 +210,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - reflectValue := reflect.New(modelType) callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} for _, name := range callbacks { - if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { + if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) diff --git a/schema/schema_test.go b/schema/schema_test.go index 1029f74f..82f07fa8 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -142,3 +142,21 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { }) } } + +type CustomizeTable struct { +} + +func (CustomizeTable) TableName() string { + return "customize" +} + +func TestCustomizeTableName(t *testing.T) { + customize, err := schema.Parse(&CustomizeTable{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + if customize.Table != "customize" { + t.Errorf("Failed to customize table with TableName method") + } +}