From cc03f438efbdd46d4c74695ab84829633169e29b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 16 Nov 2013 12:19:35 +0800 Subject: [PATCH] Add Update, Updates back --- do.go | 114 ++++++++++++++++++++++++++--------------------------- main.go | 13 +++++- private.go | 8 ++-- 3 files changed, 72 insertions(+), 63 deletions(-) diff --git a/do.go b/do.go index bb637e19..3414efcc 100644 --- a/do.go +++ b/do.go @@ -13,13 +13,11 @@ import ( ) type Do struct { - chain *Chain db *DB - guessedTableName string - specifiedTableName string + model *Model + tableName string startedTransaction bool - model *Model value interface{} sql string sqlVars []interface{} @@ -36,18 +34,20 @@ type Do struct { ignoreProtectedAttrs bool } -func (s *Do) tableName() string { - if len(s.specifiedTableName) == 0 { - s.guessedTableName = s.model.tableName() - return s.guessedTableName - } else { - return s.specifiedTableName +func (s *Do) table() string { + if len(s.tableName) == 0 { + if len(s.db.search.tableName) == 0 { + s.tableName = s.model.tableName() + } else { + s.tableName = s.db.search.tableName + } } + return s.tableName } func (s *Do) err(err error) error { if err != nil { - s.chain.err(err) + s.db.err(err) } return err } @@ -60,18 +60,18 @@ func (s *Do) setModel(value interface{}) *Do { func (s *Do) addToVars(value interface{}) string { s.sqlVars = append(s.sqlVars, value) - return fmt.Sprintf(s.chain.d.dialect.BinVar(), len(s.sqlVars)) + return fmt.Sprintf(s.db.dialect.BinVar(), len(s.sqlVars)) } func (s *Do) exec(sqls ...string) (err error) { - if !s.chain.hasError() { + if !s.db.hasError() { if len(sqls) > 0 { s.sql = sqls[0] } now := time.Now() - _, err = s.db.Exec(s.sql, s.sqlVars...) - s.chain.slog(s.sql, now, s.sqlVars...) + _, err = s.db.db.Exec(s.sql, s.sqlVars...) + s.db.slog(s.sql, now, s.sqlVars...) } return s.err(err) } @@ -95,17 +95,17 @@ func (s *Do) prepareCreateSql() { s.sql = fmt.Sprintf( "INSERT INTO %v (%v) VALUES (%v) %v", - s.tableName(), + s.table(), strings.Join(columns, ","), strings.Join(sqls, ","), - s.chain.d.dialect.ReturningStr(s.model.primaryKeyDb()), + s.db.dialect.ReturningStr(s.model.primaryKeyDb()), ) return } func (s *Do) saveBeforeAssociations() { for _, field := range s.model.beforeAssociations() { - do := &Do{chain: s.chain, db: s.db} + do := &Do{db: s.db} reflect_value := reflect.ValueOf(field.Value) if reflect_value.CanAddr() { @@ -134,7 +134,7 @@ func (s *Do) saveAfterAssociations() { switch reflect_value.Kind() { case reflect.Slice: for i := 0; i < reflect_value.Len(); i++ { - do := &Do{chain: s.chain, db: s.db} + do := &Do{db: s.db} value := reflect_value.Index(i).Addr().Interface() if len(field.foreignKey) > 0 { @@ -143,7 +143,7 @@ func (s *Do) saveAfterAssociations() { do.setModel(value).save() } default: - do := &Do{chain: s.chain, db: s.db} + do := &Do{db: s.db} if reflect_value.CanAddr() { s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value) do.setModel(field.Value).save() @@ -170,21 +170,21 @@ func (s *Do) create() (i interface{}) { s.saveBeforeAssociations() s.prepareCreateSql() - if !s.chain.hasError() { + if !s.db.hasError() { var id interface{} now := time.Now() - if s.chain.d.dialect.SupportLastInsertId() { - if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { + if s.db.dialect.SupportLastInsertId() { + if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { id, err = sql_result.LastInsertId() s.err(err) } } else { - s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) + s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) } - s.chain.slog(s.sql, now, s.sqlVars...) + s.db.slog(s.sql, now, s.sqlVars...) - if !s.chain.hasError() { + if !s.db.hasError() { s.model.setValueByColumn(s.model.primaryKey(), id, s.value) s.saveAfterAssociations() @@ -236,7 +236,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) { s.sql = fmt.Sprintf( "UPDATE %v SET %v %v", - s.tableName(), + s.table(), strings.Join(sqls, ", "), s.combinedSql(), ) @@ -258,7 +258,7 @@ func (s *Do) update() *Do { s.saveBeforeAssociations() s.prepareUpdateSql(update_attrs) - if !s.chain.hasError() { + if !s.db.hasError() { s.exec() s.saveAfterAssociations() @@ -272,11 +272,11 @@ func (s *Do) update() *Do { func (s *Do) delete() *Do { s.model.callMethod("BeforeDelete") - if !s.chain.hasError() { + if !s.db.hasError() { if !s.unscoped && s.model.hasColumn("DeletedAt") { - s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql()) + s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql()) } else { - s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql()) + s.sql = fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql()) } s.exec() s.model.callMethod("AfterDelete") @@ -285,7 +285,7 @@ func (s *Do) delete() *Do { } func (s *Do) prepareQuerySql() { - s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql()) + s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql()) return } @@ -358,10 +358,10 @@ func (s *Do) query() { } s.prepareQuerySql() - if !s.chain.hasError() { + if !s.db.hasError() { now := time.Now() - rows, err := s.db.Query(s.sql, s.sqlVars...) - s.chain.slog(s.sql, now, s.sqlVars...) + rows, err := s.db.db.Query(s.sql, s.sqlVars...) + s.db.slog(s.sql, now, s.sqlVars...) if s.err(err) != nil { return @@ -402,10 +402,10 @@ func (s *Do) query() { func (s *Do) count(value interface{}) { s.prepareQuerySql() - if !s.chain.hasError() { + if !s.db.hasError() { now := time.Now() - s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(value)) - s.chain.slog(s.sql, now, s.sqlVars...) + s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value)) + s.db.slog(s.sql, now, s.sqlVars...) } } @@ -420,10 +420,10 @@ func (s *Do) pluck(column string, value interface{}) { s.prepareQuerySql() - if !s.chain.hasError() { + if !s.db.hasError() { now := time.Now() - rows, err := s.db.Query(s.sql, s.sqlVars...) - s.chain.slog(s.sql, now, s.sqlVars...) + rows, err := s.db.db.Query(s.sql, s.sqlVars...) + s.db.slog(s.sql, now, s.sqlVars...) if s.err(err) == nil { defer rows.Close() @@ -645,25 +645,25 @@ func (s *Do) createTable() *Do { } } - s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.tableName(), strings.Join(sqls, ",")) + s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")) s.exec() return s } func (s *Do) dropTable() *Do { - s.sql = fmt.Sprintf("DROP TABLE %v", s.tableName()) + s.sql = fmt.Sprintf("DROP TABLE %v", s.table()) s.exec() return s } func (s *Do) updateColumn(column string, typ string) { - s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.tableName(), column, typ) + s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ) s.exec() } func (s *Do) dropColumn(column string) { - s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.tableName(), column) + s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column) s.exec() } @@ -672,22 +672,22 @@ func (s *Do) addIndex(column string, names ...string) { if len(names) > 0 { index_name = names[0] } else { - index_name = fmt.Sprintf("index_%v_on_%v", s.tableName(), column) + index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column) } - s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.tableName(), column) + s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column) s.exec() } func (s *Do) removeIndex(index_name string) { - s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.tableName()) + s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table()) s.exec() } func (s *Do) autoMigrate() *Do { var table_name string - sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.tableName())) - s.db.QueryRow(sql, s.sqlVars...).Scan(&table_name) + sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table())) + s.db.db.QueryRow(sql, s.sqlVars...).Scan(&table_name) s.sqlVars = []interface{}{} // If table doesn't exist @@ -696,13 +696,13 @@ func (s *Do) autoMigrate() *Do { } else { for _, field := range s.model.fields("migration") { var column_name, data_type string - sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName())) - s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type) + sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.table())) + s.db.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type) s.sqlVars = []interface{}{} // If column doesn't exist if len(column_name) == 0 && len(field.sqlTag()) > 0 { - s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.dbName, field.sqlTag()) + s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag()) s.exec() } } @@ -711,9 +711,9 @@ func (s *Do) autoMigrate() *Do { } func (s *Do) begin() *Do { - if db, ok := s.db.(sqlDb); ok { + if db, ok := s.db.db.(sqlDb); ok { if tx, err := db.Begin(); err == nil { - s.db = interface{}(tx).(sqlCommon) + s.db.db = interface{}(tx).(sqlCommon) s.startedTransaction = true } } @@ -722,8 +722,8 @@ func (s *Do) begin() *Do { func (s *Do) commit_or_rollback() { if s.startedTransaction { - if db, ok := s.db.(sqlTx); ok { - if s.chain.hasError() { + if db, ok := s.db.db.(sqlTx); ok { + if s.db.hasError() { db.Rollback() } else { db.Commit() diff --git a/main.go b/main.go index 2b3c677d..512a6f48 100644 --- a/main.go +++ b/main.go @@ -109,7 +109,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { s.clone().do(out).where(where).initialize() } else { if len(s.search.assignAttrs) > 0 { - s.do(out).updateAttrs(s.assignAttrs) //updated or not + s.do(out).updateAttrs(s.search.assignAttrs) //updated or not } } return s @@ -127,13 +127,22 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { return s } +func (s *DB) Update(attrs ...interface{}) *DB { + return s.Updates(toSearchableMap(attrs...), true) +} + +func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB { + s.do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback() + return s +} + func (s *DB) Save(value interface{}) *DB { s.do(value).begin().save().commit_or_rollback() return s } func (s *DB) Delete(value interface{}) *DB { - s.do(value).bengin().delete(value).commit_or_rollback() + s.do(value).begin().delete().commit_or_rollback() return s } diff --git a/private.go b/private.go index 8844bfc1..2ff6aef3 100644 --- a/private.go +++ b/private.go @@ -29,13 +29,13 @@ func (s *DB) hasError() bool { } func (s *DB) print(level string, v ...interface{}) { - if s.d.logMode || s.debug_mode || level == "debug" { - if _, ok := s.d.logger.(Logger); !ok { + if s.logMode || level == "debug" { + if _, ok := s.parent.logger.(Logger); !ok { fmt.Println("logger haven't been set, using os.Stdout") - s.d.logger = default_logger + s.parent.logger = default_logger } args := []interface{}{level} - s.d.logger.(Logger).Print(append(args, v...)...) + s.parent.logger.(Logger).Print(append(args, v...)...) } }