From bae0799bd8e56d8f3097577afb3fcbd8d99a895d Mon Sep 17 00:00:00 2001 From: Rob Rodriguez Date: Wed, 19 Apr 2017 00:21:56 -0700 Subject: [PATCH] Adding better binary type support for common SQL dialects --- dialect_common.go | 5 +++++ dialect_mysql.go | 2 +- dialect_postgres.go | 6 +----- dialect_sqlite3.go | 2 +- dialects/mssql/mssql.go | 12 ++++++------ 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index 1554151c..abe7532d 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -149,3 +149,8 @@ func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") return keyName } + +// IsByteArrayOrSlice returns true of the reflected value is an array or slice +func IsByteArrayOrSlice(value reflect.Value) bool { + return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) +} diff --git a/dialect_mysql.go b/dialect_mysql.go index b471a162..fa63e982 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -87,7 +87,7 @@ func (s *mysql) DataTypeOf(field *StructField) string { } } default: - if _, ok := dataValue.Interface().([]byte); ok { + if IsByteArrayOrSlice(dataValue) { if size > 0 && size < 65532 { sqlType = fmt.Sprintf("varbinary(%d)", size) } else { diff --git a/dialect_postgres.go b/dialect_postgres.go index 7d07a02c..b9161f68 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -65,7 +65,7 @@ func (s *postgres) DataTypeOf(field *StructField) string { sqlType = "hstore" } default: - if isByteArrayOrSlice(dataValue) { + if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" } else if isUUID(dataValue) { sqlType = "uuid" @@ -120,10 +120,6 @@ func (postgres) SupportLastInsertID() bool { return false } -func isByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} - func isUUID(value reflect.Value) bool { if value.Kind() != reflect.Array || value.Type().Len() != 16 { return false diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 46edea0c..de9c05cb 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -54,7 +54,7 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { sqlType = "datetime" } default: - if _, ok := dataValue.Interface().([]byte); ok { + if IsByteArrayOrSlice(dataValue) { sqlType = "blob" } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 7541b222..eb810cfa 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -58,21 +58,21 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { case reflect.Float32, reflect.Float64: sqlType = "float" case reflect.String: - if size > 0 && size < 65532 { + if size > 0 && size < 8000 { sqlType = fmt.Sprintf("nvarchar(%d)", size) } else { - sqlType = "text" + sqlType = "nvarchar(max)" } case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { sqlType = "datetime2" } default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) + if gorm.IsByteArrayOrSlice(dataValue) { + if size > 0 && size < 8000 { + sqlType = fmt.Sprintf("varbinary(%d)", size) } else { - sqlType = "text" + sqlType = "varbinary(max)" } } }