diff --git a/dialect.go b/dialect.go index facde0d0..de72b79a 100644 --- a/dialect.go +++ b/dialect.go @@ -68,10 +68,14 @@ func RegisterDialect(name string, dialect Dialect) { dialectsMap[name] = dialect } -// ParseFieldStructForDialect parse field struct for dialect -func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { +// ParseFieldStructForDialect get field's sql data type +var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type - var reflectType = field.Struct.Type + var ( + reflectType = field.Struct.Type + dataType = field.TagSettings["TYPE"] + ) + for reflectType.Kind() == reflect.Ptr { reflectType = reflectType.Elem() } @@ -79,6 +83,12 @@ func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, s // Get redirected field value fieldValue = reflect.Indirect(reflect.New(reflectType)) + if gormDataType, ok := fieldValue.Interface().(interface { + GormDataType(Dialect) string + }); ok { + dataType = gormDataType.GormDataType(dialect) + } + // Get scanner's real value var getScannerValue func(reflect.Value) getScannerValue = func(value reflect.Value) { @@ -102,5 +112,5 @@ func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, s additionalType = additionalType + " DEFAULT " + value } - return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType) + return fieldValue, dataType, size, strings.TrimSpace(additionalType) } diff --git a/dialect_common.go b/dialect_common.go index 5b5682c5..601afd4c 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -39,8 +39,8 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } -func (commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *commonDialect) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) if sqlType == "" { switch dataValue.Kind() { diff --git a/dialect_mysql.go b/dialect_mysql.go index 11b894b3..b471a162 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -27,8 +27,8 @@ func (mysql) Quote(key string) string { } // Get Data Type for MySQL Dialect -func (mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *mysql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) // MySQL allows only one auto increment column per table, and it must // be a KEY column. diff --git a/dialect_postgres.go b/dialect_postgres.go index 5a6114c0..7d07a02c 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -23,8 +23,8 @@ func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } -func (postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *postgres) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) if sqlType == "" { switch dataValue.Kind() { diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 2abcefa5..33f4aa50 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -21,8 +21,8 @@ func (sqlite3) GetName() string { } // Get Data Type for Sqlite Dialect -func (sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *sqlite3) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) if sqlType == "" { switch dataValue.Kind() { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a7bca6b8..ad2960ef 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -44,8 +44,8 @@ func (mssql) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } -func (mssql) DataTypeOf(field *gorm.StructField) string { - var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field) +func (s *mssql) DataTypeOf(field *gorm.StructField) string { + var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) if sqlType == "" { switch dataValue.Kind() {