Refact Code

This commit is contained in:
Jinzhu 2013-11-15 11:36:27 +08:00
parent c354b0fb06
commit 2d34ecd6a9
9 changed files with 132 additions and 156 deletions

View File

@ -1,7 +1,7 @@
package dialect package dialect
type Dialect interface { type Dialect interface {
BinVar(i int) string BinVar() string
SupportLastInsertId() bool SupportLastInsertId() bool
SqlTag(column interface{}, size int) string SqlTag(column interface{}, size int) string
PrimaryKeyTag(column interface{}, size int) string PrimaryKeyTag(column interface{}, size int) string

View File

@ -8,7 +8,7 @@ import (
type mysql struct{} type mysql struct{}
func (s *mysql) BinVar(i int) string { func (s *mysql) BinVar() string {
return "?" return "?"
} }

View File

@ -9,8 +9,8 @@ import (
type postgres struct { type postgres struct {
} }
func (s *postgres) BinVar(i int) string { func (s *postgres) BinVar() string {
return fmt.Sprintf("$%v", i) return "$%v"
} }
func (s *postgres) SupportLastInsertId() bool { func (s *postgres) SupportLastInsertId() bool {

View File

@ -8,7 +8,7 @@ import (
type sqlite3 struct{} type sqlite3 struct{}
func (s *sqlite3) BinVar(i int) string { func (s *sqlite3) BinVar() string {
return "?" return "?"
} }

18
do.go
View File

@ -60,7 +60,7 @@ func (s *Do) setModel(value interface{}) *Do {
func (s *Do) addToVars(value interface{}) string { func (s *Do) addToVars(value interface{}) string {
s.sqlVars = append(s.sqlVars, value) s.sqlVars = append(s.sqlVars, value)
return s.chain.d.dialect.BinVar(len(s.sqlVars)) return fmt.Sprintf(s.chain.d.dialect.BinVar(), len(s.sqlVars))
} }
func (s *Do) exec(sqls ...string) (err error) { func (s *Do) exec(sqls ...string) (err error) {
@ -209,7 +209,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool)
m := &Model{data: values, do: s} m := &Model{data: values, do: s}
s.updateAttrs = map[string]interface{}{} s.updateAttrs = map[string]interface{}{}
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
s.updateAttrs[field.DbName] = field.Value s.updateAttrs[field.dbName] = field.Value
} }
} }
@ -473,7 +473,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
m := &Model{data: value, do: s} m := &Model{data: value, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.dbName, s.addToVars(field.Value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
} }
@ -532,7 +532,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
m := &Model{data: value, do: s} m := &Model{data: value, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.dbName, s.addToVars(field.Value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
} }
@ -641,7 +641,7 @@ func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("migration") { for _, field := range s.model.fields("migration") {
if len(field.sqlTag()) > 0 { if len(field.sqlTag()) > 0 {
sqls = append(sqls, field.DbName+" "+field.sqlTag()) sqls = append(sqls, field.dbName+" "+field.sqlTag())
} }
} }
@ -697,12 +697,12 @@ func (s *Do) autoMigrate() *Do {
for _, field := range s.model.fields("migration") { for _, field := range s.model.fields("migration") {
var column_name, data_type string var column_name, data_type string
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName())) sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName()))
s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.DbName)), s.sqlVars...).Scan(&column_name, &data_type) s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
s.sqlVars = []interface{}{} s.sqlVars = []interface{}{}
// If column doesn't exist // If column doesn't exist
if len(column_name) == 0 && len(field.sqlTag()) > 0 { if len(column_name) == 0 && len(field.sqlTag()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.sqlTag()) s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.dbName, field.sqlTag())
s.exec() s.exec()
} }
} }
@ -745,7 +745,7 @@ func (s *Do) initializeWithSearchCondition() {
case reflect.Struct: case reflect.Struct:
m := &Model{data: obj, do: s} m := &Model{data: obj, do: s}
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
s.model.setValueByColumn(field.DbName, field.Value, s.value) s.model.setValueByColumn(field.dbName, field.Value, s.value)
} }
case reflect.Map: case reflect.Map:
for key, value := range obj.(map[string]interface{}) { for key, value := range obj.(map[string]interface{}) {
@ -756,7 +756,7 @@ func (s *Do) initializeWithSearchCondition() {
case interface{}: case interface{}:
m := &Model{data: value, do: s} m := &Model{data: value, do: s}
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
s.model.setValueByColumn(field.DbName, field.Value, s.value) s.model.setValueByColumn(field.dbName, field.Value, s.value)
} }
} }
} }

