Separate Field Struct

This commit is contained in:
Jinzhu 2013-11-14 21:26:02 +08:00
parent c1c13129f1
commit e4612bde9c
5 changed files with 121 additions and 105 deletions

8
do.go
View File

@ -640,8 +640,8 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do { func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("migration") { for _, field := range s.model.fields("migration") {
if len(field.SqlType) > 0 { if len(field.SqlType()) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType) sqls = append(sqls, field.DbName+" "+field.SqlType())
} }
} }
@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do {
s.sqlVars = []interface{}{} s.sqlVars = []interface{}{}
// If column doesn't exist // If column doesn't exist
if len(column_name) == 0 && len(field.SqlType) > 0 { if len(column_name) == 0 && len(field.SqlType()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType) s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType())
s.exec() s.exec()
} }
} }

102
field.go Normal file
View File

@ -0,0 +1,102 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"time"
"strconv"
"strings"
"reflect"
)
type Field struct {
Name string
Value interface{}
DbName string
AutoCreateTime bool
AutoUpdateTime bool
IsPrimaryKey bool
IsBlank bool
structField reflect.StructField
beforeAssociation bool
afterAssociation bool
foreignKey string
model *Model
}
func (f *Field) SqlType() string {
column := getInterfaceValue(f.Value)
field_value := reflect.ValueOf(f.Value)
switch field_value.Kind() {
case reflect.Slice:
return ""
case reflect.Struct:
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
_, is_time := column.(time.Time)
if !is_time && !is_scanner {
return ""
}
}
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(tagIdentifier))
if typ == "-" {
return ""
}
if len(typ) == 0 {
if f.IsPrimaryKey {
typ = f.model.do.chain.d.dialect.PrimaryKeyTag(column, size)
} else {
typ = f.model.do.chain.d.dialect.SqlTag(column, size)
}
}
if len(addational_typ) > 0 {
typ = typ + " " + addational_typ
}
return typ
}
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.Trim(strings.ToUpper(v[0]), " ")
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
}
return
}
func getInterfaceValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}

View File

