From ef002fd7accb973c9f36931e2b1c3112d2b062ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 20 Jul 2020 18:59:28 +0800 Subject: [PATCH] Add GORMDataType to Field, close #3171 --- callbacks/update.go | 4 ++-- gorm.go | 1 + schema/field.go | 7 +++++++ schema/relationship.go | 3 +++ schema/schema.go | 2 +- 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 97a0e893..d549f97b 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -202,7 +202,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) - } else if field.DataType == schema.Time { + } else if field.GORMDataType == schema.Time { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } else { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) @@ -223,7 +223,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() - } else if field.DataType == schema.Time { + } else if field.GORMDataType == schema.Time { value = stmt.DB.NowFunc() } else { value = stmt.DB.NowFunc().Unix() diff --git a/gorm.go b/gorm.go index e3b1dd35..338a1473 100644 --- a/gorm.go +++ b/gorm.go @@ -300,6 +300,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac for _, ref := range relation.References { if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType ref.ForeignKey = f } else { return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) diff --git a/schema/field.go b/schema/field.go index bc3dbc62..a170e60e 100644 --- a/schema/field.go +++ b/schema/field.go @@ -38,6 +38,7 @@ type Field struct { DBName string BindNames []string DataType DataType + GORMDataType DataType PrimaryKey bool AutoIncrement bool Creatable bool @@ -221,6 +222,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + field.GORMDataType = field.DataType + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } @@ -250,6 +253,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: diff --git a/schema/relationship.go b/schema/relationship.go index c290c5ba..e67092b4 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -157,6 +157,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi // use same data type for foreign keys relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, @@ -285,6 +286,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, f := range relation.JoinTable.Fields { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType + f.GORMDataType = fieldsMap[f.Name].GORMDataType relation.JoinTable.PrimaryFields[idx] = f ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] @@ -387,6 +389,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH for idx, foreignField := range foreignFields { // use same data type for foreign keys foreignField.DataType = primaryFields[idx].DataType + foreignField.GORMDataType = primaryFields[idx].GORMDataType relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/schema/schema.go b/schema/schema.go index 66e02443..bcf65939 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -182,7 +182,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if field := schema.PrioritizedPrimaryField; field != nil { - switch field.DataType { + switch field.GORMDataType { case Int, Uint: if !field.HasDefaultValue || field.DefaultValueInterface != nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)