diff --git a/chain.go b/chain.go index aa65e357..b45a373b 100644 --- a/chain.go +++ b/chain.go @@ -7,9 +7,9 @@ import ( ) type Chain struct { - db sql_common - driver string - value interface{} + d *DB + db sql_common + value interface{} Errors []error Error error @@ -27,6 +27,9 @@ type Chain struct { unscoped bool } +func (s *Chain) driver() string { + return s.d.driver +} func (s *Chain) err(err error) error { if err != nil { s.Errors = append(s.Errors, err) @@ -49,7 +52,6 @@ func (s *Chain) do(value interface{}) (do *Do) { do = &Do{ chain: s, db: s.db, - driver: s.driver, whereClause: s.whereClause, orClause: s.orClause, notClause: s.notClause, diff --git a/do.go b/do.go index 49c4dc80..546cb7c4 100644 --- a/do.go +++ b/do.go @@ -15,7 +15,6 @@ import ( type Do struct { chain *Chain db sql_common - driver string guessedTableName string specifiedTableName string @@ -59,14 +58,14 @@ func (s *Do) hasError() bool { } func (s *Do) setModel(value interface{}) *Do { - s.model = &Model{data: value, driver: s.driver} + s.model = &Model{data: value, do: s} s.value = value return s } func (s *Do) addToVars(value interface{}) string { s.sqlVars = append(s.sqlVars, value) - if s.driver == "postgres" { + if s.chain.driver() == "postgres" { return fmt.Sprintf("$%d", len(s.sqlVars)) } else { return "?" @@ -116,14 +115,14 @@ func (s *Do) prepareCreateSql() { func (s *Do) saveBeforeAssociations() { for _, field := range s.model.beforeAssociations() { var id interface{} - do := &Do{chain: s.chain, db: s.db, driver: s.driver} + do := &Do{chain: s.chain, db: s.db} reflect_value := reflect.ValueOf(field.Value) if reflect_value.CanAddr() { id = do.setModel(reflect_value.Addr().Interface()).save() } else { dest_value := reflect.New(reflect_value.Type()).Elem() - m := &Model{data: field.Value, driver: s.driver} + m := &Model{data: field.Value, do: s} for _, f := range m.columnsHasValue("other") { dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) } @@ -145,20 +144,20 @@ func (s *Do) saveAfterAssociations() { case reflect.Slice: for i := 0; i < reflect_value.Len(); i++ { value := reflect_value.Index(i).Addr().Interface() - do := &Do{chain: s.chain, db: s.db, driver: s.driver} + do := &Do{chain: s.chain, db: s.db} if len(field.foreignKey) > 0 { s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value) } do.setModel(value).save() } default: - do := &Do{chain: s.chain, db: s.db, driver: s.driver} + do := &Do{chain: s.chain, db: s.db} if reflect_value.CanAddr() { s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value) do.setModel(field.Value).save() } else { dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem() - m := &Model{data: field.Value, driver: s.driver} + m := &Model{data: field.Value, do: s} for _, f := range m.columnsHasValue("other") { dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) } @@ -181,7 +180,7 @@ func (s *Do) create() (i interface{}) { if !s.hasError() { var id interface{} - if s.driver == "postgres" { + if s.chain.driver() == "postgres" { s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) } else { if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { @@ -216,7 +215,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) s.setUpdateAttrs(value, ignore_protected_attrs...) } case interface{}: - m := &Model{data: values, driver: s.driver} + m := &Model{data: values, do: s} fields := m.columnsHasValue("other") s.updateAttrs = make(map[string]interface{}, len(fields)) @@ -348,8 +347,8 @@ func (s *Do) related(value interface{}, foreign_keys ...string) { var foreign_key string var err error - from := &Model{data: value, driver: s.driver} - to := &Model{data: s.value, driver: s.driver} + from := &Model{data: value, do: s} + to := &Model{data: s.value, do: s} foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id") for _, fk := range foreign_keys { @@ -535,7 +534,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { } return strings.Join(sqls, " AND ") case interface{}: - m := &Model{data: query, driver: s.driver} + m := &Model{data: query, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value))) @@ -596,7 +595,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { } return strings.Join(sqls, " AND ") case interface{}: - m := &Model{data: query, driver: s.driver} + m := &Model{data: query, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value))) @@ -759,7 +758,7 @@ func (s *Do) autoMigrate() *Do { } func (s *Do) initializeWithSearchCondition() { - m := Model{data: s.value, driver: s.driver} + m := Model{data: s.value, do: s} for _, clause := range s.whereClause { query := clause["query"] @@ -772,7 +771,7 @@ func (s *Do) initializeWithSearchCondition() { for _, obj := range query.([]interface{}) { switch reflect.ValueOf(obj).Kind() { case reflect.Struct: - m := &Model{data: obj, driver: s.driver} + m := &Model{data: obj, do: s} for _, field := range m.columnsHasValue("other") { m.setValueByColumn(field.DbName, field.Value, s.value) } @@ -783,7 +782,7 @@ func (s *Do) initializeWithSearchCondition() { } } case interface{}: - m := &Model{data: query, driver: s.driver} + m := &Model{data: query, do: s} for _, field := range m.columnsHasValue("other") { m.setValueByColumn(field.DbName, field.Value, s.value) } diff --git a/main.go b/main.go index 3d7d9d02..da8d33f1 100644 --- a/main.go +++ b/main.go @@ -34,7 +34,7 @@ func (s *DB) SingularTable(result bool) { } func (s *DB) buildChain() *Chain { - return &Chain{db: s.db, driver: s.driver} + return &Chain{db: s.db, d: s} } func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { diff --git a/model.go b/model.go index 50f7743b..c63392fc 100644 --- a/model.go +++ b/model.go @@ -13,7 +13,7 @@ import ( type Model struct { data interface{} - driver string + do *Do _cache_fields map[string][]Field } @@ -110,7 +110,7 @@ func (m *Model) fields(operation string) (fields []Field) { if is_scanner { field.IsBlank = !value.FieldByName("Valid").Interface().(bool) } else { - m := &Model{data: value.Interface(), driver: m.driver} + m := &Model{data: value.Interface(), do: m.do} fields := m.columnsHasValue("other") if len(fields) == 0 { @@ -135,9 +135,9 @@ func (m *Model) fields(operation string) (fields []Field) { value.Set(reflect.ValueOf(time.Now())) } } - field.SqlType = getSqlType(m.driver, value, 0) + field.SqlType = getSqlType(m.do.chain.driver(), value, 0) } else if field.IsPrimaryKey { - field.SqlType = getPrimaryKeySqlType(m.driver, value, 0) + field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0) } else { field_value := reflect.Indirect(value) @@ -152,7 +152,7 @@ func (m *Model) fields(operation string) (fields []Field) { _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) if is_scanner { - field.SqlType = getSqlType(m.driver, value, 0) + field.SqlType = getSqlType(m.do.chain.driver(), value, 0) } else { if indirect_value.FieldByName(p.Name + "Id").IsValid() { field.foreignKey = p.Name + "Id" @@ -166,7 +166,7 @@ func (m *Model) fields(operation string) (fields []Field) { } } default: - field.SqlType = getSqlType(m.driver, value, 0) + field.SqlType = getSqlType(m.do.chain.driver(), value, 0) } } @@ -326,7 +326,7 @@ func (m *Model) callMethod(method string) error { } func (m *Model) returningStr() (str string) { - if m.driver == "postgres" { + if m.do.chain.driver() == "postgres" { str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb()) } return