forked from mirror/gorm
Add Update, Updates back
This commit is contained in:
parent
ea67d1d377
commit
cc03f438ef
112
do.go
112
do.go
|
@ -13,13 +13,11 @@ import (
|
|||
)
|
||||
|
||||
type Do struct {
|
||||
chain *Chain
|
||||
db *DB
|
||||
guessedTableName string
|
||||
specifiedTableName string
|
||||
model *Model
|
||||
tableName string
|
||||
startedTransaction bool
|
||||
|
||||
model *Model
|
||||
value interface{}
|
||||
sql string
|
||||
sqlVars []interface{}
|
||||
|
@ -36,18 +34,20 @@ type Do struct {
|
|||
ignoreProtectedAttrs bool
|
||||
}
|
||||
|
||||
func (s *Do) tableName() string {
|
||||
if len(s.specifiedTableName) == 0 {
|
||||
s.guessedTableName = s.model.tableName()
|
||||
return s.guessedTableName
|
||||
func (s *Do) table() string {
|
||||
if len(s.tableName) == 0 {
|
||||
if len(s.db.search.tableName) == 0 {
|
||||
s.tableName = s.model.tableName()
|
||||
} else {
|
||||
return s.specifiedTableName
|
||||
s.tableName = s.db.search.tableName
|
||||
}
|
||||
}
|
||||
return s.tableName
|
||||
}
|
||||
|
||||
func (s *Do) err(err error) error {
|
||||
if err != nil {
|
||||
s.chain.err(err)
|
||||
s.db.err(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -60,18 +60,18 @@ func (s *Do) setModel(value interface{}) *Do {
|
|||
|
||||
func (s *Do) addToVars(value interface{}) string {
|
||||
s.sqlVars = append(s.sqlVars, value)
|
||||
return fmt.Sprintf(s.chain.d.dialect.BinVar(), len(s.sqlVars))
|
||||
return fmt.Sprintf(s.db.dialect.BinVar(), len(s.sqlVars))
|
||||
}
|
||||
|
||||
func (s *Do) exec(sqls ...string) (err error) {
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
if len(sqls) > 0 {
|
||||
s.sql = sqls[0]
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
_, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
_, err = s.db.db.Exec(s.sql, s.sqlVars...)
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
}
|
||||
return s.err(err)
|
||||
}
|
||||
|
@ -95,17 +95,17 @@ func (s *Do) prepareCreateSql() {
|
|||
|
||||
s.sql = fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||
s.tableName(),
|
||||
s.table(),
|
||||
strings.Join(columns, ","),
|
||||
strings.Join(sqls, ","),
|
||||
s.chain.d.dialect.ReturningStr(s.model.primaryKeyDb()),
|
||||
s.db.dialect.ReturningStr(s.model.primaryKeyDb()),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) saveBeforeAssociations() {
|
||||
for _, field := range s.model.beforeAssociations() {
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
do := &Do{db: s.db}
|
||||
|
||||
reflect_value := reflect.ValueOf(field.Value)
|
||||
if reflect_value.CanAddr() {
|
||||
|
@ -134,7 +134,7 @@ func (s *Do) saveAfterAssociations() {
|
|||
switch reflect_value.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflect_value.Len(); i++ {
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
do := &Do{db: s.db}
|
||||
|
||||
value := reflect_value.Index(i).Addr().Interface()
|
||||
if len(field.foreignKey) > 0 {
|
||||
|
@ -143,7 +143,7 @@ func (s *Do) saveAfterAssociations() {
|
|||
do.setModel(value).save()
|
||||
}
|
||||
default:
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
do := &Do{db: s.db}
|
||||
if reflect_value.CanAddr() {
|
||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
||||
do.setModel(field.Value).save()
|
||||
|
@ -170,21 +170,21 @@ func (s *Do) create() (i interface{}) {
|
|||
s.saveBeforeAssociations()
|
||||
s.prepareCreateSql()
|
||||
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
var id interface{}
|
||||
|
||||
now := time.Now()
|
||||
if s.chain.d.dialect.SupportLastInsertId() {
|
||||
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
||||
if s.db.dialect.SupportLastInsertId() {
|
||||
if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
||||
id, err = sql_result.LastInsertId()
|
||||
s.err(err)
|
||||
}
|
||||
} else {
|
||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||
}
|
||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
|
||||
|
||||
s.saveAfterAssociations()
|
||||
|
@ -236,7 +236,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
|||
|
||||
s.sql = fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
s.tableName(),
|
||||
s.table(),
|
||||
strings.Join(sqls, ", "),
|
||||
s.combinedSql(),
|
||||
)
|
||||
|
@ -258,7 +258,7 @@ func (s *Do) update() *Do {
|
|||
s.saveBeforeAssociations()
|
||||
s.prepareUpdateSql(update_attrs)
|
||||
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
s.exec()
|
||||
s.saveAfterAssociations()
|
||||
|
||||
|
@ -272,11 +272,11 @@ func (s *Do) update() *Do {
|
|||
func (s *Do) delete() *Do {
|
||||
s.model.callMethod("BeforeDelete")
|
||||
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
||||
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql())
|
||||
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql())
|
||||
} else {
|
||||
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql())
|
||||
}
|
||||
s.exec()
|
||||
s.model.callMethod("AfterDelete")
|
||||
|
@ -285,7 +285,7 @@ func (s *Do) delete() *Do {
|
|||
}
|
||||
|
||||
func (s *Do) prepareQuerySql() {
|
||||
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
||||
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql())
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -358,10 +358,10 @@ func (s *Do) query() {
|
|||
}
|
||||
|
||||
s.prepareQuerySql()
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
now := time.Now()
|
||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if s.err(err) != nil {
|
||||
return
|
||||
|
@ -402,10 +402,10 @@ func (s *Do) query() {
|
|||
|
||||
func (s *Do) count(value interface{}) {
|
||||
s.prepareQuerySql()
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
now := time.Now()
|
||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
|
||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -420,10 +420,10 @@ func (s *Do) pluck(column string, value interface{}) {
|
|||
|
||||
s.prepareQuerySql()
|
||||
|
||||
if !s.chain.hasError() {
|
||||
if !s.db.hasError() {
|
||||
now := time.Now()
|
||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if s.err(err) == nil {
|
||||
defer rows.Close()
|
||||
|
@ -645,25 +645,25 @@ func (s *Do) createTable() *Do {
|
|||
}
|
||||
}
|
||||
|
||||
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.tableName(), strings.Join(sqls, ","))
|
||||
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))
|
||||
|
||||
s.exec()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) dropTable() *Do {
|
||||
s.sql = fmt.Sprintf("DROP TABLE %v", s.tableName())
|
||||
s.sql = fmt.Sprintf("DROP TABLE %v", s.table())
|
||||
s.exec()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) updateColumn(column string, typ string) {
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.tableName(), column, typ)
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ)
|
||||
s.exec()
|
||||
}
|
||||
|
||||
func (s *Do) dropColumn(column string) {
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.tableName(), column)
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column)
|
||||
s.exec()
|
||||
}
|
||||
|
||||
|
@ -672,22 +672,22 @@ func (s *Do) addIndex(column string, names ...string) {
|
|||
if len(names) > 0 {
|
||||
index_name = names[0]
|
||||
} else {
|
||||
index_name = fmt.Sprintf("index_%v_on_%v", s.tableName(), column)
|
||||
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
|
||||
}
|
||||
|
||||
s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.tableName(), column)
|
||||
s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column)
|
||||
s.exec()
|
||||
}
|
||||
|
||||
func (s *Do) removeIndex(index_name string) {
|
||||
s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.tableName())
|
||||
s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table())
|
||||
s.exec()
|
||||
}
|
||||
|
||||
func (s *Do) autoMigrate() *Do {
|
||||
var table_name string
|
||||
sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.tableName()))
|
||||
s.db.QueryRow(sql, s.sqlVars...).Scan(&table_name)
|
||||
sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table()))
|
||||
s.db.db.QueryRow(sql, s.sqlVars...).Scan(&table_name)
|
||||
s.sqlVars = []interface{}{}
|
||||
|
||||
// If table doesn't exist
|
||||
|
@ -696,13 +696,13 @@ func (s *Do) autoMigrate() *Do {
|
|||
} else {
|
||||
for _, field := range s.model.fields("migration") {
|
||||
var column_name, data_type string
|
||||
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName()))
|
||||
s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
|
||||
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.table()))
|
||||
s.db.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type)
|
||||
s.sqlVars = []interface{}{}
|
||||
|
||||
// If column doesn't exist
|
||||
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.dbName, field.sqlTag())
|
||||
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag())
|
||||
s.exec()
|
||||
}
|
||||
}
|
||||
|
@ -711,9 +711,9 @@ func (s *Do) autoMigrate() *Do {
|
|||
}
|
||||
|
||||
func (s *Do) begin() *Do {
|
||||
if db, ok := s.db.(sqlDb); ok {
|
||||
if db, ok := s.db.db.(sqlDb); ok {
|
||||
if tx, err := db.Begin(); err == nil {
|
||||
s.db = interface{}(tx).(sqlCommon)
|
||||
s.db.db = interface{}(tx).(sqlCommon)
|
||||
s.startedTransaction = true
|
||||
}
|
||||
}
|
||||
|
@ -722,8 +722,8 @@ func (s *Do) begin() *Do {
|
|||
|
||||
func (s *Do) commit_or_rollback() {
|
||||
if s.startedTransaction {
|
||||
if db, ok := s.db.(sqlTx); ok {
|
||||
if s.chain.hasError() {
|
||||
if db, ok := s.db.db.(sqlTx); ok {
|
||||
if s.db.hasError() {
|
||||
db.Rollback()
|
||||
} else {
|
||||
db.Commit()
|
||||
|
|
13
main.go
13
main.go
|
@ -109,7 +109,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
|||
s.clone().do(out).where(where).initialize()
|
||||
} else {
|
||||
if len(s.search.assignAttrs) > 0 {
|
||||
s.do(out).updateAttrs(s.assignAttrs) //updated or not
|
||||
s.do(out).updateAttrs(s.search.assignAttrs) //updated or not
|
||||
}
|
||||
}
|
||||
return s
|
||||
|
@ -127,13 +127,22 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||
return s.Updates(toSearchableMap(attrs...), true)
|
||||
}
|
||||
|
||||
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
|
||||
s.do(s.data).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DB) Save(value interface{}) *DB {
|
||||
s.do(value).begin().save().commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DB) Delete(value interface{}) *DB {
|
||||
s.do(value).bengin().delete(value).commit_or_rollback()
|
||||
s.do(value).begin().delete().commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
|
@ -29,13 +29,13 @@ func (s *DB) hasError() bool {
|
|||
}
|
||||
|
||||
func (s *DB) print(level string, v ...interface{}) {
|
||||
if s.d.logMode || s.debug_mode || level == "debug" {
|
||||
if _, ok := s.d.logger.(Logger); !ok {
|
||||
if s.logMode || level == "debug" {
|
||||
if _, ok := s.parent.logger.(Logger); !ok {
|
||||
fmt.Println("logger haven't been set, using os.Stdout")
|
||||
s.d.logger = default_logger
|
||||
s.parent.logger = default_logger
|
||||
}
|
||||
args := []interface{}{level}
|
||||
s.d.logger.(Logger).Print(append(args, v...)...)
|
||||
s.parent.logger.(Logger).Print(append(args, v...)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue