Refact field struct

This commit is contained in:
Jinzhu 2013-11-14 21:38:35 +08:00
parent e4612bde9c
commit c354b0fb06
3 changed files with 79 additions and 71 deletions

8
do.go
View File

@ -640,8 +640,8 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do {
var sqls []string
for _, field := range s.model.fields("migration") {
if len(field.SqlType()) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType())
if len(field.sqlTag()) > 0 {
sqls = append(sqls, field.DbName+" "+field.sqlTag())
}
}
@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do {
s.sqlVars = []interface{}{}
// If column doesn't exist
if len(column_name) == 0 && len(field.SqlType()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType())
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.sqlTag())
s.exec()
}
}

View File

@ -3,6 +3,7 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"time"
"strconv"
@ -12,22 +13,50 @@ import (
)
type Field struct {
Name string
Value interface{}
DbName string
AutoCreateTime bool
AutoUpdateTime bool
IsPrimaryKey bool
IsBlank bool
structField reflect.StructField
Name string
Value interface{}
DbName string
AutoCreateTime bool
AutoUpdateTime bool
IsPrimaryKey bool
structField reflect.StructField
modelValue reflect.Value
beforeAssociation bool
afterAssociation bool
foreignKey string
model *Model
}
func (f *Field) SqlType() string {
func (f *Field) isBlank() bool {
value := reflect.ValueOf(f.Value)
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int() == 0
case reflect.String:
return value.String() == ""
case reflect.Slice:
return value.Len() == 0
case reflect.Struct:
time_value, is_time := f.Value.(time.Time)
if is_time {
return time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
return !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: f.model.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
return true
}
}
}
}
return false
}
func (f *Field) sqlTag() string {
column := getInterfaceValue(f.Value)
field_value := reflect.ValueOf(f.Value)
switch field_value.Kind() {
@ -61,6 +90,35 @@ func (f *Field) SqlType() string {
return typ
}
func (f *Field) parseAssociation() {
field_value := reflect.ValueOf(f.Value)
switch field_value.Kind() {
case reflect.Slice:
foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
case reflect.Struct:
_, is_time := f.Value.(time.Time)
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if !is_scanner && !is_time {
if f.modelValue.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
} else {
foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
}
}
}
}
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str

View File

@ -1,7 +1,6 @@
package gorm
import (
"database/sql"
"errors"
"go/ast"
"reflect"
@ -73,31 +72,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time)
field.model = m
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
field.IsBlank = value.Int() == 0
case reflect.String:
field.IsBlank = value.String() == ""
case reflect.Slice:
field.IsBlank = value.Len() == 0
case reflect.Struct:
if is_time {
field.IsBlank = time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: m.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
}
}
}
}
field.modelValue = indirect_value
if is_time {
field.AutoCreateTime = "created_at" == field.DbName
@ -113,37 +88,10 @@ func (m *Model) fields(operation string) (fields []*Field) {
value.Set(reflect.ValueOf(time.Now()))
}
}
} else {
field_value := reflect.Indirect(value)
switch field_value.Kind() {
case reflect.Slice:
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
case reflect.Struct:
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if !is_scanner {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
} else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
}
}
}
}
field.structField = p
field.Value = value.Interface()
fields = append(fields, &field)
}
}
@ -157,7 +105,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
func (m *Model) columnsHasValue(operation string) (fields []*Field) {
for _, field := range m.fields(operation) {
if !field.IsBlank {
if !field.isBlank() {
fields = append(fields, field)
}
}
@ -199,7 +147,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
if m.data != nil {
for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.SqlType()) > 0) {
if !field.IsPrimaryKey && (len(field.sqlTag()) > 0) {
results[field.DbName] = field.Value
}
}
@ -297,7 +245,8 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
func (m *Model) beforeAssociations() (fields []*Field) {
for _, field := range m.fields("null") {
if field.beforeAssociation && !field.IsBlank {
field.parseAssociation()
if field.beforeAssociation && !field.isBlank() {
fields = append(fields, field)
}
}
@ -306,7 +255,8 @@ func (m *Model) beforeAssociations() (fields []*Field) {
func (m *Model) afterAssociations() (fields []*Field) {
for _, field := range m.fields("null") {
if field.afterAssociation && !field.IsBlank {
field.parseAssociation()
if field.afterAssociation && !field.isBlank() {
fields = append(fields, field)
}
}