From 5c478b46e1b351551c9c1b6f326aa13fc461444f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 18 Feb 2015 10:19:34 +0800 Subject: [PATCH] Use Common Initialisms from golint --- README.md | 1 - association.go | 2 +- association_test.go | 4 +-- callback_create.go | 2 +- callback_update.go | 2 +- doc/development.md | 2 +- model_struct.go | 68 ++++++++++++++++++++++----------------------- scope.go | 2 +- scope_private.go | 8 +++--- structs_test.go | 4 +-- utils.go | 21 +++++++++++--- utils_private.go | 4 +-- 12 files changed, 66 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index 4da9a0fd..7879d6ee 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,6 @@ db.Save(&User{Name: "xxx"}) // table "users" * Column name is the snake case of field's name * Use `Id` field as primary key -* Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)` * Use `CreatedAt` to store record's created time if field exists * Use `UpdatedAt` to store record's updated time if field exists * Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete) diff --git a/association.go b/association.go index 24faf013..c52f3e58 100644 --- a/association.go +++ b/association.go @@ -157,7 +157,7 @@ func (association *Association) Count() int { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey) if relationship.ForeignType != "" { - countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToDBColumnName(relationship.ForeignType))), scope.TableName()) + countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToDBName(relationship.ForeignType))), scope.TableName()) } countScope.Count(&count) } else if relationship.Kind == "belongs_to" { diff --git a/association_test.go b/association_test.go index 6079454b..201eed53 100644 --- a/association_test.go +++ b/association_test.go @@ -65,11 +65,11 @@ func TestRelated(t *testing.T) { DB.Save(&user) - if user.CreditCard.Id == 0 { + if user.CreditCard.ID == 0 { t.Errorf("After user save, credit card should have id") } - if user.BillingAddress.Id == 0 { + if user.BillingAddress.ID == 0 { t.Errorf("After user save, billing address should have id") } diff --git a/callback_create.go b/callback_create.go index 1bf44a3d..a27a9eb4 100644 --- a/callback_create.go +++ b/callback_create.go @@ -26,7 +26,7 @@ func Create(scope *Scope) { var sqls, columns []string for _, field := range scope.Fields() { if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) { - if !field.IsBlank || field.DefaultValue == nil { + if !field.IsBlank || !field.HasDefaultValue { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) } diff --git a/callback_update.go b/callback_update.go index fc8ee3ad..e7c7375a 100644 --- a/callback_update.go +++ b/callback_update.go @@ -48,7 +48,7 @@ func Update(scope *Scope) { } else { for _, field := range scope.Fields() { if !field.IsPrimaryKey && field.IsNormal { - if !field.IsBlank || field.DefaultValue == nil { + if !field.IsBlank || !field.HasDefaultValue { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } diff --git a/doc/development.md b/doc/development.md index de5bb9aa..674cfc43 100644 --- a/doc/development.md +++ b/doc/development.md @@ -2,7 +2,7 @@ ## Architecture -The most notable component of Gorm is `gorm.DB`, which hold database connection. It could be initialized like this: +The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this: db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") diff --git a/model_struct.go b/model_struct.go index ae91c229..421e1d35 100644 --- a/model_struct.go +++ b/model_struct.go @@ -18,36 +18,36 @@ type ModelStruct struct { } type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsScanner bool - IsTime bool - IsNormal bool - IsIgnored bool - DefaultValue *string - SqlTag string - Tag reflect.StructTag - Struct reflect.StructField - Relationship *Relationship + DBName string + Name string + Names []string + IsPrimaryKey bool + IsNormal bool + IsIgnored bool + IsScanner bool + HasDefaultValue bool + SqlTag string + Tag reflect.StructTag + Struct reflect.StructField + IsForeignKey bool + Relationship *Relationship } func (structField *StructField) clone() *StructField { return &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsScanner: structField.IsScanner, - IsTime: structField.IsTime, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - DefaultValue: structField.DefaultValue, - SqlTag: structField.SqlTag, - Tag: structField.Tag, - Struct: structField.Struct, - Relationship: structField.Relationship, + DBName: structField.DBName, + Name: structField.Name, + Names: structField.Names, + IsPrimaryKey: structField.IsPrimaryKey, + IsNormal: structField.IsNormal, + IsIgnored: structField.IsIgnored, + IsScanner: structField.IsScanner, + HasDefaultValue: structField.HasDefaultValue, + SqlTag: structField.SqlTag, + Tag: structField.Tag, + Struct: structField.Struct, + IsForeignKey: structField.IsForeignKey, + Relationship: structField.Relationship, } } @@ -146,7 +146,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } } else { - modelStruct.TableName = ToDBColumnName(scopeType.Name()) + modelStruct.TableName = ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { for index, reg := range pluralMapKeys { if reg.MatchString(modelStruct.TableName) { @@ -176,14 +176,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.PrimaryKeyField = field } - if value, ok := sqlSettings["DEFAULT"]; ok { - field.DefaultValue = &value + if _, ok := sqlSettings["DEFAULT"]; ok { + field.HasDefaultValue = true } if value, ok := gormSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBColumnName(fieldStruct.Name) + field.DBName = ToDBName(fieldStruct.Name) } fieldType, indirectType := fieldStruct.Type, fieldStruct.Type @@ -196,7 +196,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { - field.IsTime, field.IsNormal = true, true + field.IsNormal = true } many2many := gormSettings["MANY2MANY"] @@ -238,8 +238,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ForeignType: foreignType, ForeignFieldName: foreignKey, AssociationForeignFieldName: associationForeignKey, - ForeignDBName: ToDBColumnName(foreignKey), - AssociationForeignDBName: ToDBColumnName(associationForeignKey), + ForeignDBName: ToDBName(foreignKey), + AssociationForeignDBName: ToDBName(associationForeignKey), Kind: kind, } } else { @@ -274,7 +274,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.Relationship = &Relationship{ ForeignFieldName: foreignKey, - ForeignDBName: ToDBColumnName(foreignKey), + ForeignDBName: ToDBName(foreignKey), ForeignType: foreignType, Kind: kind, } diff --git a/scope.go b/scope.go index ce9c66c0..56b0e52b 100644 --- a/scope.go +++ b/scope.go @@ -145,7 +145,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { return field.Set(value) } - dbName = ToDBColumnName(dbName) + dbName = ToDBName(dbName) if field, ok := scope.Fields()[dbName]; ok { return field.Set(value) } diff --git a/scope_private.go b/scope_private.go index f10586fc..27457a04 100644 --- a/scope_private.go +++ b/scope_private.go @@ -316,7 +316,7 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore fields := scope.Fields() for key, value := range values { - if field, ok := fields[ToDBColumnName(key)]; ok && field.Field.IsValid() { + if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { if !equalAsString(field.Field.Interface(), value) { hasUpdate = true @@ -389,8 +389,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { fromFields := scope.Fields() toFields := toScope.Fields() for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField := fromFields[ToDBColumnName(foreignKey)] - toField := toFields[ToDBColumnName(foreignKey)] + fromField := fromFields[ToDBName(foreignKey)] + toField := toFields[ToDBName(foreignKey)] if fromField != nil { if relationship := fromField.Relationship; relationship != nil { @@ -411,7 +411,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := toScope.db.Where(sql, scope.PrimaryKeyValue()) if relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(ToDBColumnName(relationship.ForeignType))), scope.TableName()) + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(ToDBName(relationship.ForeignType))), scope.TableName()) } scope.Err(query.Find(value).Error) } diff --git a/structs_test.go b/structs_test.go index 95c1ea38..3bf76f3f 100644 --- a/structs_test.go +++ b/structs_test.go @@ -36,7 +36,7 @@ type User struct { } type CreditCard struct { - Id int8 + ID int8 Number string UserId sql.NullInt64 CreatedAt time.Time @@ -53,7 +53,7 @@ type Email struct { } type Address struct { - Id int + ID int Address1 string Address2 string Post string diff --git a/utils.go b/utils.go index b1b3166f..3298eb7d 100644 --- a/utils.go +++ b/utils.go @@ -5,15 +5,28 @@ import ( "strings" ) +// Copied from golint +var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} +var commonInitialismsReplacer *strings.Replacer + +func init() { + var commonInitialismsForReplacer []string + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + var smap = map[string]string{} -func ToDBColumnName(u string) string { - if v, ok := smap[u]; ok { +func ToDBName(name string) string { + if v, ok := smap[name]; ok { return v } + value := commonInitialismsReplacer.Replace(name) buf := bytes.NewBufferString("") - for i, v := range u { + for i, v := range value { if i > 0 && v >= 'A' && v <= 'Z' { buf.WriteRune('_') } @@ -21,7 +34,7 @@ func ToDBColumnName(u string) string { } s := strings.ToLower(buf.String()) - smap[u] = s + smap[name] = s return s } diff --git a/utils_private.go b/utils_private.go index 09791ec1..6f609ae0 100644 --- a/utils_private.go +++ b/utils_private.go @@ -44,7 +44,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { switch value := values.(type) { case map[string]interface{}: for k, v := range value { - attrs[ToDBColumnName(k)] = v + attrs[ToDBName(k)] = v } case []interface{}: for _, v := range value { @@ -58,7 +58,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: scope := Scope{Value: values}