Add Update, Updates back

This commit is contained in:
Jinzhu 2013-11-16 12:19:35 +08:00
parent ea67d1d377
commit cc03f438ef
3 changed files with 72 additions and 63 deletions

112
do.go
View File

@ -13,13 +13,11 @@ import (
)
type Do struct {
chain *Chain
db *DB
guessedTableName string
specifiedTableName string
model *Model
tableName string
startedTransaction bool
model *Model
value interface{}
sql string
sqlVars []interface{}
@ -36,18 +34,20 @@ type Do struct {
ignoreProtectedAttrs bool
}
func (s *Do) tableName() string {
if len(s.specifiedTableName) == 0 {
s.guessedTableName = s.model.tableName()
return s.guessedTableName
func (s *Do) table() string {
if len(s.tableName) == 0 {
if len(s.db.search.tableName) == 0 {
s.tableName = s.model.tableName()
} else {
return s.specifiedTableName
s.tableName = s.db.search.tableName
}
}
return s.tableName
}
func (s *Do) err(err error) error {
if err != nil {
s.chain.err(err)
s.db.err(err)
}
return err
}
@ -60,18 +60,18 @@ func (s *Do) setModel(value interface{}) *Do {
func (s *Do) addToVars(value interface{}) string {
s.sqlVars = append(s.sqlVars, value)
return fmt.Sprintf(s.chain.d.dialect.BinVar(), len(s.sqlVars))
return fmt.Sprintf(s.db.dialect.BinVar(), len(s.sqlVars))
}
func (s *Do) exec(sqls ...string) (err error) {
if !s.chain.hasError() {
if !s.db.hasError() {
if len(sqls) > 0 {
s.sql = sqls[0]
}
now := time.Now()
_, err = s.db.Exec(s.sql, s.sqlVars...)
s.chain.slog(s.sql, now, s.sqlVars...)
_, err = s.db.db.Exec(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
}
return s.err(err)
}
@ -95,17 +95,17 @@ func (s *Do) prepareCreateSql() {
s.sql = fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v) %v",
s.tableName(),
s.table(),
strings.Join(columns, ","),
strings.Join(sqls, ","),
s.chain.d.dialect.ReturningStr(s.model.primaryKeyDb()),
s.db.dialect.ReturningStr(s.model.primaryKeyDb()),
)
return
}
func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() {
do := &Do{chain: s.chain, db: s.db}
do := &Do{db: s.db}
reflect_value := reflect.ValueOf(field.Value)
if reflect_value.CanAddr() {
@ -134,7 +134,7 @@ func (s *Do) saveAfterAssociations() {
switch reflect_value.Kind() {
case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ {
do := &Do{chain: s.chain, db: s.db}
do := &Do{db: s.db}
value := reflect_value.Index(i).Addr().Interface()
if len(field.foreignKey) > 0 {
@ -143,7 +143,7 @@ func (s *Do) saveAfterAssociations() {
do.setModel(value).save()
}
default:
do := &Do{chain: s.chain, db: s.db}
do := &Do{db: s.db}
if reflect_value.CanAddr() {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
do.setModel(field.Value).save()
@ -170,21 +170,21 @@ func (s *Do) create() (i interface{}) {
s.saveBeforeAssociations()
s.prepareCreateSql()
if !s.chain.hasError() {
if !s.db.hasError() {
var id interface{}
now := time.Now()
if s.chain.d.dialect.SupportLastInsertId() {
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
if s.db.dialect.SupportLastInsertId() {
if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
id, err = sql_result.LastInsertId()
s.err(err)
}
} else {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
}
s.chain.slog(s.sql, now, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
if !s.chain.hasError() {
if !s.db.hasError() {
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
s.saveAfterAssociations()
@ -236,7 +236,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
s.sql = fmt.Sprintf(
"UPDATE %v SET %v %v",
s.tableName(),
s.table(),
strings.Join(sqls, ", "),
s.combinedSql(),
)
@ -258,7 +258,7 @@ func (s *Do) update() *Do {
s.saveBeforeAssociations()
s.prepareUpdateSql(update_attrs)
if !s.chain.hasError() {
if !s.db.hasError() {
s.exec()
s.saveAfterAssociations()
@ -272,11 +272,11 @@ func (s *Do) update() *Do {
func (s *Do) delete() *Do {
s.model.callMethod("BeforeDelete")
if !s.chain.hasError() {
if !s.db.hasError() {
if !s.unscoped && s.model.hasColumn("DeletedAt") {
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql())
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql())
} else {
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql())
}
s.exec()
s.model.callMethod("AfterDelete")
@ -285,7 +285,7 @@ func (s *Do) delete() *Do {
}
func (s *Do) prepareQuerySql() {
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql())
return
}
@ -358,10 +358,10 @@ func (s *Do) query() {
}
s.prepareQuerySql()
if !s.chain.hasError() {
if !s.db.hasError() {
now := time.Now()
rows, err := s.db.Query(s.sql, s.sqlVars...)
s.chain.slog(s.sql, now, s.sqlVars...)
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
if s.err(err) != nil {
return
@ -402,10 +402,10 @@ func (s *Do) query() {
func (s *Do) count(value interface{}) {
s.prepareQuerySql()
if !s.chain.hasError() {
if !s.db.hasError() {
now := time.Now()
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
s.chain.slog(s.sql, now, s.sqlVars...)
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
s.db.slog(s.sql, now, s.sqlVars...)
}
}
@ -420,10 +420,10 @@ func (s *Do) pluck(column string, value interface{}) {
s.prepareQuerySql()
if !s.chain.hasError() {
if !s.db.hasError() {
now := time.Now()
rows, err := s.db.Query(s.sql, s.sqlVars...)
s.chain.slog(s.sql, now, s.sqlVars...)
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
if s.err(err) == nil {
defer rows.Close()
@ -645,25 +645,25 @@ func (s *Do) createTable() *Do {
}
}
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.tableName(), strings.Join(sqls, ","))
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))
s.exec()
return s
}
func (s *Do) dropTable() *Do {
s.sql = fmt.Sprintf("DROP TABLE %v", s.tableName())
s.sql = fmt.Sprintf("DROP TABLE %v", s.table())
s.exec()
return s
}
func (s *Do) updateColumn(column string, typ string) {
s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.tableName(), column, typ)
s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ)
s.exec()
}
func (s *Do) dropColumn(column string) {
s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.tableName(), column)
s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column)
s.exec()
}
@ -672,22 +672,22 @@ func (s *Do) addIndex(column string, names ...string) {
if len(names) > 0 {
index_name = names[0]
} else {
index_name = fmt.Sprintf("index_%v_on_%v", s.tableName(), column)
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
}
s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.tableName(), column)
s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column)
s.exec()
}
func (s *Do) removeIndex(index_name string) {
s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.tableName())
s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table())
s.exec()
}
func (s *Do) autoMigrate() *Do {
var table_name string
sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.tableName()))
s.db.QueryRow(sql, s.sqlVars...).Scan(&table_name)
sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table()))
s.db.db.QueryRow(sql, s.sqlVars...).Scan(&table_name)
s.sqlVars = []interface{}{}
// If table doesn't exist
@ -696,13 +696,13 @@ func (s *Do) autoMigrate() *Do {
} else {
for _, field := range s.model.fields("migration") {
var column_name, data_type string
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName()))
s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.table()))
s.db.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
s.sqlVars = []interface{}{}
// If column doesn't exist
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.dbName, field.sqlTag())
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag())
s.exec()
}
}
@ -711,9 +711,9 @@ func (s *Do) autoMigrate() *Do {
}
func (s *Do) begin() *Do {
if db, ok := s.db.(sqlDb); ok {
if db, ok := s.db.db.(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
s.db = interface{}(tx).(sqlCommon)
s.db.db = interface{}(tx).(sqlCommon)
s.startedTransaction = true
}
}
@ -722,8 +722,8 @@ func (s *Do) begin() *Do {
func (s *Do) commit_or_rollback() {
if s.startedTransaction {
if db, ok := s.db.(sqlTx); ok {
if s.chain.hasError() {
if db, ok := s.db.db.(sqlTx); ok {
if s.db.hasError() {
db.Rollback()
} else {
db.Commit()

13
main.go
View File

@ -109,7 +109,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
s.clone().do(out).where(where).initialize()
} else {
if len(s.search.assignAttrs) > 0 {
s.do(out).updateAttrs(s.assignAttrs) //updated or not
s.do(out).updateAttrs(s.search.assignAttrs) //updated or not
}
}
return s
@ -127,13 +127,22 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
return s
}
func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true)
}
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
s.do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
return s
}
func (s *DB) Save(value interface{}) *DB {
s.do(value).begin().save().commit_or_rollback()
return s
}
func (s *DB) Delete(value interface{}) *DB {
s.do(value).bengin().delete(value).commit_or_rollback()
s.do(value).begin().delete().commit_or_rollback()
return s
}

View File

@ -29,13 +29,13 @@ func (s *DB) hasError() bool {
}
func (s *DB) print(level string, v ...interface{}) {
if s.d.logMode || s.debug_mode || level == "debug" {
if _, ok := s.d.logger.(Logger); !ok {
if s.logMode || level == "debug" {
if _, ok := s.parent.logger.(Logger); !ok {
fmt.Println("logger haven't been set, using os.Stdout")
s.d.logger = default_logger
s.parent.logger = default_logger
}
args := []interface{}{level}
s.d.logger.(Logger).Print(append(args, v...)...)
s.parent.logger.(Logger).Print(append(args, v...)...)
}
}