Don't copy unnecessary variables

This commit is contained in:
Jinzhu 2013-11-11 13:40:35 +08:00
parent d550315548
commit 1c49c4ef85
4 changed files with 30 additions and 29 deletions

View File

@ -7,8 +7,8 @@ import (
) )
type Chain struct { type Chain struct {
d *DB
db sql_common db sql_common
driver string
value interface{} value interface{}
Errors []error Errors []error
@ -27,6 +27,9 @@ type Chain struct {
unscoped bool unscoped bool
} }
func (s *Chain) driver() string {
return s.d.driver
}
func (s *Chain) err(err error) error { func (s *Chain) err(err error) error {
if err != nil { if err != nil {
s.Errors = append(s.Errors, err) s.Errors = append(s.Errors, err)
@ -49,7 +52,6 @@ func (s *Chain) do(value interface{}) (do *Do) {
do = &Do{ do = &Do{
chain: s, chain: s,
db: s.db, db: s.db,
driver: s.driver,
whereClause: s.whereClause, whereClause: s.whereClause,
orClause: s.orClause, orClause: s.orClause,
notClause: s.notClause, notClause: s.notClause,

33
do.go
View File

@ -15,7 +15,6 @@ import (
type Do struct { type Do struct {
chain *Chain chain *Chain
db sql_common db sql_common
driver string
guessedTableName string guessedTableName string
specifiedTableName string specifiedTableName string
@ -59,14 +58,14 @@ func (s *Do) hasError() bool {
} }
func (s *Do) setModel(value interface{}) *Do { func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, driver: s.driver} s.model = &Model{data: value, do: s}
s.value = value s.value = value
return s return s
} }
func (s *Do) addToVars(value interface{}) string { func (s *Do) addToVars(value interface{}) string {
s.sqlVars = append(s.sqlVars, value) s.sqlVars = append(s.sqlVars, value)
if s.driver == "postgres" { if s.chain.driver() == "postgres" {
return fmt.Sprintf("$%d", len(s.sqlVars)) return fmt.Sprintf("$%d", len(s.sqlVars))
} else { } else {
return "?" return "?"
@ -116,14 +115,14 @@ func (s *Do) prepareCreateSql() {
func (s *Do) saveBeforeAssociations() { func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() { for _, field := range s.model.beforeAssociations() {
var id interface{} var id interface{}
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db}
reflect_value := reflect.ValueOf(field.Value) reflect_value := reflect.ValueOf(field.Value)
if reflect_value.CanAddr() { if reflect_value.CanAddr() {
id = do.setModel(reflect_value.Addr().Interface()).save() id = do.setModel(reflect_value.Addr().Interface()).save()
} else { } else {
dest_value := reflect.New(reflect_value.Type()).Elem() dest_value := reflect.New(reflect_value.Type()).Elem()
m := &Model{data: field.Value, driver: s.driver} m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") { for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
} }
@ -145,20 +144,20 @@ func (s *Do) saveAfterAssociations() {
case reflect.Slice: case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ { for i := 0; i < reflect_value.Len(); i++ {
value := reflect_value.Index(i).Addr().Interface() value := reflect_value.Index(i).Addr().Interface()
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db}
if len(field.foreignKey) > 0 { if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value) s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
} }
do.setModel(value).save() do.setModel(value).save()
} }
default: default:
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db}
if reflect_value.CanAddr() { if reflect_value.CanAddr() {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value) s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
do.setModel(field.Value).save() do.setModel(field.Value).save()
} else { } else {
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem() dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
m := &Model{data: field.Value, driver: s.driver} m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") { for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
} }
@ -181,7 +180,7 @@ func (s *Do) create() (i interface{}) {
if !s.hasError() { if !s.hasError() {
var id interface{} var id interface{}
if s.driver == "postgres" { if s.chain.driver() == "postgres" {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} else { } else {
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
@ -216,7 +215,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool)
s.setUpdateAttrs(value, ignore_protected_attrs...) s.setUpdateAttrs(value, ignore_protected_attrs...)
} }
case interface{}: case interface{}:
m := &Model{data: values, driver: s.driver} m := &Model{data: values, do: s}
fields := m.columnsHasValue("other") fields := m.columnsHasValue("other")
s.updateAttrs = make(map[string]interface{}, len(fields)) s.updateAttrs = make(map[string]interface{}, len(fields))
@ -348,8 +347,8 @@ func (s *Do) related(value interface{}, foreign_keys ...string) {
var foreign_key string var foreign_key string
var err error var err error
from := &Model{data: value, driver: s.driver} from := &Model{data: value, do: s}
to := &Model{data: s.value, driver: s.driver} to := &Model{data: s.value, do: s}
foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id") foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id")
for _, fk := range foreign_keys { for _, fk := range foreign_keys {
@ -535,7 +534,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
m := &Model{data: query, driver: s.driver} m := &Model{data: query, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
@ -596,7 +595,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
m := &Model{data: query, driver: s.driver} m := &Model{data: query, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value)))
@ -759,7 +758,7 @@ func (s *Do) autoMigrate() *Do {
} }
func (s *Do) initializeWithSearchCondition() { func (s *Do) initializeWithSearchCondition() {
m := Model{data: s.value, driver: s.driver} m := Model{data: s.value, do: s}
for _, clause := range s.whereClause { for _, clause := range s.whereClause {
query := clause["query"] query := clause["query"]
@ -772,7 +771,7 @@ func (s *Do) initializeWithSearchCondition() {
for _, obj := range query.([]interface{}) { for _, obj := range query.([]interface{}) {
switch reflect.ValueOf(obj).Kind() { switch reflect.ValueOf(obj).Kind() {
case reflect.Struct: case reflect.Struct:
m := &Model{data: obj, driver: s.driver} m := &Model{data: obj, do: s}
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value) m.setValueByColumn(field.DbName, field.Value, s.value)
} }
@ -783,7 +782,7 @@ func (s *Do) initializeWithSearchCondition() {
} }
} }
case interface{}: case interface{}:
m := &Model{data: query, driver: s.driver} m := &Model{data: query, do: s}
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value) m.setValueByColumn(field.DbName, field.Value, s.value)
} }

View File

@ -34,7 +34,7 @@ func (s *DB) SingularTable(result bool) {
} }
func (s *DB) buildChain() *Chain { func (s *DB) buildChain() *Chain {
return &Chain{db: s.db, driver: s.driver} return &Chain{db: s.db, d: s}
} }
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {

View File

@ -13,7 +13,7 @@ import (
type Model struct { type Model struct {
data interface{} data interface{}
driver string do *Do
_cache_fields map[string][]Field _cache_fields map[string][]Field
} }
@ -110,7 +110,7 @@ func (m *Model) fields(operation string) (fields []Field) {
if is_scanner { if is_scanner {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool) field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
} else { } else {
m := &Model{data: value.Interface(), driver: m.driver} m := &Model{data: value.Interface(), do: m.do}
fields := m.columnsHasValue("other") fields := m.columnsHasValue("other")
if len(fields) == 0 { if len(fields) == 0 {
@ -135,9 +135,9 @@ func (m *Model) fields(operation string) (fields []Field) {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
} }
field.SqlType = getSqlType(m.driver, value, 0) field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
} else if field.IsPrimaryKey { } else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.driver, value, 0) field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
} else { } else {
field_value := reflect.Indirect(value) field_value := reflect.Indirect(value)
@ -152,7 +152,7 @@ func (m *Model) fields(operation string) (fields []Field) {
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if is_scanner { if is_scanner {
field.SqlType = getSqlType(m.driver, value, 0) field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
} else { } else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() { if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id" field.foreignKey = p.Name + "Id"
@ -166,7 +166,7 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
} }
default: default:
field.SqlType = getSqlType(m.driver, value, 0) field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
} }
} }
@ -326,7 +326,7 @@ func (m *Model) callMethod(method string) error {
} }
func (m *Model) returningStr() (str string) { func (m *Model) returningStr() (str string) {
if m.driver == "postgres" { if m.do.chain.driver() == "postgres" {
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb()) str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
} }
return return