diff --git a/README.md b/README.md index 8558d4d4..36edc38c 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,8 @@ db, err := Open("postgres", "user=gorm dbname=gorm sslmode=disable") // Set the maximum idle database connections db.SetPool(100) +// By default, table name is plural of struct type, if you like singular table name +db.SingularTable(true) // Gorm is goroutines friendly, so you can create a global variable to keep the connection and use it everywhere like this diff --git a/chain.go b/chain.go index 79a8db0e..8ccf4d7b 100644 --- a/chain.go +++ b/chain.go @@ -9,10 +9,11 @@ import ( ) type Chain struct { - db *sql.DB - driver string - debug bool - value interface{} + db *sql.DB + driver string + debug bool + singularTableName bool + value interface{} Errors []error Error error @@ -68,6 +69,7 @@ func (s *Chain) do(value interface{}) *Do { do.limitStr = s.limitStr do.specifiedTableName = s.specifiedTableName do.unscoped = s.unscoped + do.singularTableName = s.singularTableName s.value = value do.setModel(value) diff --git a/do.go b/do.go index 47dbaa18..44bedf22 100644 --- a/do.go +++ b/do.go @@ -35,12 +35,13 @@ type Do struct { unscoped bool updateAttrs map[string]interface{} ignoreProtectedAttrs bool + singularTableName bool } func (s *Do) tableName() string { if s.specifiedTableName == "" { var err error - s.guessedTableName, err = s.model.tableName() + s.guessedTableName, err = s.model.tableName(s.singularTableName) s.err(err) return s.guessedTableName } else { diff --git a/gorm_test.go b/gorm_test.go index 9b263018..d7d62fd1 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -81,7 +81,6 @@ var ( func init() { var err error db, err = Open("postgres", "user=gorm dbname=gorm sslmode=disable") - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; // CREATE DATABASE 'gorm'; // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; diff --git a/main.go b/main.go index 9eb91f66..8bf63aea 100644 --- a/main.go +++ b/main.go @@ -3,9 +3,10 @@ package gorm import "database/sql" type DB struct { - db *sql.DB - driver string - DebugMode bool + db *sql.DB + driver string + DebugMode bool + SingularTableName bool } func Open(driver, source string) (db DB, err error) { @@ -18,8 +19,12 @@ func (s *DB) SetPool(n int) { s.db.SetMaxIdleConns(n) } +func (s *DB) SingularTable(result bool) { + s.SingularTableName = result +} + func (s *DB) buildChain() *Chain { - return &Chain{db: s.db, driver: s.driver, debug: s.DebugMode} + return &Chain{db: s.db, driver: s.driver, debug: s.DebugMode, singularTableName: s.SingularTableName} } func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { diff --git a/model.go b/model.go index a30cc2d9..df0b1598 100644 --- a/model.go +++ b/model.go @@ -276,7 +276,7 @@ func (m *Model) typeName() string { return typ.Name() } -func (m *Model) tableName() (str string, err error) { +func (m *Model) tableName(singularTableName bool) (str string, err error) { if m.data == nil { err = errors.New("Model haven't been set") return @@ -284,11 +284,13 @@ func (m *Model) tableName() (str string, err error) { str = toSnake(m.typeName()) - pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"} - for key, value := range pluralMap { - reg := regexp.MustCompile(key + "$") - if reg.MatchString(str) { - return reg.ReplaceAllString(str, value), err + if !singularTableName { + pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"} + for key, value := range pluralMap { + reg := regexp.MustCompile(key + "$") + if reg.MatchString(str) { + return reg.ReplaceAllString(str, value), err + } } } return