forked from mirror/gorm
Separate Field Struct
This commit is contained in:
parent
c1c13129f1
commit
e4612bde9c
8
do.go
8
do.go
|
@ -640,8 +640,8 @@ func (s *Do) combinedSql() string {
|
||||||
func (s *Do) createTable() *Do {
|
func (s *Do) createTable() *Do {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range s.model.fields("migration") {
|
for _, field := range s.model.fields("migration") {
|
||||||
if len(field.SqlType) > 0 {
|
if len(field.SqlType()) > 0 {
|
||||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
sqls = append(sqls, field.DbName+" "+field.SqlType())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do {
|
||||||
s.sqlVars = []interface{}{}
|
s.sqlVars = []interface{}{}
|
||||||
|
|
||||||
// If column doesn't exist
|
// If column doesn't exist
|
||||||
if len(column_name) == 0 && len(field.SqlType) > 0 {
|
if len(column_name) == 0 && len(field.SqlType()) > 0 {
|
||||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType)
|
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType())
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Field struct {
|
||||||
|
Name string
|
||||||
|
Value interface{}
|
||||||
|
DbName string
|
||||||
|
AutoCreateTime bool
|
||||||
|
AutoUpdateTime bool
|
||||||
|
IsPrimaryKey bool
|
||||||
|
IsBlank bool
|
||||||
|
structField reflect.StructField
|
||||||
|
|
||||||
|
beforeAssociation bool
|
||||||
|
afterAssociation bool
|
||||||
|
foreignKey string
|
||||||
|
model *Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Field) SqlType() string {
|
||||||
|
column := getInterfaceValue(f.Value)
|
||||||
|
field_value := reflect.ValueOf(f.Value)
|
||||||
|
switch field_value.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
return ""
|
||||||
|
case reflect.Struct:
|
||||||
|
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||||
|
_, is_time := column.(time.Time)
|
||||||
|
if !is_time && !is_scanner {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(tagIdentifier))
|
||||||
|
|
||||||
|
if typ == "-" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(typ) == 0 {
|
||||||
|
if f.IsPrimaryKey {
|
||||||
|
typ = f.model.do.chain.d.dialect.PrimaryKeyTag(column, size)
|
||||||
|
} else {
|
||||||
|
typ = f.model.do.chain.d.dialect.SqlTag(column, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addational_typ) > 0 {
|
||||||
|
typ = typ + " " + addational_typ
|
||||||
|
}
|
||||||
|
return typ
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
|
||||||
|
if str == "-" {
|
||||||
|
typ = str
|
||||||
|
} else if str != "" {
|
||||||
|
tags := strings.Split(str, ";")
|
||||||
|
m := make(map[string]string)
|
||||||
|
for _, value := range tags {
|
||||||
|
v := strings.Split(value, ":")
|
||||||
|
k := strings.Trim(strings.ToUpper(v[0]), " ")
|
||||||
|
if len(v) == 2 {
|
||||||
|
m[k] = v[1]
|
||||||
|
} else {
|
||||||
|
m[k] = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m["SIZE"]) > 0 {
|
||||||
|
size, _ = strconv.Atoi(m["SIZE"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m["TYPE"]) > 0 {
|
||||||
|
typ = m["TYPE"]
|
||||||
|
}
|
||||||
|
|
||||||
|
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInterfaceValue(column interface{}) interface{} {
|
||||||
|
if v, ok := column.(reflect.Value); ok {
|
||||||
|
column = v.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
if valuer, ok := interface{}(column).(driver.Valuer); ok {
|
||||||
|
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
|
||||||
|
}
|
||||||
|
return column
|
||||||
|
}
|
|
@ -351,6 +351,7 @@ func TestComplexWhere(t *testing.T) {
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
user_ids = append(user_ids, user.Id)
|
user_ids = append(user_ids, user.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
users = []User{}
|
users = []User{}
|
||||||
db.Where("id in (?)", user_ids).Find(&users)
|
db.Where("id in (?)", user_ids).Find(&users)
|
||||||
if len(users) != 3 {
|
if len(users) != 3 {
|
||||||
|
|
73
model.go
73
model.go
|
@ -3,33 +3,16 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
data interface{}
|
data interface{}
|
||||||
do *Do
|
do *Do
|
||||||
_cache_fields map[string][]Field
|
_cache_fields map[string][]*Field
|
||||||
}
|
|
||||||
|
|
||||||
type Field struct {
|
|
||||||
Name string
|
|
||||||
Value interface{}
|
|
||||||
SqlType string
|
|
||||||
DbName string
|
|
||||||
AutoCreateTime bool
|
|
||||||
AutoUpdateTime bool
|
|
||||||
IsPrimaryKey bool
|
|
||||||
IsBlank bool
|
|
||||||
|
|
||||||
beforeAssociation bool
|
|
||||||
afterAssociation bool
|
|
||||||
foreignKey string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) primaryKeyZero() bool {
|
func (m *Model) primaryKeyZero() bool {
|
||||||
|
@ -69,7 +52,7 @@ func (m *Model) primaryKeyDb() string {
|
||||||
return toSnake(m.primaryKey())
|
return toSnake(m.primaryKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) fields(operation string) (fields []Field) {
|
func (m *Model) fields(operation string) (fields []*Field) {
|
||||||
if len(m._cache_fields[operation]) > 0 {
|
if len(m._cache_fields[operation]) > 0 {
|
||||||
return m._cache_fields[operation]
|
return m._cache_fields[operation]
|
||||||
}
|
}
|
||||||
|
@ -89,6 +72,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
|
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
|
||||||
value := indirect_value.FieldByName(p.Name)
|
value := indirect_value.FieldByName(p.Name)
|
||||||
time_value, is_time := value.Interface().(time.Time)
|
time_value, is_time := value.Interface().(time.Time)
|
||||||
|
field.model = m
|
||||||
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Int, reflect.Int64, reflect.Int32:
|
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||||
|
@ -129,14 +113,6 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
field.Value = value.Interface()
|
|
||||||
|
|
||||||
if is_time {
|
|
||||||
field.SqlType = m.getSqlTag(field, p)
|
|
||||||
} else if field.IsPrimaryKey {
|
|
||||||
field.SqlType = m.getSqlTag(field, p)
|
|
||||||
} else {
|
} else {
|
||||||
field_value := reflect.Indirect(value)
|
field_value := reflect.Indirect(value)
|
||||||
|
|
||||||
|
@ -150,9 +126,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||||
|
|
||||||
if is_scanner {
|
if !is_scanner {
|
||||||
field.SqlType = m.getSqlTag(field, p)
|
|
||||||
} else {
|
|
||||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||||
field.foreignKey = p.Name + "Id"
|
field.foreignKey = p.Name + "Id"
|
||||||
field.beforeAssociation = true
|
field.beforeAssociation = true
|
||||||
|
@ -164,23 +138,24 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
field.afterAssociation = true
|
field.afterAssociation = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
field.SqlType = m.getSqlTag(field, p)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fields = append(fields, field)
|
field.structField = p
|
||||||
|
field.Value = value.Interface()
|
||||||
|
|
||||||
|
fields = append(fields, &field)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m._cache_fields) == 0 {
|
if len(m._cache_fields) == 0 {
|
||||||
m._cache_fields = map[string][]Field{}
|
m._cache_fields = map[string][]*Field{}
|
||||||
}
|
}
|
||||||
m._cache_fields[operation] = fields
|
m._cache_fields[operation] = fields
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) columnsHasValue(operation string) (fields []Field) {
|
func (m *Model) columnsHasValue(operation string) (fields []*Field) {
|
||||||
for _, field := range m.fields(operation) {
|
for _, field := range m.fields(operation) {
|
||||||
if !field.IsBlank {
|
if !field.IsBlank {
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
|
@ -224,7 +199,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
||||||
|
|
||||||
if m.data != nil {
|
if m.data != nil {
|
||||||
for _, field := range m.fields(operation) {
|
for _, field := range m.fields(operation) {
|
||||||
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
|
if !field.IsPrimaryKey && (len(field.SqlType()) > 0) {
|
||||||
results[field.DbName] = field.Value
|
results[field.DbName] = field.Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -320,7 +295,7 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
|
||||||
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
|
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) beforeAssociations() (fields []Field) {
|
func (m *Model) beforeAssociations() (fields []*Field) {
|
||||||
for _, field := range m.fields("null") {
|
for _, field := range m.fields("null") {
|
||||||
if field.beforeAssociation && !field.IsBlank {
|
if field.beforeAssociation && !field.IsBlank {
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
|
@ -329,7 +304,7 @@ func (m *Model) beforeAssociations() (fields []Field) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) afterAssociations() (fields []Field) {
|
func (m *Model) afterAssociations() (fields []*Field) {
|
||||||
for _, field := range m.fields("null") {
|
for _, field := range m.fields("null") {
|
||||||
if field.afterAssociation && !field.IsBlank {
|
if field.afterAssociation && !field.IsBlank {
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
|
@ -337,25 +312,3 @@ func (m *Model) afterAssociations() (fields []Field) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) getSqlTag(field Field, struct_field reflect.StructField) string {
|
|
||||||
column := getInterfaceValue(field.Value)
|
|
||||||
typ, addational_typ, size := parseSqlTag(struct_field.Tag.Get(tagIdentifier))
|
|
||||||
|
|
||||||
if typ == "-" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(typ) == 0 {
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
typ = m.do.chain.d.dialect.PrimaryKeyTag(column, size)
|
|
||||||
} else {
|
|
||||||
typ = m.do.chain.d.dialect.SqlTag(column, size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(addational_typ) > 0 {
|
|
||||||
typ = typ + " " + addational_typ
|
|
||||||
}
|
|
||||||
return typ
|
|
||||||
}
|
|
||||||
|
|
42
utils.go
42
utils.go
|
@ -3,7 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
|
||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -67,46 +67,6 @@ func getInterfaceAsString(value interface{}) (str string, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
|
|
||||||
if str == "-" {
|
|
||||||
typ = str
|
|
||||||
} else if str != "" {
|
|
||||||
tags := strings.Split(str, ";")
|
|
||||||
m := make(map[string]string)
|
|
||||||
for _, value := range tags {
|
|
||||||
v := strings.Split(value, ":")
|
|
||||||
k := strings.Trim(strings.ToUpper(v[0]), " ")
|
|
||||||
if len(v) == 2 {
|
|
||||||
m[k] = v[1]
|
|
||||||
} else {
|
|
||||||
m[k] = k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m["SIZE"]) > 0 {
|
|
||||||
size, _ = strconv.Atoi(m["SIZE"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m["TYPE"]) > 0 {
|
|
||||||
typ = m["TYPE"]
|
|
||||||
}
|
|
||||||
|
|
||||||
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getInterfaceValue(column interface{}) interface{} {
|
|
||||||
if v, ok := column.(reflect.Value); ok {
|
|
||||||
column = v.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
if valuer, ok := interface{}(column).(driver.Valuer); ok {
|
|
||||||
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
|
|
||||||
}
|
|
||||||
return column
|
|
||||||
}
|
|
||||||
|
|
||||||
func setFieldValue(field reflect.Value, value interface{}) bool {
|
func setFieldValue(field reflect.Value, value interface{}) bool {
|
||||||
if field.IsValid() && field.CanAddr() {
|
if field.IsValid() && field.CanAddr() {
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
|
|
Loading…
Reference in New Issue