mirror of https://github.com/go-gorm/gorm.git
Separate Field Struct
This commit is contained in:
parent
c1c13129f1
commit
e4612bde9c
8
do.go
8
do.go
|
@ -640,8 +640,8 @@ func (s *Do) combinedSql() string {
|
|||
func (s *Do) createTable() *Do {
|
||||
var sqls []string
|
||||
for _, field := range s.model.fields("migration") {
|
||||
if len(field.SqlType) > 0 {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
if len(field.SqlType()) > 0 {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do {
|
|||
s.sqlVars = []interface{}{}
|
||||
|
||||
// If column doesn't exist
|
||||
if len(column_name) == 0 && len(field.SqlType) > 0 {
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType)
|
||||
if len(column_name) == 0 && len(field.SqlType()) > 0 {
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType())
|
||||
s.exec()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Field struct {
|
||||
Name string
|
||||
Value interface{}
|
||||
DbName string
|
||||
AutoCreateTime bool
|
||||
AutoUpdateTime bool
|
||||
IsPrimaryKey bool
|
||||
IsBlank bool
|
||||
structField reflect.StructField
|
||||
|
||||
beforeAssociation bool
|
||||
afterAssociation bool
|
||||
foreignKey string
|
||||
model *Model
|
||||
}
|
||||
|
||||
func (f *Field) SqlType() string {
|
||||
column := getInterfaceValue(f.Value)
|
||||
field_value := reflect.ValueOf(f.Value)
|
||||
switch field_value.Kind() {
|
||||
case reflect.Slice:
|
||||
return ""
|
||||
case reflect.Struct:
|
||||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||
_, is_time := column.(time.Time)
|
||||
if !is_time && !is_scanner {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(tagIdentifier))
|
||||
|
||||
if typ == "-" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(typ) == 0 {
|
||||
if f.IsPrimaryKey {
|
||||
typ = f.model.do.chain.d.dialect.PrimaryKeyTag(column, size)
|
||||
} else {
|
||||
typ = f.model.do.chain.d.dialect.SqlTag(column, size)
|
||||
}
|
||||
}
|
||||
|
||||
if len(addational_typ) > 0 {
|
||||
typ = typ + " " + addational_typ
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
||||
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
|
||||
if str == "-" {
|
||||
typ = str
|
||||
} else if str != "" {
|
||||
tags := strings.Split(str, ";")
|
||||
m := make(map[string]string)
|
||||
for _, value := range tags {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.Trim(strings.ToUpper(v[0]), " ")
|
||||
if len(v) == 2 {
|
||||
m[k] = v[1]
|
||||
} else {
|
||||
m[k] = k
|
||||
}
|
||||
}
|
||||
|
||||
if len(m["SIZE"]) > 0 {
|
||||
size, _ = strconv.Atoi(m["SIZE"])
|
||||
}
|
||||
|
||||
if len(m["TYPE"]) > 0 {
|
||||
typ = m["TYPE"]
|
||||
}
|
||||
|
||||
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getInterfaceValue(column interface{}) interface{} {
|
||||
if v, ok := column.(reflect.Value); ok {
|
||||
column = v.Interface()
|
||||
}
|
||||
|
||||
if valuer, ok := interface{}(column).(driver.Valuer); ok {
|
||||
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
|
||||
}
|
||||
return column
|
||||
}
|
|
@ -351,6 +351,7 @@ func TestComplexWhere(t *testing.T) {
|
|||
for _, user := range users {
|
||||
user_ids = append(user_ids, user.Id)
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
db.Where("id in (?)", user_ids).Find(&users)
|
||||
if len(users) != 3 {
|
||||
|
|
73
model.go
73
model.go
|
@ -3,33 +3,16 @@ package gorm
|
|||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"regexp"
|
||||
|
||||
"time"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
data interface{}
|
||||
do *Do
|
||||
_cache_fields map[string][]Field
|
||||
}
|
||||
|
||||
type Field struct {
|
||||
Name string
|
||||
Value interface{}
|
||||
SqlType string
|
||||
DbName string
|
||||
AutoCreateTime bool
|
||||
AutoUpdateTime bool
|
||||
IsPrimaryKey bool
|
||||
IsBlank bool
|
||||
|
||||
beforeAssociation bool
|
||||
afterAssociation bool
|
||||
foreignKey string
|
||||
_cache_fields map[string][]*Field
|
||||
}
|
||||
|
||||
func (m *Model) primaryKeyZero() bool {
|
||||
|
@ -69,7 +52,7 @@ func (m *Model) primaryKeyDb() string {
|
|||
return toSnake(m.primaryKey())
|
||||
}
|
||||
|
||||
func (m *Model) fields(operation string) (fields []Field) {
|
||||
func (m *Model) fields(operation string) (fields []*Field) {
|
||||
if len(m._cache_fields[operation]) > 0 {
|
||||
return m._cache_fields[operation]
|
||||
}
|
||||
|
@ -89,6 +72,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
|
||||
value := indirect_value.FieldByName(p.Name)
|
||||
time_value, is_time := value.Interface().(time.Time)
|
||||
field.model = m
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||
|
@ -129,14 +113,6 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
field.Value = value.Interface()
|
||||
|
||||
if is_time {
|
||||
field.SqlType = m.getSqlTag(field, p)
|
||||
} else if field.IsPrimaryKey {
|
||||
field.SqlType = m.getSqlTag(field, p)
|
||||
} else {
|
||||
field_value := reflect.Indirect(value)
|
||||
|
||||
|
@ -150,9 +126,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
case reflect.Struct:
|
||||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if is_scanner {
|
||||
field.SqlType = m.getSqlTag(field, p)
|
||||
} else {
|
||||
if !is_scanner {
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
field.beforeAssociation = true
|
||||
|
@ -164,23 +138,24 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.afterAssociation = true
|
||||
}
|
||||
}
|
||||
default:
|
||||
field.SqlType = m.getSqlTag(field, p)
|
||||
}
|
||||
}
|
||||
|
||||
fields = append(fields, field)
|
||||
field.structField = p
|
||||
field.Value = value.Interface()
|
||||
|
||||
fields = append(fields, &field)
|
||||
}
|
||||
}
|
||||
|
||||
if len(m._cache_fields) == 0 {
|
||||
m._cache_fields = map[string][]Field{}
|
||||
m._cache_fields = map[string][]*Field{}
|
||||
}
|
||||
m._cache_fields[operation] = fields
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) columnsHasValue(operation string) (fields []Field) {
|
||||
func (m *Model) columnsHasValue(operation string) (fields []*Field) {
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.IsBlank {
|
||||
fields = append(fields, field)
|
||||
|
@ -224,7 +199,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
|||
|
||||
if m.data != nil {
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
|
||||
if !field.IsPrimaryKey && (len(field.SqlType()) > 0) {
|
||||
results[field.DbName] = field.Value
|
||||
}
|
||||
}
|
||||
|
@ -320,7 +295,7 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
|
|||
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
|
||||
}
|
||||
|
||||
func (m *Model) beforeAssociations() (fields []Field) {
|
||||
func (m *Model) beforeAssociations() (fields []*Field) {
|
||||
for _, field := range m.fields("null") {
|
||||
if field.beforeAssociation && !field.IsBlank {
|
||||
fields = append(fields, field)
|
||||
|
@ -329,7 +304,7 @@ func (m *Model) beforeAssociations() (fields []Field) {
|
|||
return
|
||||
}
|
||||
|
||||
func (m *Model) afterAssociations() (fields []Field) {
|
||||
func (m *Model) afterAssociations() (fields []*Field) {
|
||||
for _, field := range m.fields("null") {
|
||||
if field.afterAssociation && !field.IsBlank {
|
||||
fields = append(fields, field)
|
||||
|
@ -337,25 +312,3 @@ func (m *Model) afterAssociations() (fields []Field) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) getSqlTag(field Field, struct_field reflect.StructField) string {
|
||||
column := getInterfaceValue(field.Value)
|
||||
typ, addational_typ, size := parseSqlTag(struct_field.Tag.Get(tagIdentifier))
|
||||
|
||||
if typ == "-" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(typ) == 0 {
|
||||
if field.IsPrimaryKey {
|
||||
typ = m.do.chain.d.dialect.PrimaryKeyTag(column, size)
|
||||
} else {
|
||||
typ = m.do.chain.d.dialect.SqlTag(column, size)
|
||||
}
|
||||
}
|
||||
|
||||
if len(addational_typ) > 0 {
|
||||
typ = typ + " " + addational_typ
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
|
42
utils.go
42
utils.go
|
@ -3,7 +3,7 @@ package gorm
|
|||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
@ -67,46 +67,6 @@ func getInterfaceAsString(value interface{}) (str string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
|
||||
if str == "-" {
|
||||
typ = str
|
||||
} else if str != "" {
|
||||
tags := strings.Split(str, ";")
|
||||
m := make(map[string]string)
|
||||
for _, value := range tags {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.Trim(strings.ToUpper(v[0]), " ")
|
||||
if len(v) == 2 {
|
||||
m[k] = v[1]
|
||||
} else {
|
||||
m[k] = k
|
||||
}
|
||||
}
|
||||
|
||||
if len(m["SIZE"]) > 0 {
|
||||
size, _ = strconv.Atoi(m["SIZE"])
|
||||
}
|
||||
|
||||
if len(m["TYPE"]) > 0 {
|
||||
typ = m["TYPE"]
|
||||
}
|
||||
|
||||
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getInterfaceValue(column interface{}) interface{} {
|
||||
if v, ok := column.(reflect.Value); ok {
|
||||
column = v.Interface()
|
||||
}
|
||||
|
||||
if valuer, ok := interface{}(column).(driver.Valuer); ok {
|
||||
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
|
||||
}
|
||||
return column
|
||||
}
|
||||
|
||||
func setFieldValue(field reflect.Value, value interface{}) bool {
|
||||
if field.IsValid() && field.CanAddr() {
|
||||
switch field.Kind() {
|
||||
|
|
Loading…
Reference in New Issue