From 874856a5927c11bf68765d171e3a5331c2aef992 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Nov 2013 23:07:09 +0800 Subject: [PATCH] Cleanup unused code --- chain.go | 82 ++++++++++++++++------------------------------------ do.go | 42 ++++++++++----------------- gorm_test.go | 1 - main.go | 11 ++----- model.go | 1 - utils.go | 18 ++++++++++++ 6 files changed, 61 insertions(+), 94 deletions(-) diff --git a/chain.go b/chain.go index 88a19dc3..01064467 100644 --- a/chain.go +++ b/chain.go @@ -5,13 +5,11 @@ import ( "errors" "fmt" "regexp" - "strconv" ) type Chain struct { db *sql.DB driver string - debug bool value interface{} Errors []error @@ -30,20 +28,10 @@ type Chain struct { unscoped bool } -func (s *Chain) msg(str string) { - if s.debug { - debug(str) - } -} - func (s *Chain) err(err error) error { if err != nil { s.Errors = append(s.Errors, err) s.Error = err - - if s.debug { - debug(err) - } } return err } @@ -53,26 +41,25 @@ func (s *Chain) deleteLastError() { s.Errors = s.Errors[:len(s.Errors)-1] } -func (s *Chain) do(value interface{}) *Do { - var do Do - do.chain = s - do.db = s.db - do.driver = s.driver - - do.whereClause = s.whereClause - do.orClause = s.orClause - do.notClause = s.notClause - do.selectStr = s.selectStr - do.orderStrs = s.orderStrs - do.offsetStr = s.offsetStr - do.limitStr = s.limitStr - do.specifiedTableName = s.specifiedTableName - do.unscoped = s.unscoped - do.debug = s.debug +func (s *Chain) do(value interface{}) (do *Do) { + do = &Do{ + chain: s, + db: s.db, + driver: s.driver, + whereClause: s.whereClause, + orClause: s.orClause, + notClause: s.notClause, + selectStr: s.selectStr, + orderStrs: s.orderStrs, + offsetStr: s.offsetStr, + limitStr: s.limitStr, + specifiedTableName: s.specifiedTableName, + unscoped: s.unscoped, + } s.value = value do.setModel(value) - return &do + return } func (s *Chain) Model(model interface{}) *Chain { @@ -91,32 +78,18 @@ func (s *Chain) Not(querystring interface{}, args ...interface{}) *Chain { } func (s *Chain) Limit(value interface{}) *Chain { - switch value := value.(type) { - case string: - s.limitStr = value - case int: - if value < 0 { - s.limitStr = "" - } else { - s.limitStr = strconv.Itoa(value) - } - default: + if str, err := getInterfaceAsString(value); err == nil { + s.limitStr = str + } else { s.err(errors.New("Can' understand the value of Limit, Should be int")) } return s } func (s *Chain) Offset(value interface{}) *Chain { - switch value := value.(type) { - case string: - s.offsetStr = value - case int: - if value < 0 { - s.offsetStr = "" - } else { - s.offsetStr = strconv.Itoa(value) - } - default: + if str, err := getInterfaceAsString(value); err == nil { + s.offsetStr = str + } else { s.err(errors.New("Can' understand the value of Offset, Should be int")) } return s @@ -125,7 +98,7 @@ func (s *Chain) Offset(value interface{}) *Chain { func (s *Chain) Order(value string, reorder ...bool) *Chain { defer s.validSql(value) if len(reorder) > 0 && reorder[0] { - s.orderStrs = append([]string{}, value) + s.orderStrs = []string{value} } else { s.orderStrs = append(s.orderStrs, value) } @@ -196,8 +169,8 @@ func (s *Chain) Assign(attrs ...interface{}) *Chain { func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain { if s.First(out, where...).Error != nil { - s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() s.deleteLastError() + s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() } else { if len(s.assignAttrs) > 0 { s.do(out).setUpdateAttrs(s.assignAttrs).prepareUpdateAttrs() @@ -208,8 +181,8 @@ func (s *Chain) FirstOrInit(out interface{}, where ...interface{}) *Chain { func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain { if s.First(out, where...).Error != nil { - s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() s.deleteLastError() + s.do(out).where(where...).where(s.initAttrs).where(s.assignAttrs).initializeWithSearchCondition() s.Save(out) } else { if len(s.assignAttrs) > 0 { @@ -265,11 +238,6 @@ func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain { return s } -func (s *Chain) Debug() *Chain { - s.debug = true - return s -} - func (s *Chain) validSql(str string) (result bool) { result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str) if !result { diff --git a/do.go b/do.go index 5529b80c..cb60fbdc 100644 --- a/do.go +++ b/do.go @@ -18,14 +18,11 @@ type Do struct { driver string guessedTableName string specifiedTableName string - debug bool - Errors []error - model *Model - value interface{} - sqlResult sql.Result - sql string - sqlVars []interface{} + model *Model + value interface{} + sql string + sqlVars []interface{} whereClause []map[string]interface{} orClause []map[string]interface{} @@ -40,7 +37,7 @@ type Do struct { } func (s *Do) tableName() string { - if s.specifiedTableName == "" { + if len(s.specifiedTableName) == 0 { var err error s.guessedTableName, err = s.model.tableName() s.err(err) @@ -52,18 +49,17 @@ func (s *Do) tableName() string { func (s *Do) err(err error) error { if err != nil { - s.Errors = append(s.Errors, err) s.chain.err(err) } return err } func (s *Do) hasError() bool { - return len(s.Errors) > 0 + return len(s.chain.Errors) > 0 } func (s *Do) setModel(value interface{}) *Do { - s.model = &Model{data: value, driver: s.driver, debug: s.debug} + s.model = &Model{data: value, driver: s.driver} s.value = value return s } @@ -77,20 +73,15 @@ func (s *Do) addToVars(value interface{}) string { } } -func (s *Do) exec(sql ...string) { +func (s *Do) exec(sqls ...string) (err error) { if s.hasError() { return + } else if len(sqls) > 0 { + _, err = s.db.Exec(sqls[0]) + } else if len(s.sql) > 0 { + _, err = s.db.Exec(s.sql, s.sqlVars...) } - - var err error - if len(sql) == 0 { - if len(s.sql) > 0 { - s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) - } - } else { - s.sqlResult, err = s.db.Exec(sql[0]) - } - s.err(err) + return s.err(err) } func (s *Do) save() (i interface{}) { @@ -123,7 +114,6 @@ func (s *Do) prepareCreateSql() { func (s *Do) saveBeforeAssociations() { for _, field := range s.model.beforeAssociations() { var id interface{} - do := &Do{chain: s.chain, db: s.db, driver: s.driver} reflect_value := reflect.ValueOf(field.Value) @@ -192,10 +182,8 @@ func (s *Do) create() (i interface{}) { if s.driver == "postgres" { s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) } else { - var err error - s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) - if s.err(err) == nil { - id, err = s.sqlResult.LastInsertId() + if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { + id, err = sql_result.LastInsertId() s.err(err) } } diff --git a/gorm_test.go b/gorm_test.go index 7dab681f..e446d94e 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -92,7 +92,6 @@ func init() { } db.SetPool(10) - // db.DebugMode = true err = db.DropTable(&User{}).Error if err != nil { diff --git a/main.go b/main.go index daa28535..1f4ff90f 100644 --- a/main.go +++ b/main.go @@ -5,9 +5,8 @@ import "database/sql" var singularTableName bool type DB struct { - db *sql.DB - driver string - DebugMode bool + db *sql.DB + driver string } func Open(driver, source string) (db DB, err error) { @@ -25,7 +24,7 @@ func (s *DB) SingularTable(result bool) { } func (s *DB) buildChain() *Chain { - return &Chain{db: s.db, driver: s.driver, debug: s.DebugMode} + return &Chain{db: s.db, driver: s.driver} } func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { @@ -104,10 +103,6 @@ func (s *DB) Table(name string) *Chain { return s.buildChain().Table(name) } -func (s *DB) Debug() *Chain { - return s.buildChain().Debug() -} - func (s *DB) CreateTable(value interface{}) *Chain { return s.buildChain().CreateTable(value) } diff --git a/model.go b/model.go index c84d3bb1..50f7743b 100644 --- a/model.go +++ b/model.go @@ -14,7 +14,6 @@ import ( type Model struct { data interface{} driver string - debug bool _cache_fields map[string][]Field } diff --git a/utils.go b/utils.go index 0848ec0b..b4b7e4c2 100644 --- a/utils.go +++ b/utils.go @@ -2,6 +2,8 @@ package gorm import ( "bytes" + "errors" + "strconv" "fmt" "strings" @@ -46,6 +48,22 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) { return } +func getInterfaceAsString(value interface{}) (str string, err error) { + switch value := value.(type) { + case string: + str = value + case int: + if value < 0 { + str = "" + } else { + str = strconv.Itoa(value) + } + default: + err = errors.New(fmt.Sprintf("Can't understand %v", value)) + } + return +} + func debug(value interface{}) { fmt.Printf("***************\n") fmt.Printf("%+v\n\n", value)