From b9a39be9c5e77bb0bfebd516114a8a4d605c645a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Apr 2015 11:36:01 +0800 Subject: [PATCH] Add Tabler --- model_struct.go | 16 +++++++++++----- scope.go | 25 ++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/model_struct.go b/model_struct.go index f73c902b..a70489fc 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,7 +17,7 @@ type ModelStruct struct { PrimaryFields []*StructField StructFields []*StructField ModelType reflect.Type - TableName string + TableName func(*DB) string } type StructField struct { @@ -97,18 +97,24 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { if results := fm.Call([]reflect.Value{}); len(results) > 0 { if name, ok := results[0].Interface().(string); ok { - modelStruct.TableName = name + modelStruct.TableName = func(*DB) string { + return name + } } } } else { - modelStruct.TableName = ToDBName(scopeType.Name()) + name := ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { for index, reg := range pluralMapKeys { - if reg.MatchString(modelStruct.TableName) { - modelStruct.TableName = reg.ReplaceAllString(modelStruct.TableName, pluralMapValues[index]) + if reg.MatchString(name) { + name = reg.ReplaceAllString(name, pluralMapValues[index]) } } } + + modelStruct.TableName = func(*DB) string { + return name + } } // Get all fields diff --git a/scope.go b/scope.go index d8e39348..86994a85 100644 --- a/scope.go +++ b/scope.go @@ -224,12 +224,35 @@ func (scope *Scope) AddToVars(value interface{}) string { } } +type tabler interface { + TableName() string +} + +type dbTabler interface { + TableName(*DB) string +} + // TableName get table name func (scope *Scope) TableName() string { if scope.Search != nil && len(scope.Search.tableName) > 0 { return scope.Search.tableName } - return scope.GetModelStruct().TableName + + if tabler, ok := scope.Value.(tabler); ok { + return tabler.TableName() + } + + if tabler, ok := scope.Value.(dbTabler); ok { + return tabler.TableName(scope.db) + } + + if scope.GetModelStruct().TableName != nil { + scope.Search.tableName = scope.GetModelStruct().TableName(scope.db) + return scope.Search.tableName + } + + scope.Err(errors.New("wrong table name")) + return "" } func (scope *Scope) QuotedTableName() (name string) {