diff --git a/main.go b/main.go index bf8acbae..181722fd 100644 --- a/main.go +++ b/main.go @@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join } } } - -func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) { - s.NewScope(source).GetModelStruct().TableName = handler -} diff --git a/model_struct.go b/model_struct.go index a70489fc..10423ae2 100644 --- a/model_struct.go +++ b/model_struct.go @@ -13,11 +13,19 @@ import ( var modelStructs = map[reflect.Type]*ModelStruct{} +var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { + return defaultTableName +} + type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - TableName func(*DB) string + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + defaultTableName string +} + +func (s ModelStruct) TableName(db *DB) string { + return DefaultTableNameHandler(db, s.defaultTableName) } type StructField struct { @@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Set tablename - if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { - if results := fm.Call([]reflect.Value{}); len(results) > 0 { - if name, ok := results[0].Interface().(string); ok { - modelStruct.TableName = func(*DB) string { - return name - } - } - } + type tabler interface { + TableName() string + } + + if tabler, ok := reflect.New(scopeType).Interface().(interface { + TableName() string + }); ok { + modelStruct.defaultTableName = tabler.TableName() } else { name := ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { @@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStruct.TableName = func(*DB) string { - return name - } + modelStruct.defaultTableName = name } // Get all fields diff --git a/scope.go b/scope.go index 54bf5c84..960a653c 100644 --- a/scope.go +++ b/scope.go @@ -251,12 +251,7 @@ func (scope *Scope) TableName() string { return tabler.TableName(scope.db) } - if scope.GetModelStruct().TableName != nil { - return scope.GetModelStruct().TableName(scope.db) - } - - scope.Err(errors.New("wrong table name")) - return "" + return scope.GetModelStruct().TableName(scope.db) } func (scope *Scope) QuotedTableName() (name string) {