diff --git a/chain.go b/chain.go index 36380ce1..28ba14fe 100644 --- a/chain.go +++ b/chain.go @@ -7,9 +7,10 @@ import ( ) type Chain struct { - d *DB - db sql_common - value interface{} + d *DB + db sql_common + value interface{} + debug_mode bool Errors []error Error error @@ -35,7 +36,7 @@ func (s *Chain) err(err error) error { if err != nil { s.Errors = append(s.Errors, err) s.Error = err - warn(err) + s.warn(err) } return err } @@ -257,6 +258,11 @@ func (s *Chain) Begin() *Chain { return s } +func (s *Chain) Debug() *Chain { + s.debug_mode = true + return s +} + func (s *Chain) Commit() *Chain { if db, ok := s.db.(sql_tx); ok { s.err(db.Commit()) diff --git a/do.go b/do.go index f2c35fa7..32e974c7 100644 --- a/do.go +++ b/do.go @@ -77,8 +77,9 @@ func (s *Do) exec(sqls ...string) (err error) { if len(sqls) > 0 { s.sql = sqls[0] } + now := time.Now() _, err = s.db.Exec(s.sql, s.sqlVars...) - slog(s.sql, s.sqlVars...) + s.chain.slog(s.sql, now, s.sqlVars...) } return s.err(err) } @@ -177,6 +178,7 @@ func (s *Do) create() (i interface{}) { s.prepareCreateSql() if !s.hasError() { + now := time.Now() var id interface{} if s.chain.driver() == "postgres" { s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) @@ -186,7 +188,7 @@ func (s *Do) create() (i interface{}) { s.err(err) } } - slog(s.sql, s.sqlVars...) + s.chain.slog(s.sql, now, s.sqlVars...) if !s.hasError() { result := reflect.Indirect(reflect.ValueOf(s.value)) @@ -377,8 +379,9 @@ func (s *Do) query() { s.prepareQuerySql() if !s.hasError() { + now := time.Now() rows, err := s.db.Query(s.sql, s.sqlVars...) - slog(s.sql, s.sqlVars...) + s.chain.slog(s.sql, now, s.sqlVars...) if s.err(err) != nil { return } @@ -433,8 +436,9 @@ func (s *Do) count(value interface{}) { s.prepareQuerySql() if !s.hasError() { + now := time.Now() rows, err := s.db.Query(s.sql, s.sqlVars...) - slog(s.sql, s.sqlVars...) + s.chain.slog(s.sql, now, s.sqlVars...) if s.err(err) != nil { return } @@ -463,8 +467,9 @@ func (s *Do) pluck(column string, value interface{}) { s.prepareQuerySql() if !s.hasError() { + now := time.Now() rows, err := s.db.Query(s.sql, s.sqlVars...) - slog(s.sql, s.sqlVars...) + s.chain.slog(s.sql, now, s.sqlVars...) if s.err(err) != nil { return } @@ -512,7 +517,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { id, _ := strconv.Atoi(value) return s.primaryCondiation(s.addToVars(id)) } else { - str = "( " + value + " )" + str = "(" + value + ")" } case int, int64, int32: return s.primaryCondiation(s.addToVars(query)) @@ -524,14 +529,14 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { case map[string]interface{}: var sqls []string for key, value := range query.(map[string]interface{}) { - sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", key, s.addToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value))) } return strings.Join(sqls, " AND ") case interface{}: m := &Model{data: query, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf(" (%v = %v) ", field.DbName, s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") } @@ -585,14 +590,14 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { case map[string]interface{}: var sqls []string for key, value := range query.(map[string]interface{}) { - sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", key, s.addToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value))) } return strings.Join(sqls, " AND ") case interface{}: m := &Model{data: query, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") } @@ -654,7 +659,7 @@ func (s *Do) whereSql() (sql string) { if len(primary_condiations) > 0 { sql = "WHERE " + strings.Join(primary_condiations, " AND ") if len(combined_conditions) > 0 { - sql = sql + " AND ( " + combined_conditions + " )" + sql = sql + " AND (" + combined_conditions + ")" } } else if len(combined_conditions) > 0 { sql = "WHERE " + combined_conditions @@ -664,7 +669,7 @@ func (s *Do) whereSql() (sql string) { func (s *Do) selectSql() string { if len(s.selectStr) == 0 { - return " * " + return "*" } else { return s.selectStr } diff --git a/gorm_test.go b/gorm_test.go index 8c38d61c..25938301 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -86,7 +86,7 @@ func init() { // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; // db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") // db, err = Open("sqlite3", "/tmp/gorm.db") - db.LogMode(false) + db.LogMode(true) if err != nil { panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err)) diff --git a/logger.go b/logger.go index b23869f8..2caf9a5c 100644 --- a/logger.go +++ b/logger.go @@ -5,42 +5,32 @@ import ( "log" "os" "regexp" + "time" ) -var logger interface{} -var logger_disabled bool - type Logger interface { Print(v ...interface{}) } -func print(level string, v ...interface{}) { - if logger_disabled && level != "debug" { - return +func (s *Chain) print(level string, v ...interface{}) { + if s.d.log_mode || s.debug_mode || level == "debug" { + if _, ok := s.d.logger.(Logger); !ok { + fmt.Println("logger haven't been set, using os.Stdout") + s.d.logger = log.New(os.Stdout, "", 0) + } + args := []interface{}{level} + s.d.logger.(Logger).Print(append(args, v...)) } - - var has_valid_logger bool - if logger, has_valid_logger = logger.(Logger); !has_valid_logger { - fmt.Println("logger haven't been set, using os.Stdout") - logger = log.New(os.Stdout, "", 0) - } - - args := []interface{}{level} - logger.(Logger).Print(append(args, v...)) } -func warn(v ...interface{}) { - go print("warn", v...) +func (s *Chain) warn(v ...interface{}) { + go s.print("warn", v...) } -func info(v ...interface{}) { - go print("info", v...) +func (s *Chain) slog(sql string, t time.Time, vars ...interface{}) { + go s.print("sql", time.Now().Sub(t), fmt.Sprintf(regexp.MustCompile(`\$\d|\?`).ReplaceAllString(sql, "'%v'"), vars...)) } -func slog(sql string, vars ...interface{}) { - go print("sql", fmt.Sprintf(regexp.MustCompile(`\$\d|\?`).ReplaceAllString(sql, "'%v'"), vars...)) -} - -func debug(v ...interface{}) { - go print("debug", v...) +func (s *Chain) debug(v ...interface{}) { + go s.print("debug", v...) } diff --git a/main.go b/main.go index da8d33f1..f5f61401 100644 --- a/main.go +++ b/main.go @@ -5,8 +5,10 @@ import "database/sql" var singularTableName bool type DB struct { - db sql_common - driver string + db sql_common + driver string + logger Logger + log_mode bool } func Open(driver, source string) (db DB, err error) { @@ -21,12 +23,12 @@ func (s *DB) SetPool(n int) { } } -func (s *DB) SetLogger(l interface{}) { - logger = l +func (s *DB) SetLogger(l Logger) { + s.logger = l } func (s *DB) LogMode(b bool) { - logger_disabled = !b + s.log_mode = b } func (s *DB) SingularTable(result bool) { @@ -125,6 +127,10 @@ func (s *DB) AutoMigrate(value interface{}) *Chain { return s.buildChain().AutoMigrate(value) } +func (s *DB) Debug() *Chain { + return s.buildChain().Debug() +} + func (s *DB) Begin() *Chain { return s.buildChain().Begin() }