Cleanup unused code

This commit is contained in:
Jinzhu 2013-11-10 23:07:09 +08:00
parent 0cb1c1ba32
commit 874856a592
6 changed files with 61 additions and 94 deletions

View File

@ -5,13 +5,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"strconv"
) )
type Chain struct { type Chain struct {
db *sql.DB db *sql.DB
driver string driver string
debug bool
value interface{} value interface{}
Errors []error Errors []error
@ -30,20 +28,10 @@ type Chain struct {
unscoped bool unscoped bool
} }
func (s *Chain) msg(str string) {
if s.debug {
debug(str)
}
}
func (s *Chain) err(err error) error { func (s *Chain) err(err error) error {
if err != nil { if err != nil {
s.Errors = append(s.Errors, err) s.Errors = append(s.Errors, err)
s.Error = err s.Error = err
if s.debug {
debug(err)
}
} }
return err return err
} }
@ -53,26 +41,25 @@ func (s *Chain) deleteLastError() {
s.Errors = s.Errors[:len(s.Errors)-1] s.Errors = s.Errors[:len(s.Errors)-1]
} }
func (s *Chain) do(value interface{}) *Do { func (s *Chain) do(value interface{}) (do *Do) {
var do Do do = &Do{
do.chain = s chain: s,
do.db = s.db db: s.db,
do.driver = s.driver driver: s.driver,
whereClause: s.whereClause,
do.whereClause = s.whereClause orClause: s.orClause,
do.orClause = s.orClause notClause: s.notClause,
do.notClause = s.notClause selectStr: s.selectStr,
do.selectStr = s.selectStr orderStrs: s.orderStrs,
do.orderStrs = s.orderStrs offsetStr: s.offsetStr,
do.offsetStr = s.offsetStr limitStr: s.limitStr,
do.limitStr = s.limitStr specifiedTableName: s.specifiedTableName,
do.specifiedTableName = s.specifiedTableName unscoped: s.unscoped,
do.unscoped = s.unscoped }
do.debug = s.debug
s.value = value s.value = value
do.setModel(value) do.setModel(value)
return &do return
} }
func (s *Chain) Model(model interface{}) *Chain { func (s *Chain) Model(model interface{}) *Chain {
@ -91,32 +78,18 @@ func (s *Chain) Not(querystring interface{}, args ...interface{}) *Chain {
} }
func (s *Chain) Limit(value interface{}) *Chain { func (s *Chain) Limit(value interface{}) *Chain {
switch value := value.(type) { if str, err := getInterfaceAsString(value); err == nil {
case string: s.limitStr = str
s.limitStr = value } else {
case int:
if value < 0 {
s.limitStr = ""
} else {
s.limitStr = strconv.Itoa(value)
}
default:
s.err(errors.New("Can' understand the value of Limit, Should be int")) s.err(errors.New("Can' understand the value of Limit, Should be int"))
} }
return s return s
} }
func (s *Chain) Offset(value interface{}) *Chain { func (s *Chain) Offset(value interface{}) *Chain {
switch value := value.(type) { if str, err := getInterfaceAsString(value); err == nil {
case string: s.offsetStr = str
s.offsetStr = value } else {
case int:
if value < 0 {
s.offsetStr = ""
} else {
s.offsetStr = strconv.Itoa(value)
}
default:
s.err(errors.New("Can' understand the value of Offset, Should be int")) s.err(errors.New("Can' understand the value of Offset, Should be int"))
} }
return s return s
@ -125,7 +98,7 @@ func (s *Chain) Offset(value interface{}) *Chain {
func (s *Chain) Order(value string, reorder ...bool) *Chain { func (s *Chain) Order(value string, reorder ...bool) *Chain {
defer s.validSql(value) defer s.validSql(value)
if len(reorder) > 0 && reorder[0] { if len(reorder) > 0 && reorder[0] {
s.orderStrs = append([]string{}, value) s.orderStrs = []string{value}
} else { } else {
s.orderStrs = append(s.orderStrs, value) s.orderStrs = append(s.orderStrs, value)
} }
@ -196,8 +169,8 @@ func (s *Chain) Assign(attrs ...interface{}) *Chain {
func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain { func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain {
if s.First(out, where...).Error != nil { if s.First(out, where...).Error != nil {
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
s.deleteLastError() s.deleteLastError()
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
} else { } else {
if len(s.assignAttrs) > 0 { if len(s.assignAttrs) > 0 {
s.do(out).setUpdateAttrs(s.assignAttrs).prepareUpdateAttrs() s.do(out).setUpdateAttrs(s.assignAttrs).prepareUpdateAttrs()
@ -208,8 +181,8 @@ func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain {
func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain { func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain {
if s.First(out, where...).Error != nil { if s.First(out, where...).Error != nil {
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
s.deleteLastError() s.deleteLastError()
s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition()
s.Save(out) s.Save(out)
} else { } else {
if len(s.assignAttrs) > 0 { if len(s.assignAttrs) > 0 {
@ -265,11 +238,6 @@ func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain {
return s return s
} }
func (s *Chain) Debug() *Chain {
s.debug = true
return s
}
func (s *Chain) validSql(str string) (result bool) { func (s *Chain) validSql(str string) (result bool) {
result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str) result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str)
if !result { if !result {

42
do.go
View File

@ -18,14 +18,11 @@ type Do struct {
driver string driver string
guessedTableName string guessedTableName string
specifiedTableName string specifiedTableName string
debug bool
Errors []error
model *Model model *Model
value interface{} value interface{}
sqlResult sql.Result sql string
sql string sqlVars []interface{}
sqlVars []interface{}
whereClause []map[string]interface{} whereClause []map[string]interface{}
orClause []map[string]interface{} orClause []map[string]interface{}
@ -40,7 +37,7 @@ type Do struct {
} }
func (s *Do) tableName() string { func (s *Do) tableName() string {
if s.specifiedTableName == "" { if len(s.specifiedTableName) == 0 {
var err error var err error
s.guessedTableName, err = s.model.tableName() s.guessedTableName, err = s.model.tableName()
s.err(err) s.err(err)
@ -52,18 +49,17 @@ func (s *Do) tableName() string {
func (s *Do) err(err error) error { func (s *Do) err(err error) error {
if err != nil { if err != nil {
s.Errors = append(s.Errors, err)
s.chain.err(err) s.chain.err(err)
} }
return err return err
} }
func (s *Do) hasError() bool { func (s *Do) hasError() bool {
return len(s.Errors) > 0 return len(s.chain.Errors) > 0
} }
func (s *Do) setModel(value interface{}) *Do { func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, driver: s.driver, debug: s.debug} s.model = &Model{data: value, driver: s.driver}
s.value = value s.value = value
return s return s
} }
@ -77,20 +73,15 @@ func (s *Do) addToVars(value interface{}) string {
} }
} }
func (s *Do) exec(sql ...string) { func (s *Do) exec(sqls ...string) (err error) {
if s.hasError() { if s.hasError() {
return return
} else if len(sqls) > 0 {
_, err = s.db.Exec(sqls[0])
} else if len(s.sql) > 0 {
_, err = s.db.Exec(s.sql, s.sqlVars...)
} }
return s.err(err)
var err error
if len(sql) == 0 {
if len(s.sql) > 0 {
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
}
} else {
s.sqlResult, err = s.db.Exec(sql[0])
}
s.err(err)
} }
func (s *Do) save() (i interface{}) { func (s *Do) save() (i interface{}) {
@ -123,7 +114,6 @@ func (s *Do) prepareCreateSql() {
func (s *Do) saveBeforeAssociations() { func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() { for _, field := range s.model.beforeAssociations() {
var id interface{} var id interface{}
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db, driver: s.driver}
reflect_value := reflect.ValueOf(field.Value) reflect_value := reflect.ValueOf(field.Value)
@ -192,10 +182,8 @@ func (s *Do) create() (i interface{}) {
if s.driver == "postgres" { if s.driver == "postgres" {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} else { } else {
var err error if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) id, err = sql_result.LastInsertId()
if s.err(err) == nil {
id, err = s.sqlResult.LastInsertId()
s.err(err) s.err(err)
} }
} }

View File

@ -92,7 +92,6 @@ func init() {
} }
db.SetPool(10) db.SetPool(10)
// db.DebugMode = true
err = db.DropTable(&User{}).Error err = db.DropTable(&User{}).Error
if err != nil { if err != nil {

11
main.go
View File

@ -5,9 +5,8 @@ import "database/sql"
var singularTableName bool var singularTableName bool
type DB struct { type DB struct {
db *sql.DB db *sql.DB
driver string driver string
DebugMode bool
} }
func Open(driver, source string) (db DB, err error) { func Open(driver, source string) (db DB, err error) {
@ -25,7 +24,7 @@ func (s *DB) SingularTable(result bool) {
} }
func (s *DB) buildChain() *Chain { func (s *DB) buildChain() *Chain {
return &Chain{db: s.db, driver: s.driver, debug: s.DebugMode} return &Chain{db: s.db, driver: s.driver}
} }
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
@ -104,10 +103,6 @@ func (s *DB) Table(name string) *Chain {
return s.buildChain().Table(name) return s.buildChain().Table(name)
} }
func (s *DB) Debug() *Chain {
return s.buildChain().Debug()
}
func (s *DB) CreateTable(value interface{}) *Chain { func (s *DB) CreateTable(value interface{}) *Chain {
return s.buildChain().CreateTable(value) return s.buildChain().CreateTable(value)
} }

View File

@ -14,7 +14,6 @@ import (
type Model struct { type Model struct {
data interface{} data interface{}
driver string driver string
debug bool
_cache_fields map[string][]Field _cache_fields map[string][]Field
} }

View File

@ -2,6 +2,8 @@ package gorm
import ( import (
"bytes" "bytes"
"errors"
"strconv"
"fmt" "fmt"
"strings" "strings"
@ -46,6 +48,22 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
return return
} }
func getInterfaceAsString(value interface{}) (str string, err error) {
switch value := value.(type) {
case string:
str = value
case int:
if value < 0 {
str = ""
} else {
str = strconv.Itoa(value)
}
default:
err = errors.New(fmt.Sprintf("Can't understand %v", value))
}
return
}
func debug(value interface{}) { func debug(value interface{}) {
fmt.Printf("***************\n") fmt.Printf("***************\n")
fmt.Printf("%+v\n\n", value) fmt.Printf("%+v\n\n", value)