mirror of https://github.com/go-gorm/gorm.git
Refact field struct
This commit is contained in:
parent
e4612bde9c
commit
c354b0fb06
8
do.go
8
do.go
|
@ -640,8 +640,8 @@ func (s *Do) combinedSql() string {
|
|||
func (s *Do) createTable() *Do {
|
||||
var sqls []string
|
||||
for _, field := range s.model.fields("migration") {
|
||||
if len(field.SqlType()) > 0 {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType())
|
||||
if len(field.sqlTag()) > 0 {
|
||||
sqls = append(sqls, field.DbName+" "+field.sqlTag())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do {
|
|||
s.sqlVars = []interface{}{}
|
||||
|
||||
// If column doesn't exist
|
||||
if len(column_name) == 0 && len(field.SqlType()) > 0 {
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType())
|
||||
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.sqlTag())
|
||||
s.exec()
|
||||
}
|
||||
}
|
||||
|
|
78
field.go
78
field.go
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
|
||||
"time"
|
||||
|
||||
"strconv"
|
||||
|
@ -12,22 +13,50 @@ import (
|
|||
)
|
||||
|
||||
type Field struct {
|
||||
Name string
|
||||
Value interface{}
|
||||
DbName string
|
||||
AutoCreateTime bool
|
||||
AutoUpdateTime bool
|
||||
IsPrimaryKey bool
|
||||
IsBlank bool
|
||||
structField reflect.StructField
|
||||
|
||||
Name string
|
||||
Value interface{}
|
||||
DbName string
|
||||
AutoCreateTime bool
|
||||
AutoUpdateTime bool
|
||||
IsPrimaryKey bool
|
||||
structField reflect.StructField
|
||||
modelValue reflect.Value
|
||||
beforeAssociation bool
|
||||
afterAssociation bool
|
||||
foreignKey string
|
||||
model *Model
|
||||
}
|
||||
|
||||
func (f *Field) SqlType() string {
|
||||
func (f *Field) isBlank() bool {
|
||||
value := reflect.ValueOf(f.Value)
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||
return value.Int() == 0
|
||||
case reflect.String:
|
||||
return value.String() == ""
|
||||
case reflect.Slice:
|
||||
return value.Len() == 0
|
||||
case reflect.Struct:
|
||||
time_value, is_time := f.Value.(time.Time)
|
||||
if is_time {
|
||||
return time_value.IsZero()
|
||||
} else {
|
||||
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
|
||||
if is_scanner {
|
||||
return !value.FieldByName("Valid").Interface().(bool)
|
||||
} else {
|
||||
m := &Model{data: value.Interface(), do: f.model.do}
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *Field) sqlTag() string {
|
||||
column := getInterfaceValue(f.Value)
|
||||
field_value := reflect.ValueOf(f.Value)
|
||||
switch field_value.Kind() {
|
||||
|
@ -61,6 +90,35 @@ func (f *Field) SqlType() string {
|
|||
return typ
|
||||
}
|
||||
|
||||
func (f *Field) parseAssociation() {
|
||||
field_value := reflect.ValueOf(f.Value)
|
||||
|
||||
switch field_value.Kind() {
|
||||
case reflect.Slice:
|
||||
foreign_key := f.model.typeName() + "Id"
|
||||
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
f.foreignKey = foreign_key
|
||||
}
|
||||
f.afterAssociation = true
|
||||
case reflect.Struct:
|
||||
_, is_time := f.Value.(time.Time)
|
||||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if !is_scanner && !is_time {
|
||||
if f.modelValue.FieldByName(f.Name + "Id").IsValid() {
|
||||
f.foreignKey = f.Name + "Id"
|
||||
f.beforeAssociation = true
|
||||
} else {
|
||||
foreign_key := f.model.typeName() + "Id"
|
||||
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
f.foreignKey = foreign_key
|
||||
}
|
||||
f.afterAssociation = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
|
||||
if str == "-" {
|
||||
typ = str
|
||||
|
|
64
model.go
64
model.go
|
@ -1,7 +1,6 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
|
@ -73,31 +72,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
|
|||
value := indirect_value.FieldByName(p.Name)
|
||||
time_value, is_time := value.Interface().(time.Time)
|
||||
field.model = m
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||
field.IsBlank = value.Int() == 0
|
||||
case reflect.String:
|
||||
field.IsBlank = value.String() == ""
|
||||
case reflect.Slice:
|
||||
field.IsBlank = value.Len() == 0
|
||||
case reflect.Struct:
|
||||
if is_time {
|
||||
field.IsBlank = time_value.IsZero()
|
||||
} else {
|
||||
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if is_scanner {
|
||||
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
||||
} else {
|
||||
m := &Model{data: value.Interface(), do: m.do}
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
field.IsBlank = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
field.modelValue = indirect_value
|
||||
|
||||
if is_time {
|
||||
field.AutoCreateTime = "created_at" == field.DbName
|
||||
|
@ -113,37 +88,10 @@ func (m *Model) fields(operation string) (fields []*Field) {
|
|||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field_value := reflect.Indirect(value)
|
||||
|
||||
switch field_value.Kind() {
|
||||
case reflect.Slice:
|
||||
foreign_key := typ.Name() + "Id"
|
||||
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
field.foreignKey = foreign_key
|
||||
}
|
||||
field.afterAssociation = true
|
||||
case reflect.Struct:
|
||||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if !is_scanner {
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
field.beforeAssociation = true
|
||||
} else {
|
||||
foreign_key := typ.Name() + "Id"
|
||||
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
field.foreignKey = foreign_key
|
||||
}
|
||||
field.afterAssociation = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
field.structField = p
|
||||
field.Value = value.Interface()
|
||||
|
||||
fields = append(fields, &field)
|
||||
}
|
||||
}
|
||||
|
@ -157,7 +105,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
|
|||
|
||||
func (m *Model) columnsHasValue(operation string) (fields []*Field) {
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.IsBlank {
|
||||
if !field.isBlank() {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
|
@ -199,7 +147,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
|||
|
||||
if m.data != nil {
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.IsPrimaryKey && (len(field.SqlType()) > 0) {
|
||||
if !field.IsPrimaryKey && (len(field.sqlTag()) > 0) {
|
||||
results[field.DbName] = field.Value
|
||||
}
|
||||
}
|
||||
|
@ -297,7 +245,8 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
|
|||
|
||||
func (m *Model) beforeAssociations() (fields []*Field) {
|
||||
for _, field := range m.fields("null") {
|
||||
if field.beforeAssociation && !field.IsBlank {
|
||||
field.parseAssociation()
|
||||
if field.beforeAssociation && !field.isBlank() {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
|
@ -306,7 +255,8 @@ func (m *Model) beforeAssociations() (fields []*Field) {
|
|||
|
||||
func (m *Model) afterAssociations() (fields []*Field) {
|
||||
for _, field := range m.fields("null") {
|
||||
if field.afterAssociation && !field.IsBlank {
|
||||
field.parseAssociation()
|
||||
if field.afterAssociation && !field.isBlank() {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue