From e4612bde9c667655484956a8670c0526d218cbe0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 Nov 2013 21:26:02 +0800 Subject: [PATCH] Separate Field Struct --- do.go | 8 ++-- field.go | 102 +++++++++++++++++++++++++++++++++++++++++++++++++++ gorm_test.go | 1 + model.go | 73 +++++++----------------------------- utils.go | 42 +-------------------- 5 files changed, 121 insertions(+), 105 deletions(-) create mode 100644 field.go diff --git a/do.go b/do.go index 1bcc0d17..b84c5f0c 100644 --- a/do.go +++ b/do.go @@ -640,8 +640,8 @@ func (s *Do) combinedSql() string { func (s *Do) createTable() *Do { var sqls []string for _, field := range s.model.fields("migration") { - if len(field.SqlType) > 0 { - sqls = append(sqls, field.DbName+" "+field.SqlType) + if len(field.SqlType()) > 0 { + sqls = append(sqls, field.DbName+" "+field.SqlType()) } } @@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do { s.sqlVars = []interface{}{} // If column doesn't exist - if len(column_name) == 0 && len(field.SqlType) > 0 { - s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType) + if len(column_name) == 0 && len(field.SqlType()) > 0 { + s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType()) s.exec() } } diff --git a/field.go b/field.go new file mode 100644 index 00000000..01e9d22e --- /dev/null +++ b/field.go @@ -0,0 +1,102 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "time" + + "strconv" + "strings" + + "reflect" +) + +type Field struct { + Name string + Value interface{} + DbName string + AutoCreateTime bool + AutoUpdateTime bool + IsPrimaryKey bool + IsBlank bool + structField reflect.StructField + + beforeAssociation bool + afterAssociation bool + foreignKey string + model *Model +} + +func (f *Field) SqlType() string { + column := getInterfaceValue(f.Value) + field_value := reflect.ValueOf(f.Value) + switch field_value.Kind() { + case reflect.Slice: + return "" + case reflect.Struct: + _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) + _, is_time := column.(time.Time) + if !is_time && !is_scanner { + return "" + } + } + + typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(tagIdentifier)) + + if typ == "-" { + return "" + } + + if len(typ) == 0 { + if f.IsPrimaryKey { + typ = f.model.do.chain.d.dialect.PrimaryKeyTag(column, size) + } else { + typ = f.model.do.chain.d.dialect.SqlTag(column, size) + } + } + + if len(addational_typ) > 0 { + typ = typ + " " + addational_typ + } + return typ +} + +func parseSqlTag(str string) (typ string, addational_typ string, size int) { + if str == "-" { + typ = str + } else if str != "" { + tags := strings.Split(str, ";") + m := make(map[string]string) + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.Trim(strings.ToUpper(v[0]), " ") + if len(v) == 2 { + m[k] = v[1] + } else { + m[k] = k + } + } + + if len(m["SIZE"]) > 0 { + size, _ = strconv.Atoi(m["SIZE"]) + } + + if len(m["TYPE"]) > 0 { + typ = m["TYPE"] + } + + addational_typ = m["NOT NULL"] + " " + m["UNIQUE"] + } + return +} + +func getInterfaceValue(column interface{}) interface{} { + if v, ok := column.(reflect.Value); ok { + column = v.Interface() + } + + if valuer, ok := interface{}(column).(driver.Valuer); ok { + column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface() + } + return column +} diff --git a/gorm_test.go b/gorm_test.go index 03db697e..092761da 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -351,6 +351,7 @@ func TestComplexWhere(t *testing.T) { for _, user := range users { user_ids = append(user_ids, user.Id) } + users = []User{} db.Where("id in (?)", user_ids).Find(&users) if len(users) != 3 { diff --git a/model.go b/model.go index 011d1bb8..7fa58950 100644 --- a/model.go +++ b/model.go @@ -3,33 +3,16 @@ package gorm import ( "database/sql" "errors" - "go/ast" "reflect" "regexp" - "time" ) type Model struct { data interface{} do *Do - _cache_fields map[string][]Field -} - -type Field struct { - Name string - Value interface{} - SqlType string - DbName string - AutoCreateTime bool - AutoUpdateTime bool - IsPrimaryKey bool - IsBlank bool - - beforeAssociation bool - afterAssociation bool - foreignKey string + _cache_fields map[string][]*Field } func (m *Model) primaryKeyZero() bool { @@ -69,7 +52,7 @@ func (m *Model) primaryKeyDb() string { return toSnake(m.primaryKey()) } -func (m *Model) fields(operation string) (fields []Field) { +func (m *Model) fields(operation string) (fields []*Field) { if len(m._cache_fields[operation]) > 0 { return m._cache_fields[operation] } @@ -89,6 +72,7 @@ func (m *Model) fields(operation string) (fields []Field) { field.IsPrimaryKey = m.primaryKeyDb() == field.DbName value := indirect_value.FieldByName(p.Name) time_value, is_time := value.Interface().(time.Time) + field.model = m switch value.Kind() { case reflect.Int, reflect.Int64, reflect.Int32: @@ -129,14 +113,6 @@ func (m *Model) fields(operation string) (fields []Field) { value.Set(reflect.ValueOf(time.Now())) } } - } - - field.Value = value.Interface() - - if is_time { - field.SqlType = m.getSqlTag(field, p) - } else if field.IsPrimaryKey { - field.SqlType = m.getSqlTag(field, p) } else { field_value := reflect.Indirect(value) @@ -150,9 +126,7 @@ func (m *Model) fields(operation string) (fields []Field) { case reflect.Struct: _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) - if is_scanner { - field.SqlType = m.getSqlTag(field, p) - } else { + if !is_scanner { if indirect_value.FieldByName(p.Name + "Id").IsValid() { field.foreignKey = p.Name + "Id" field.beforeAssociation = true @@ -164,23 +138,24 @@ func (m *Model) fields(operation string) (fields []Field) { field.afterAssociation = true } } - default: - field.SqlType = m.getSqlTag(field, p) } } - fields = append(fields, field) + field.structField = p + field.Value = value.Interface() + + fields = append(fields, &field) } } if len(m._cache_fields) == 0 { - m._cache_fields = map[string][]Field{} + m._cache_fields = map[string][]*Field{} } m._cache_fields[operation] = fields return } -func (m *Model) columnsHasValue(operation string) (fields []Field) { +func (m *Model) columnsHasValue(operation string) (fields []*Field) { for _, field := range m.fields(operation) { if !field.IsBlank { fields = append(fields, field) @@ -224,7 +199,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} { if m.data != nil { for _, field := range m.fields(operation) { - if !field.IsPrimaryKey && (len(field.SqlType) > 0) { + if !field.IsPrimaryKey && (len(field.SqlType()) > 0) { results[field.DbName] = field.Value } } @@ -320,7 +295,7 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{} setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value) } -func (m *Model) beforeAssociations() (fields []Field) { +func (m *Model) beforeAssociations() (fields []*Field) { for _, field := range m.fields("null") { if field.beforeAssociation && !field.IsBlank { fields = append(fields, field) @@ -329,7 +304,7 @@ func (m *Model) beforeAssociations() (fields []Field) { return } -func (m *Model) afterAssociations() (fields []Field) { +func (m *Model) afterAssociations() (fields []*Field) { for _, field := range m.fields("null") { if field.afterAssociation && !field.IsBlank { fields = append(fields, field) @@ -337,25 +312,3 @@ func (m *Model) afterAssociations() (fields []Field) { } return } - -func (m *Model) getSqlTag(field Field, struct_field reflect.StructField) string { - column := getInterfaceValue(field.Value) - typ, addational_typ, size := parseSqlTag(struct_field.Tag.Get(tagIdentifier)) - - if typ == "-" { - return "" - } - - if len(typ) == 0 { - if field.IsPrimaryKey { - typ = m.do.chain.d.dialect.PrimaryKeyTag(column, size) - } else { - typ = m.do.chain.d.dialect.SqlTag(column, size) - } - } - - if len(addational_typ) > 0 { - typ = typ + " " + addational_typ - } - return typ -} diff --git a/utils.go b/utils.go index 4db884e6..3ea863d4 100644 --- a/utils.go +++ b/utils.go @@ -3,7 +3,7 @@ package gorm import ( "bytes" "database/sql" - "database/sql/driver" + "errors" "reflect" "strconv" @@ -67,46 +67,6 @@ func getInterfaceAsString(value interface{}) (str string, err error) { return } -func parseSqlTag(str string) (typ string, addational_typ string, size int) { - if str == "-" { - typ = str - } else if str != "" { - tags := strings.Split(str, ";") - m := make(map[string]string) - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.Trim(strings.ToUpper(v[0]), " ") - if len(v) == 2 { - m[k] = v[1] - } else { - m[k] = k - } - } - - if len(m["SIZE"]) > 0 { - size, _ = strconv.Atoi(m["SIZE"]) - } - - if len(m["TYPE"]) > 0 { - typ = m["TYPE"] - } - - addational_typ = m["NOT NULL"] + " " + m["UNIQUE"] - } - return -} - -func getInterfaceValue(column interface{}) interface{} { - if v, ok := column.(reflect.Value); ok { - column = v.Interface() - } - - if valuer, ok := interface{}(column).(driver.Valuer); ok { - column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface() - } - return column -} - func setFieldValue(field reflect.Value, value interface{}) bool { if field.IsValid() && field.CanAddr() { switch field.Kind() {