Don't copy unnecessary variables

This commit is contained in:
Jinzhu 2013-11-11 13:40:35 +08:00
parent d550315548
commit 1c49c4ef85
4 changed files with 30 additions and 29 deletions

View File

@ -7,9 +7,9 @@ import (
)
type Chain struct {
db sql_common
driver string
value interface{}
d *DB
db sql_common
value interface{}
Errors []error
Error error
@ -27,6 +27,9 @@ type Chain struct {
unscoped bool
}
func (s *Chain) driver() string {
return s.d.driver
}
func (s *Chain) err(err error) error {
if err != nil {
s.Errors = append(s.Errors, err)
@ -49,7 +52,6 @@ func (s *Chain) do(value interface{}) (do *Do) {
do = &Do{
chain: s,
db: s.db,
driver: s.driver,
whereClause: s.whereClause,
orClause: s.orClause,
notClause: s.notClause,

33
do.go
View File

@ -15,7 +15,6 @@ import (
type Do struct {
chain *Chain
db sql_common
driver string
guessedTableName string
specifiedTableName string
@ -59,14 +58,14 @@ func (s *Do) hasError() bool {
}
func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, driver: s.driver}
s.model = &Model{data: value, do: s}
s.value = value
return s
}
func (s *Do) addToVars(value interface{}) string {
s.sqlVars = append(s.sqlVars, value)
if s.driver == "postgres" {
if s.chain.driver() == "postgres" {
return fmt.Sprintf("$%d", len(s.sqlVars))
} else {
return "?"
@ -116,14 +115,14 @@ func (s *Do) prepareCreateSql() {
func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() {
var id interface{}
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do := &Do{chain: s.chain, db: s.db}
reflect_value := reflect.ValueOf(field.Value)
if reflect_value.CanAddr() {
id = do.setModel(reflect_value.Addr().Interface()).save()
} else {
dest_value := reflect.New(reflect_value.Type()).Elem()
m := &Model{data: field.Value, driver: s.driver}
m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
@ -145,20 +144,20 @@ func (s *Do) saveAfterAssociations() {
case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ {
value := reflect_value.Index(i).Addr().Interface()
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do := &Do{chain: s.chain, db: s.db}
if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
}
do.setModel(value).save()
}
default:
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do := &Do{chain: s.chain, db: s.db}
if reflect_value.CanAddr() {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
do.setModel(field.Value).save()
} else {
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
m := &Model{data: field.Value, driver: s.driver}
m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
@ -181,7 +180,7 @@ func (s *Do) create() (i interface{}) {
if !s.hasError() {
var id interface{}
if s.driver == "postgres" {
if s.chain.driver() == "postgres" {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} else {
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
@ -216,7 +215,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool)
s.setUpdateAttrs(value, ignore_protected_attrs...)
}
case interface{}:
m := &Model{data: values, driver: s.driver}
m := &Model{data: values, do: s}
fields := m.columnsHasValue("other")
s.updateAttrs = make(map[string]interface{}, len(fields))
@ -348,8 +347,8 @@ func (s *Do) related(value interface{}, foreign_keys ...string) {
var foreign_key string
var err error
from := &Model{data: value, driver: s.driver}
to := &Model{data: s.value, driver: s.driver}
from := &Model{data: value, do: s}
to := &Model{data: s.value, do: s}
foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id")
for _, fk := range foreign_keys {
@ -535,7 +534,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
}
return strings.Join(sqls, " AND ")
case interface{}:
m := &Model{data: query, driver: s.driver}
m := &Model{data: query, do: s}
var sqls []string
for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
@ -596,7 +595,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
}
return strings.Join(sqls, " AND ")
case interface{}:
m := &Model{data: query, driver: s.driver}
m := &Model{data: query, do: s}
var sqls []string
for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value)))
@ -759,7 +758,7 @@ func (s *Do) autoMigrate() *Do {
}
func (s *Do) initializeWithSearchCondition() {
m := Model{data: s.value, driver: s.driver}
m := Model{data: s.value, do: s}
for _, clause := range s.whereClause {
query := clause["query"]
@ -772,7 +771,7 @@ func (s *Do) initializeWithSearchCondition() {
for _, obj := range query.([]interface{}) {
switch reflect.ValueOf(obj).Kind() {
case reflect.Struct:
m := &Model{data: obj, driver: s.driver}
m := &Model{data: obj, do: s}
for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value)
}
@ -783,7 +782,7 @@ func (s *Do) initializeWithSearchCondition() {
}
}
case interface{}:
m := &Model{data: query, driver: s.driver}
m := &Model{data: query, do: s}
for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value)
}

View File

@ -34,7 +34,7 @@ func (s *DB) SingularTable(result bool) {
}
func (s *DB) buildChain() *Chain {
return &Chain{db: s.db, driver: s.driver}
return &Chain{db: s.db, d: s}
}
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {

View File

@ -13,7 +13,7 @@ import (
type Model struct {
data interface{}
driver string
do *Do
_cache_fields map[string][]Field
}
@ -110,7 +110,7 @@ func (m *Model) fields(operation string) (fields []Field) {
if is_scanner {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), driver: m.driver}
m := &Model{data: value.Interface(), do: m.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
@ -135,9 +135,9 @@ func (m *Model) fields(operation string) (fields []Field) {
value.Set(reflect.ValueOf(time.Now()))
}
}
field.SqlType = getSqlType(m.driver, value, 0)
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
} else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.driver, value, 0)
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
} else {
field_value := reflect.Indirect(value)
@ -152,7 +152,7 @@ func (m *Model) fields(operation string) (fields []Field) {
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if is_scanner {
field.SqlType = getSqlType(m.driver, value, 0)
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
@ -166,7 +166,7 @@ func (m *Model) fields(operation string) (fields []Field) {
}
}
default:
field.SqlType = getSqlType(m.driver, value, 0)
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
}
}
@ -326,7 +326,7 @@ func (m *Model) callMethod(method string) error {
}
func (m *Model) returningStr() (str string) {
if m.driver == "postgres" {
if m.do.chain.driver() == "postgres" {
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
}
return