115
field.go
View File

@ -2,85 +2,68 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "reflect"
"time"
"strconv" "strconv"
"strings" "strings"
"time"
"reflect"
) )
type Field struct { type Field struct {
Name string Name string
Value interface{} Value interface{}
DbName string model *Model
AutoCreateTime bool dbName string
AutoUpdateTime bool isPrimaryKey bool
IsPrimaryKey bool autoCreateTime bool
structField reflect.StructField autoUpdateTime bool
modelValue reflect.Value foreignKey string
beforeAssociation bool beforeAssociation bool
afterAssociation bool afterAssociation bool
foreignKey string reflectValue reflect.Value
model *Model structField reflect.StructField
} }
func (f *Field) isBlank() bool { func (f *Field) isBlank() bool {
value := reflect.ValueOf(f.Value) return isBlank(f.reflectValue)
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int() == 0
case reflect.String:
return value.String() == ""
case reflect.Slice:
return value.Len() == 0
case reflect.Struct:
time_value, is_time := f.Value.(time.Time)
if is_time {
return time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
return !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: f.model.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
return true
}
}
}
}
return false
} }
func (f *Field) sqlTag() string { func (f *Field) isScanner() bool {
column := getInterfaceValue(f.Value) _, is_scanner := reflect.New(f.reflectValue.Type()).Interface().(sql.Scanner)
field_value := reflect.ValueOf(f.Value) return is_scanner
switch field_value.Kind() { }
func (f *Field) isTime() bool {
_, is_time := f.Value.(time.Time)
return is_time
}
func (f *Field) sqlTag() (str string) {
value := f.Value
if f.isScanner() {
value = f.reflectValue.Field(0).Interface()
}
reflect_value := f.reflectValue
switch reflect_value.Kind() {
case reflect.Slice: case reflect.Slice:
return "" return
case reflect.Struct: case reflect.Struct:
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) if !f.isTime() && !f.isScanner() {
_, is_time := column.(time.Time) return
if !is_time && !is_scanner {
return ""
} }
} }
typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(tagIdentifier)) typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(tagIdentifier))
if typ == "-" { if typ == "-" {
return "" return
} }
if len(typ) == 0 { if len(typ) == 0 {
if f.IsPrimaryKey { if f.isPrimaryKey {
typ = f.model.do.chain.d.dialect.PrimaryKeyTag(column, size) typ = f.model.do.chain.d.dialect.PrimaryKeyTag(value, size)
} else { } else {
typ = f.model.do.chain.d.dialect.SqlTag(column, size) typ = f.model.do.chain.d.dialect.SqlTag(value, size)
} }
} }
@ -91,26 +74,23 @@ func (f *Field) sqlTag() string {
} }
func (f *Field) parseAssociation() { func (f *Field) parseAssociation() {
field_value := reflect.ValueOf(f.Value) reflect_value := f.reflectValue
switch field_value.Kind() { switch reflect_value.Kind() {
case reflect.Slice: case reflect.Slice:
foreign_key := f.model.typeName() + "Id" foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() { if reflect.New(reflect_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key f.foreignKey = foreign_key
} }
f.afterAssociation = true f.afterAssociation = true
case reflect.Struct: case reflect.Struct:
_, is_time := f.Value.(time.Time) if !f.isTime() && !f.isScanner() {
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) if f.model.reflectData().FieldByName(f.Name + "Id").IsValid() {
if !is_scanner && !is_time {
if f.modelValue.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id" f.foreignKey = f.Name + "Id"
f.beforeAssociation = true f.beforeAssociation = true
} else { } else {
foreign_key := f.model.typeName() + "Id" foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { if reflect.New(reflect_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key f.foreignKey = foreign_key
} }
f.afterAssociation = true f.afterAssociation = true
@ -147,14 +127,3 @@ func parseSqlTag(str string) (typ string, addational_typ string, size int) {
} }
return return
} }
func getInterfaceValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}

110
model.go
View File

@ -14,33 +14,22 @@ type Model struct {
_cache_fields map[string][]*Field _cache_fields map[string][]*Field
} }
func (m *Model) reflectData() reflect.Value {
return reflect.Indirect(reflect.ValueOf(m.data))
}
func (m *Model) primaryKeyZero() bool { func (m *Model) primaryKeyZero() bool {
return m.primaryKeyValue() <= 0 return isBlank(reflect.ValueOf(m.primaryKeyValue()))
} }
func (m *Model) primaryKeyValue() int64 { func (m *Model) primaryKeyValue() interface{} {
if m.data == nil { if data := m.reflectData(); data.Kind() == reflect.Struct {
return -1 field := data.FieldByName(m.primaryKey())
if data.FieldByName(m.primaryKey()).IsValid() {
return field.Interface()
}
} }
data := reflect.Indirect(reflect.ValueOf(m.data))
switch data.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
return 0 return 0
default:
value := data.FieldByName(m.primaryKey())
if value.IsValid() {
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int()
default:
return 0
}
} else {
return 0
}
}
} }
func (m *Model) primaryKey() string { func (m *Model) primaryKey() string {
@ -56,7 +45,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
return m._cache_fields[operation] return m._cache_fields[operation]
} }
indirect_value := reflect.Indirect(reflect.ValueOf(m.data)) indirect_value := m.reflectData()
if !indirect_value.IsValid() { if !indirect_value.IsValid() {
return return
} }
@ -67,30 +56,29 @@ func (m *Model) fields(operation string) (fields []*Field) {
if !p.Anonymous && ast.IsExported(p.Name) { if !p.Anonymous && ast.IsExported(p.Name) {
var field Field var field Field
field.Name = p.Name field.Name = p.Name
field.DbName = toSnake(p.Name) field.dbName = toSnake(p.Name)
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName field.isPrimaryKey = m.primaryKeyDb() == field.dbName
value := indirect_value.FieldByName(p.Name) value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time)
field.model = m field.model = m
field.modelValue = indirect_value
if is_time { if time_value, is_time := value.Interface().(time.Time); is_time {
field.AutoCreateTime = "created_at" == field.DbName field.autoCreateTime = "created_at" == field.dbName
field.AutoUpdateTime = "updated_at" == field.DbName field.autoUpdateTime = "updated_at" == field.dbName
switch operation { switch operation {
case "create": case "create":
if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() { if (field.autoCreateTime || field.autoUpdateTime) && time_value.IsZero() {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
case "update": case "update":
if field.AutoUpdateTime { if field.autoUpdateTime {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
} }
} }
field.structField = p field.structField = p
field.reflectValue = value
field.Value = value.Interface() field.Value = value.Interface()
fields = append(fields, &field) fields = append(fields, &field)
} }
@ -117,17 +105,16 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results
return values, true return values, true
} }
data := reflect.Indirect(reflect.ValueOf(m.data)) data := m.reflectData()
for key, value := range values { for key, value := range values {
field := data.FieldByName(snakeToUpperCamel(key)) if field := data.FieldByName(snakeToUpperCamel(key)); field.IsValid() {
if field.IsValid() {
if field.Interface() != value { if field.Interface() != value {
switch field.Kind() { switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
if field.Int() != reflect.ValueOf(value).Int() { if field.Int() != reflect.ValueOf(value).Int() {
any_updated = true any_updated = true
}
field.SetInt(reflect.ValueOf(value).Int()) field.SetInt(reflect.ValueOf(value).Int())
}
default: default:
any_updated = true any_updated = true
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
@ -145,52 +132,47 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results
func (m *Model) columnsAndValues(operation string) map[string]interface{} { func (m *Model) columnsAndValues(operation string) map[string]interface{} {
results := map[string]interface{}{} results := map[string]interface{}{}
if m.data != nil {
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.sqlTag()) > 0) { if !field.isPrimaryKey && (len(field.sqlTag()) > 0) {
results[field.DbName] = field.Value results[field.dbName] = field.Value
}
} }
} }
return results return results
} }
func (m *Model) hasColumn(name string) bool { func (m *Model) hasColumn(name string) bool {
if m.data == nil { data := m.reflectData()
if data.Kind() == reflect.Struct {
return data.FieldByName(name).IsValid()
} else if data.Kind() == reflect.Slice {
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
}
return false return false
} }
data := reflect.Indirect(reflect.ValueOf(m.data))
if data.Kind() == reflect.Slice {
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
} else {
return data.FieldByName(name).IsValid()
}
}
func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
if m.data != nil { data := m.reflectData()
data := reflect.Indirect(reflect.ValueOf(m.data))
if data.Kind() == reflect.Slice { if data.Kind() == reflect.Struct {
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
is_slice = true
} else {
if has_column = data.FieldByName(name).IsValid(); has_column { if has_column = data.FieldByName(name).IsValid(); has_column {
value = data.FieldByName(name).Interface() value = data.FieldByName(name).Interface()
} }
} } else if data.Kind() == reflect.Slice {
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
is_slice = true
} }
return return
} }
func (m *Model) typeName() string { func (m *Model) typeName() string {
typ := reflect.Indirect(reflect.ValueOf(m.data)).Type() typ := m.reflectData().Type()
if typ.Kind() == reflect.Slice { if typ.Kind() == reflect.Slice {
typ = typ.Elem() return typ.Elem().Name()
} } else {
return typ.Name() return typ.Name()
} }
}
func (m *Model) tableName() (str string) { func (m *Model) tableName() (str string) {
if m.data == nil { if m.data == nil {
@ -198,8 +180,8 @@ func (m *Model) tableName() (str string) {
return return
} }
fm := reflect.Indirect(reflect.ValueOf(m.data)).MethodByName("TableName") data := m.reflectData()
if fm.IsValid() { if fm := data.MethodByName("TableName"); fm.IsValid() {
if v := fm.Call([]reflect.Value{}); len(v) > 0 { if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if result, ok := v[0].Interface().(string); ok { if result, ok := v[0].Interface().(string); ok {
return result return result
@ -227,8 +209,7 @@ func (m *Model) callMethod(method string) {
return return
} }
fm := reflect.ValueOf(m.data).MethodByName(method) if fm := reflect.ValueOf(m.data).MethodByName(method); fm.IsValid() {
if fm.IsValid() {
if v := fm.Call([]reflect.Value{}); len(v) > 0 { if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if verr, ok := v[0].Interface().(error); ok { if verr, ok := v[0].Interface().(error); ok {
m.do.err(verr) m.do.err(verr)
@ -255,7 +236,6 @@ func (m *Model) beforeAssociations() (fields []*Field) {
func (m *Model) afterAssociations() (fields []*Field) { func (m *Model) afterAssociations() (fields []*Field) {
for _, field := range m.fields("null") { for _, field := range m.fields("null") {
field.parseAssociation()
if field.afterAssociation && !field.isBlank() { if field.afterAssociation && !field.isBlank() {
fields = append(fields, field) fields = append(fields, field)
} }

View File

View File

@ -3,13 +3,12 @@ package gorm
import ( import (
"bytes" "bytes"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"reflect" "reflect"
"strconv" "strconv"
"fmt"
"strings" "strings"
"time"
) )
func toSnake(s string) string { func toSnake(s string) string {
@ -87,3 +86,31 @@ func setFieldValue(field reflect.Value, value interface{}) bool {
return false return false
} }
func isBlank(value reflect.Value) bool {
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int() == 0
case reflect.String:
return value.String() == ""
case reflect.Slice:
return value.Len() == 0
case reflect.Struct:
time_value, is_time := value.Interface().(time.Time)
if is_time {
return time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
return !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface()}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
return true
}
}
}
}
return false
}