@ -351,6 +351,7 @@ func TestComplexWhere(t *testing.T) {
for _, user := range users { for _, user := range users {
user_ids = append(user_ids, user.Id) user_ids = append(user_ids, user.Id)
} }
users = []User{} users = []User{}
db.Where("id in (?)", user_ids).Find(&users) db.Where("id in (?)", user_ids).Find(&users)
if len(users) != 3 { if len(users) != 3 {

View File

@ -3,33 +3,16 @@ package gorm
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"go/ast" "go/ast"
"reflect" "reflect"
"regexp" "regexp"
"time" "time"
) )
type Model struct { type Model struct {
data interface{} data interface{}
do *Do do *Do
_cache_fields map[string][]Field _cache_fields map[string][]*Field
}
type Field struct {
Name string
Value interface{}
SqlType string
DbName string
AutoCreateTime bool
AutoUpdateTime bool
IsPrimaryKey bool
IsBlank bool
beforeAssociation bool
afterAssociation bool
foreignKey string
} }
func (m *Model) primaryKeyZero() bool { func (m *Model) primaryKeyZero() bool {
@ -69,7 +52,7 @@ func (m *Model) primaryKeyDb() string {
return toSnake(m.primaryKey()) return toSnake(m.primaryKey())
} }
func (m *Model) fields(operation string) (fields []Field) { func (m *Model) fields(operation string) (fields []*Field) {
if len(m._cache_fields[operation]) > 0 { if len(m._cache_fields[operation]) > 0 {
return m._cache_fields[operation] return m._cache_fields[operation]
} }
@ -89,6 +72,7 @@ func (m *Model) fields(operation string) (fields []Field) {
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
value := indirect_value.FieldByName(p.Name) value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time) time_value, is_time := value.Interface().(time.Time)
field.model = m
switch value.Kind() { switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32: case reflect.Int, reflect.Int64, reflect.Int32:
@ -129,14 +113,6 @@ func (m *Model) fields(operation string) (fields []Field) {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
} }
}
field.Value = value.Interface()
if is_time {
field.SqlType = m.getSqlTag(field, p)
} else if field.IsPrimaryKey {
field.SqlType = m.getSqlTag(field, p)
} else { } else {
field_value := reflect.Indirect(value) field_value := reflect.Indirect(value)
@ -150,9 +126,7 @@ func (m *Model) fields(operation string) (fields []Field) {
case reflect.Struct: case reflect.Struct:
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if is_scanner { if !is_scanner {
field.SqlType = m.getSqlTag(field, p)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() { if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id" field.foreignKey = p.Name + "Id"
field.beforeAssociation = true field.beforeAssociation = true
@ -164,23 +138,24 @@ func (m *Model) fields(operation string) (fields []Field) {
field.afterAssociation = true field.afterAssociation = true
} }
} }
default:
field.SqlType = m.getSqlTag(field, p)
} }
} }
fields = append(fields, field) field.structField = p
field.Value = value.Interface()
fields = append(fields, &field)
} }
} }
if len(m._cache_fields) == 0 { if len(m._cache_fields) == 0 {
m._cache_fields = map[string][]Field{} m._cache_fields = map[string][]*Field{}
} }
m._cache_fields[operation] = fields m._cache_fields[operation] = fields
return return
} }
func (m *Model) columnsHasValue(operation string) (fields []Field) { func (m *Model) columnsHasValue(operation string) (fields []*Field) {
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsBlank { if !field.IsBlank {
fields = append(fields, field) fields = append(fields, field)
@ -224,7 +199,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
if m.data != nil { if m.data != nil {
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.SqlType) > 0) { if !field.IsPrimaryKey && (len(field.SqlType()) > 0) {
results[field.DbName] = field.Value results[field.DbName] = field.Value
} }
} }
@ -320,7 +295,7 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value) setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
} }
func (m *Model) beforeAssociations() (fields []Field) { func (m *Model) beforeAssociations() (fields []*Field) {
for _, field := range m.fields("null") { for _, field := range m.fields("null") {
if field.beforeAssociation && !field.IsBlank { if field.beforeAssociation && !field.IsBlank {
fields = append(fields, field) fields = append(fields, field)
@ -329,7 +304,7 @@ func (m *Model) beforeAssociations() (fields []Field) {
return return
} }
func (m *Model) afterAssociations() (fields []Field) { func (m *Model) afterAssociations() (fields []*Field) {
for _, field := range m.fields("null") { for _, field := range m.fields("null") {
if field.afterAssociation && !field.IsBlank { if field.afterAssociation && !field.IsBlank {
fields = append(fields, field) fields = append(fields, field)
@ -337,25 +312,3 @@ func (m *Model) afterAssociations() (fields []Field) {
} }
return return
} }
func (m *Model) getSqlTag(field Field, struct_field reflect.StructField) string {
column := getInterfaceValue(field.Value)
typ, addational_typ, size := parseSqlTag(struct_field.Tag.Get(tagIdentifier))
if typ == "-" {
return ""
}
if len(typ) == 0 {
if field.IsPrimaryKey {
typ = m.do.chain.d.dialect.PrimaryKeyTag(column, size)
} else {
typ = m.do.chain.d.dialect.SqlTag(column, size)
}
}
if len(addational_typ) > 0 {
typ = typ + " " + addational_typ
}
return typ
}

View File

@ -3,7 +3,7 @@ package gorm
import ( import (
"bytes" "bytes"
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"reflect" "reflect"
"strconv" "strconv"
@ -67,46 +67,6 @@ func getInterfaceAsString(value interface{}) (str string, err error) {
return return
} }
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.Trim(strings.ToUpper(v[0]), " ")
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
}
return
}
func getInterfaceValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}
func setFieldValue(field reflect.Value, value interface{}) bool { func setFieldValue(field reflect.Value, value interface{}) bool {
if field.IsValid() && field.CanAddr() { if field.IsValid() && field.CanAddr() {
switch field.Kind() { switch field.Kind() {