mirror of https://github.com/go-gorm/gorm.git
Don't copy unnecessary variables
This commit is contained in:
parent
d550315548
commit
1c49c4ef85
10
chain.go
10
chain.go
|
@ -7,9 +7,9 @@ import (
|
|||
)
|
||||
|
||||
type Chain struct {
|
||||
db sql_common
|
||||
driver string
|
||||
value interface{}
|
||||
d *DB
|
||||
db sql_common
|
||||
value interface{}
|
||||
|
||||
Errors []error
|
||||
Error error
|
||||
|
@ -27,6 +27,9 @@ type Chain struct {
|
|||
unscoped bool
|
||||
}
|
||||
|
||||
func (s *Chain) driver() string {
|
||||
return s.d.driver
|
||||
}
|
||||
func (s *Chain) err(err error) error {
|
||||
if err != nil {
|
||||
s.Errors = append(s.Errors, err)
|
||||
|
@ -49,7 +52,6 @@ func (s *Chain) do(value interface{}) (do *Do) {
|
|||
do = &Do{
|
||||
chain: s,
|
||||
db: s.db,
|
||||
driver: s.driver,
|
||||
whereClause: s.whereClause,
|
||||
orClause: s.orClause,
|
||||
notClause: s.notClause,
|
||||
|
|
33
do.go
33
do.go
|
@ -15,7 +15,6 @@ import (
|
|||
type Do struct {
|
||||
chain *Chain
|
||||
db sql_common
|
||||
driver string
|
||||
guessedTableName string
|
||||
specifiedTableName string
|
||||
|
||||
|
@ -59,14 +58,14 @@ func (s *Do) hasError() bool {
|
|||
}
|
||||
|
||||
func (s *Do) setModel(value interface{}) *Do {
|
||||
s.model = &Model{data: value, driver: s.driver}
|
||||
s.model = &Model{data: value, do: s}
|
||||
s.value = value
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) addToVars(value interface{}) string {
|
||||
s.sqlVars = append(s.sqlVars, value)
|
||||
if s.driver == "postgres" {
|
||||
if s.chain.driver() == "postgres" {
|
||||
return fmt.Sprintf("$%d", len(s.sqlVars))
|
||||
} else {
|
||||
return "?"
|
||||
|
@ -116,14 +115,14 @@ func (s *Do) prepareCreateSql() {
|
|||
func (s *Do) saveBeforeAssociations() {
|
||||
for _, field := range s.model.beforeAssociations() {
|
||||
var id interface{}
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
|
||||
reflect_value := reflect.ValueOf(field.Value)
|
||||
if reflect_value.CanAddr() {
|
||||
id = do.setModel(reflect_value.Addr().Interface()).save()
|
||||
} else {
|
||||
dest_value := reflect.New(reflect_value.Type()).Elem()
|
||||
m := &Model{data: field.Value, driver: s.driver}
|
||||
m := &Model{data: field.Value, do: s}
|
||||
for _, f := range m.columnsHasValue("other") {
|
||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
}
|
||||
|
@ -145,20 +144,20 @@ func (s *Do) saveAfterAssociations() {
|
|||
case reflect.Slice:
|
||||
for i := 0; i < reflect_value.Len(); i++ {
|
||||
value := reflect_value.Index(i).Addr().Interface()
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
if len(field.foreignKey) > 0 {
|
||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
||||
}
|
||||
do.setModel(value).save()
|
||||
}
|
||||
default:
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
if reflect_value.CanAddr() {
|
||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
||||
do.setModel(field.Value).save()
|
||||
} else {
|
||||
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
|
||||
m := &Model{data: field.Value, driver: s.driver}
|
||||
m := &Model{data: field.Value, do: s}
|
||||
for _, f := range m.columnsHasValue("other") {
|
||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
}
|
||||
|
@ -181,7 +180,7 @@ func (s *Do) create() (i interface{}) {
|
|||
|
||||
if !s.hasError() {
|
||||
var id interface{}
|
||||
if s.driver == "postgres" {
|
||||
if s.chain.driver() == "postgres" {
|
||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||
} else {
|
||||
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
||||
|
@ -216,7 +215,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool)
|
|||
s.setUpdateAttrs(value, ignore_protected_attrs...)
|
||||
}
|
||||
case interface{}:
|
||||
m := &Model{data: values, driver: s.driver}
|
||||
m := &Model{data: values, do: s}
|
||||
fields := m.columnsHasValue("other")
|
||||
|
||||
s.updateAttrs = make(map[string]interface{}, len(fields))
|
||||
|
@ -348,8 +347,8 @@ func (s *Do) related(value interface{}, foreign_keys ...string) {
|
|||
var foreign_key string
|
||||
var err error
|
||||
|
||||
from := &Model{data: value, driver: s.driver}
|
||||
to := &Model{data: s.value, driver: s.driver}
|
||||
from := &Model{data: value, do: s}
|
||||
to := &Model{data: s.value, do: s}
|
||||
foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id")
|
||||
|
||||
for _, fk := range foreign_keys {
|
||||
|
@ -535,7 +534,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
m := &Model{data: query, driver: s.driver}
|
||||
m := &Model{data: query, do: s}
|
||||
var sqls []string
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
|
||||
|
@ -596,7 +595,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
m := &Model{data: query, driver: s.driver}
|
||||
m := &Model{data: query, do: s}
|
||||
var sqls []string
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value)))
|
||||
|
@ -759,7 +758,7 @@ func (s *Do) autoMigrate() *Do {
|
|||
}
|
||||
|
||||
func (s *Do) initializeWithSearchCondition() {
|
||||
m := Model{data: s.value, driver: s.driver}
|
||||
m := Model{data: s.value, do: s}
|
||||
|
||||
for _, clause := range s.whereClause {
|
||||
query := clause["query"]
|
||||
|
@ -772,7 +771,7 @@ func (s *Do) initializeWithSearchCondition() {
|
|||
for _, obj := range query.([]interface{}) {
|
||||
switch reflect.ValueOf(obj).Kind() {
|
||||
case reflect.Struct:
|
||||
m := &Model{data: obj, driver: s.driver}
|
||||
m := &Model{data: obj, do: s}
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
m.setValueByColumn(field.DbName, field.Value, s.value)
|
||||
}
|
||||
|
@ -783,7 +782,7 @@ func (s *Do) initializeWithSearchCondition() {
|
|||
}
|
||||
}
|
||||
case interface{}:
|
||||
m := &Model{data: query, driver: s.driver}
|
||||
m := &Model{data: query, do: s}
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
m.setValueByColumn(field.DbName, field.Value, s.value)
|
||||
}
|
||||
|
|
2
main.go
2
main.go
|
@ -34,7 +34,7 @@ func (s *DB) SingularTable(result bool) {
|
|||
}
|
||||
|
||||
func (s *DB) buildChain() *Chain {
|
||||
return &Chain{db: s.db, driver: s.driver}
|
||||
return &Chain{db: s.db, d: s}
|
||||
}
|
||||
|
||||
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
|
||||
|
|
14
model.go
14
model.go
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
type Model struct {
|
||||
data interface{}
|
||||
driver string
|
||||
do *Do
|
||||
_cache_fields map[string][]Field
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
if is_scanner {
|
||||
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
||||
} else {
|
||||
m := &Model{data: value.Interface(), driver: m.driver}
|
||||
m := &Model{data: value.Interface(), do: m.do}
|
||||
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
|
@ -135,9 +135,9 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
}
|
||||
field.SqlType = getSqlType(m.driver, value, 0)
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||
} else if field.IsPrimaryKey {
|
||||
field.SqlType = getPrimaryKeySqlType(m.driver, value, 0)
|
||||
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
|
||||
} else {
|
||||
field_value := reflect.Indirect(value)
|
||||
|
||||
|
@ -152,7 +152,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if is_scanner {
|
||||
field.SqlType = getSqlType(m.driver, value, 0)
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||
} else {
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
|
@ -166,7 +166,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
}
|
||||
default:
|
||||
field.SqlType = getSqlType(m.driver, value, 0)
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -326,7 +326,7 @@ func (m *Model) callMethod(method string) error {
|
|||
}
|
||||
|
||||
func (m *Model) returningStr() (str string) {
|
||||
if m.driver == "postgres" {
|
||||
if m.do.chain.driver() == "postgres" {
|
||||
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
|
||||
}
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue