mirror of https://github.com/go-gorm/gorm.git
Don't copy unnecessary variables
This commit is contained in:
parent
d550315548
commit
1c49c4ef85
6
chain.go
6
chain.go
|
@ -7,8 +7,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Chain struct {
|
type Chain struct {
|
||||||
|
d *DB
|
||||||
db sql_common
|
db sql_common
|
||||||
driver string
|
|
||||||
value interface{}
|
value interface{}
|
||||||
|
|
||||||
Errors []error
|
Errors []error
|
||||||
|
@ -27,6 +27,9 @@ type Chain struct {
|
||||||
unscoped bool
|
unscoped bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Chain) driver() string {
|
||||||
|
return s.d.driver
|
||||||
|
}
|
||||||
func (s *Chain) err(err error) error {
|
func (s *Chain) err(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Errors = append(s.Errors, err)
|
s.Errors = append(s.Errors, err)
|
||||||
|
@ -49,7 +52,6 @@ func (s *Chain) do(value interface{}) (do *Do) {
|
||||||
do = &Do{
|
do = &Do{
|
||||||
chain: s,
|
chain: s,
|
||||||
db: s.db,
|
db: s.db,
|
||||||
driver: s.driver,
|
|
||||||
whereClause: s.whereClause,
|
whereClause: s.whereClause,
|
||||||
orClause: s.orClause,
|
orClause: s.orClause,
|
||||||
notClause: s.notClause,
|
notClause: s.notClause,
|
||||||
|
|
33
do.go
33
do.go
|
@ -15,7 +15,6 @@ import (
|
||||||
type Do struct {
|
type Do struct {
|
||||||
chain *Chain
|
chain *Chain
|
||||||
db sql_common
|
db sql_common
|
||||||
driver string
|
|
||||||
guessedTableName string
|
guessedTableName string
|
||||||
specifiedTableName string
|
specifiedTableName string
|
||||||
|
|
||||||
|
@ -59,14 +58,14 @@ func (s *Do) hasError() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) setModel(value interface{}) *Do {
|
func (s *Do) setModel(value interface{}) *Do {
|
||||||
s.model = &Model{data: value, driver: s.driver}
|
s.model = &Model{data: value, do: s}
|
||||||
s.value = value
|
s.value = value
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) addToVars(value interface{}) string {
|
func (s *Do) addToVars(value interface{}) string {
|
||||||
s.sqlVars = append(s.sqlVars, value)
|
s.sqlVars = append(s.sqlVars, value)
|
||||||
if s.driver == "postgres" {
|
if s.chain.driver() == "postgres" {
|
||||||
return fmt.Sprintf("$%d", len(s.sqlVars))
|
return fmt.Sprintf("$%d", len(s.sqlVars))
|
||||||
} else {
|
} else {
|
||||||
return "?"
|
return "?"
|
||||||
|
@ -116,14 +115,14 @@ func (s *Do) prepareCreateSql() {
|
||||||
func (s *Do) saveBeforeAssociations() {
|
func (s *Do) saveBeforeAssociations() {
|
||||||
for _, field := range s.model.beforeAssociations() {
|
for _, field := range s.model.beforeAssociations() {
|
||||||
var id interface{}
|
var id interface{}
|
||||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
do := &Do{chain: s.chain, db: s.db}
|
||||||
|
|
||||||
reflect_value := reflect.ValueOf(field.Value)
|
reflect_value := reflect.ValueOf(field.Value)
|
||||||
if reflect_value.CanAddr() {
|
if reflect_value.CanAddr() {
|
||||||
id = do.setModel(reflect_value.Addr().Interface()).save()
|
id = do.setModel(reflect_value.Addr().Interface()).save()
|
||||||
} else {
|
} else {
|
||||||
dest_value := reflect.New(reflect_value.Type()).Elem()
|
dest_value := reflect.New(reflect_value.Type()).Elem()
|
||||||
m := &Model{data: field.Value, driver: s.driver}
|
m := &Model{data: field.Value, do: s}
|
||||||
for _, f := range m.columnsHasValue("other") {
|
for _, f := range m.columnsHasValue("other") {
|
||||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||||
}
|
}
|
||||||
|
@ -145,20 +144,20 @@ func (s *Do) saveAfterAssociations() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < reflect_value.Len(); i++ {
|
for i := 0; i < reflect_value.Len(); i++ {
|
||||||
value := reflect_value.Index(i).Addr().Interface()
|
value := reflect_value.Index(i).Addr().Interface()
|
||||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
do := &Do{chain: s.chain, db: s.db}
|
||||||
if len(field.foreignKey) > 0 {
|
if len(field.foreignKey) > 0 {
|
||||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
||||||
}
|
}
|
||||||
do.setModel(value).save()
|
do.setModel(value).save()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
do := &Do{chain: s.chain, db: s.db}
|
||||||
if reflect_value.CanAddr() {
|
if reflect_value.CanAddr() {
|
||||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
||||||
do.setModel(field.Value).save()
|
do.setModel(field.Value).save()
|
||||||
} else {
|
} else {
|
||||||
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
|
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
|
||||||
m := &Model{data: field.Value, driver: s.driver}
|
m := &Model{data: field.Value, do: s}
|
||||||
for _, f := range m.columnsHasValue("other") {
|
for _, f := range m.columnsHasValue("other") {
|
||||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||||
}
|
}
|
||||||
|
@ -181,7 +180,7 @@ func (s *Do) create() (i interface{}) {
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
var id interface{}
|
var id interface{}
|
||||||
if s.driver == "postgres" {
|
if s.chain.driver() == "postgres" {
|
||||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||||
} else {
|
} else {
|
||||||
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
||||||
|
@ -216,7 +215,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool)
|
||||||
s.setUpdateAttrs(value, ignore_protected_attrs...)
|
s.setUpdateAttrs(value, ignore_protected_attrs...)
|
||||||
}
|
}
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: values, driver: s.driver}
|
m := &Model{data: values, do: s}
|
||||||
fields := m.columnsHasValue("other")
|
fields := m.columnsHasValue("other")
|
||||||
|
|
||||||
s.updateAttrs = make(map[string]interface{}, len(fields))
|
s.updateAttrs = make(map[string]interface{}, len(fields))
|
||||||
|
@ -348,8 +347,8 @@ func (s *Do) related(value interface{}, foreign_keys ...string) {
|
||||||
var foreign_key string
|
var foreign_key string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
from := &Model{data: value, driver: s.driver}
|
from := &Model{data: value, do: s}
|
||||||
to := &Model{data: s.value, driver: s.driver}
|
to := &Model{data: s.value, do: s}
|
||||||
foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id")
|
foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id")
|
||||||
|
|
||||||
for _, fk := range foreign_keys {
|
for _, fk := range foreign_keys {
|
||||||
|
@ -535,7 +534,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: query, driver: s.driver}
|
m := &Model{data: query, do: s}
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range m.columnsHasValue("other") {
|
for _, field := range m.columnsHasValue("other") {
|
||||||
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
|
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
|
||||||
|
@ -596,7 +595,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: query, driver: s.driver}
|
m := &Model{data: query, do: s}
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range m.columnsHasValue("other") {
|
for _, field := range m.columnsHasValue("other") {
|
||||||
sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value)))
|
sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value)))
|
||||||
|
@ -759,7 +758,7 @@ func (s *Do) autoMigrate() *Do {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) initializeWithSearchCondition() {
|
func (s *Do) initializeWithSearchCondition() {
|
||||||
m := Model{data: s.value, driver: s.driver}
|
m := Model{data: s.value, do: s}
|
||||||
|
|
||||||
for _, clause := range s.whereClause {
|
for _, clause := range s.whereClause {
|
||||||
query := clause["query"]
|
query := clause["query"]
|
||||||
|
@ -772,7 +771,7 @@ func (s *Do) initializeWithSearchCondition() {
|
||||||
for _, obj := range query.([]interface{}) {
|
for _, obj := range query.([]interface{}) {
|
||||||
switch reflect.ValueOf(obj).Kind() {
|
switch reflect.ValueOf(obj).Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
m := &Model{data: obj, driver: s.driver}
|
m := &Model{data: obj, do: s}
|
||||||
for _, field := range m.columnsHasValue("other") {
|
for _, field := range m.columnsHasValue("other") {
|
||||||
m.setValueByColumn(field.DbName, field.Value, s.value)
|
m.setValueByColumn(field.DbName, field.Value, s.value)
|
||||||
}
|
}
|
||||||
|
@ -783,7 +782,7 @@ func (s *Do) initializeWithSearchCondition() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: query, driver: s.driver}
|
m := &Model{data: query, do: s}
|
||||||
for _, field := range m.columnsHasValue("other") {
|
for _, field := range m.columnsHasValue("other") {
|
||||||
m.setValueByColumn(field.DbName, field.Value, s.value)
|
m.setValueByColumn(field.DbName, field.Value, s.value)
|
||||||
}
|
}
|
||||||
|
|
2
main.go
2
main.go
|
@ -34,7 +34,7 @@ func (s *DB) SingularTable(result bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) buildChain() *Chain {
|
func (s *DB) buildChain() *Chain {
|
||||||
return &Chain{db: s.db, driver: s.driver}
|
return &Chain{db: s.db, d: s}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
|
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
|
||||||
|
|
14
model.go
14
model.go
|
@ -13,7 +13,7 @@ import (
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
data interface{}
|
data interface{}
|
||||||
driver string
|
do *Do
|
||||||
_cache_fields map[string][]Field
|
_cache_fields map[string][]Field
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
if is_scanner {
|
if is_scanner {
|
||||||
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
||||||
} else {
|
} else {
|
||||||
m := &Model{data: value.Interface(), driver: m.driver}
|
m := &Model{data: value.Interface(), do: m.do}
|
||||||
|
|
||||||
fields := m.columnsHasValue("other")
|
fields := m.columnsHasValue("other")
|
||||||
if len(fields) == 0 {
|
if len(fields) == 0 {
|
||||||
|
@ -135,9 +135,9 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
field.SqlType = getSqlType(m.driver, value, 0)
|
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||||
} else if field.IsPrimaryKey {
|
} else if field.IsPrimaryKey {
|
||||||
field.SqlType = getPrimaryKeySqlType(m.driver, value, 0)
|
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
|
||||||
} else {
|
} else {
|
||||||
field_value := reflect.Indirect(value)
|
field_value := reflect.Indirect(value)
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||||
|
|
||||||
if is_scanner {
|
if is_scanner {
|
||||||
field.SqlType = getSqlType(m.driver, value, 0)
|
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||||
} else {
|
} else {
|
||||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||||
field.foreignKey = p.Name + "Id"
|
field.foreignKey = p.Name + "Id"
|
||||||
|
@ -166,7 +166,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
field.SqlType = getSqlType(m.driver, value, 0)
|
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,7 +326,7 @@ func (m *Model) callMethod(method string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) returningStr() (str string) {
|
func (m *Model) returningStr() (str string) {
|
||||||
if m.driver == "postgres" {
|
if m.do.chain.driver() == "postgres" {
|
||||||
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
|
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue