diff --git a/main.go b/main.go index de6ce428..993e19b1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" ) @@ -162,7 +163,7 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + modelStructsMap = sync.Map{} s.parent.singularTable = enable } diff --git a/model_struct.go b/model_struct.go index f571e2e8..8506fe87 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() +var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { @@ -48,7 +27,7 @@ type ModelStruct struct { defaultTableName string } -// TableName get model's table name +// TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name @@ -152,8 +131,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -601,7 +580,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(reflectType, &modelStruct) return &modelStruct }