Yay, Add Update, Updates Support

This commit is contained in:
Jinzhu 2013-10-28 21:52:22 +08:00
parent 2a20e551ed
commit 84b280c0ff
5 changed files with 111 additions and 19 deletions

View File

@ -9,6 +9,8 @@ Yet Another ORM library for Go, aims for developer friendly
* Before/After Create/Save/Update/Delete Callbacks
* Order/Select/Limit/Offset Support
* Automatically CreatedAt, UpdatedAt
* Update, Updates Like Rails's update_attribute, update_attributes
* Dynamically set table name when search, update, delete...
* Create table from struct
* Prevent SQL Injection
* Goroutines friendly
@ -83,7 +85,6 @@ db.Order("age desc").Find(&users1).Order("age", true).Find(&users2)
//// users1 -> select * from users order by age desc;
//// users2 -> select * from users order by age;
// Limit
db.Limit(3).Find(&users)
//// users -> select * from users limit 3;
@ -92,7 +93,6 @@ db.Limit(10).Find(&users1).Limit(20).Find(&users2).Limit(-1).Find(&users3)
//// users2 -> select * from users limit 20;
//// users3 -> select * from users;
// Offset
//// select * from users offset 3;
db.Offset(3).Find(&users)
@ -159,6 +159,18 @@ db.Table("deleted_users").Find(&deleted_users)
db.Table("deleted_users").Find(&deleted_user)
//// deleted_user -> select * from deleted_users limit 1;
// Update
db.Table("users").Where(10).Update("name", "hello")
//// update users set name='hello' where id = 10;
db.Table("users").Update("name", "hello")
//// update users set name='hello';
// Updates
db.Table("users").Where(10).Updates(map[string]interface{}{"name": "hello", "age": 18})
//// update users set name='hello', age=18 where id = 10;
db.Table("users").Updates(map[string]interface{}{"name": "hello", "age": 18})
//// update users set name='hello', age=18;
// Run Raw SQL
db.Exec("drop table users;")
@ -203,7 +215,6 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us
```
## TODO
* Update, Updates like rails's update_attribute, update_attributes
* Soft Delete
* Query with map or struct
* FindOrInitialize / FindOrCreate

View File

@ -132,11 +132,15 @@ func (s *Chain) Delete(value interface{}) *Chain {
return s
}
func (s *Chain) Update(column string, value string) *Chain {
return s
func (s *Chain) Update(column string, value interface{}) *Chain {
return s.Updates(map[string]interface{}{column: value}, true)
}
func (s *Chain) Updates(values map[string]string) *Chain {
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain {
do := s.do(s.value)
do.updateAttrs = values
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0
do.update()
return s
}

23
do.go
View File

@ -33,6 +33,9 @@ type Do struct {
offsetStr string
limitStr string
operation string
updateAttrs map[string]interface{}
ignoreProtectedAttrs bool
}
func (s *Do) tableName() string {
@ -93,10 +96,10 @@ func (s *Do) save() *Do {
}
func (s *Do) prepareCreateSql() *Do {
columns, values := s.model.columnsAndValues("create")
var sqls, columns []string
var sqls []string
for _, value := range values {
for key, value := range s.model.columnsAndValues("create") {
columns = append(columns, key)
sqls = append(sqls, s.addToVars(value))
}
@ -138,16 +141,20 @@ func (s *Do) create() *Do {
}
func (s *Do) prepareUpdateSql() *Do {
columns, values := s.model.columnsAndValues("update")
var sets []string
for index, column := range columns {
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))
update_attrs := s.updateAttrs
if len(update_attrs) == 0 {
update_attrs = s.model.columnsAndValues("update")
}
var sqls []string
for key, value := range update_attrs {
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
}
s.Sql = fmt.Sprintf(
"UPDATE %v SET %v %v",
s.tableName(),
strings.Join(sets, ", "),
strings.Join(sqls, ", "),
s.combinedSql(),
)
return s

View File

@ -116,7 +116,7 @@ func TestSaveAndFind(t *testing.T) {
db.Find(&users)
}
func TestUpdate(t *testing.T) {
func TestSaveAndUpdate(t *testing.T) {
name, name2, new_name := "update", "update2", "new_update"
user := User{Name: name, Age: 1}
db.Save(&user)
@ -172,6 +172,10 @@ func TestWhere(t *testing.T) {
t.Errorf("Should found out user with name '%v'", name)
}
if db.Where(user.Id).First(&User{}).Error != nil {
t.Errorf("Should found out users only with id")
}
user = &User{}
orm := db.Where("name LIKE ?", "%nonono%").First(user)
if orm.Error == nil {
@ -616,3 +620,65 @@ func TestSetTableDirectly(t *testing.T) {
t.Errorf("Set Table Chain Should works well")
}
}
func TestUpdate(t *testing.T) {
product1 := Product{Code: "123"}
product2 := Product{Code: "234"}
db.Save(&product1).Save(&product2).Update("code", "456")
if db.First(&Product{}, "code = '123'").Error != nil {
t.Errorf("Product 123's code should not be changed!")
}
if db.First(&Product{}, "code = '234'").Error == nil {
t.Errorf("Product 234's code should be changed to 456!")
}
if db.First(&Product{}, "code = '456'").Error != nil {
t.Errorf("Product 234's code should be 456 now!")
}
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
if db.First(&Product{}, "code = '123'").Error == nil {
t.Errorf("Product 123's code should be changed to 789")
}
if db.First(&Product{}, "code = '456'").Error != nil {
t.Errorf("Product 456's code should not be changed to 789")
}
if db.First(&Product{}, "code = '789'").Error != nil {
t.Errorf("We should have Product 789")
}
}
func TestUpdates(t *testing.T) {
product1 := Product{Code: "abc", Price: 10}
product2 := Product{Code: "cde", Price: 20}
db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100})
if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil {
t.Errorf("Product abc should not be updated!")
}
if db.First(&Product{}, "code = 'cde'").Error == nil {
t.Errorf("Product cde should be renamed to edf!")
}
if db.First(&Product{}, "code = 'edf' and price = 100").Error != nil {
t.Errorf("We should have product edf!")
}
db.Table("products").Where("code in (?)", []string{"abc"}).Updates(map[string]interface{}{"code": "fgh", "price": 200})
if db.First(&Product{}, "code = 'abc'").Error == nil {
t.Errorf("Product abc's code should be changed to fgh")
}
if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil {
t.Errorf("Product cde's code should not be changed to fgh")
}
if db.First(&Product{}, "code = 'fgh' and price = 200").Error != nil {
t.Errorf("We should have Product fgh")
}
}

View File

@ -96,14 +96,14 @@ func (m *Model) fields(operation string) (fields []Field) {
return
}
func (m *Model) columnsAndValues(operation string) (columns []string, values []interface{}) {
func (m *Model) columnsAndValues(operation string) map[string]interface{} {
results := map[string]interface{}{}
for _, field := range m.fields(operation) {
if !field.IsPrimaryKey {
columns = append(columns, field.DbName)
values = append(values, field.Value)
results[field.DbName] = field.Value
}
}
return
return results
}
func (m *Model) tableName() (str string, err error) {
@ -138,6 +138,10 @@ func (m *Model) tableName() (str string, err error) {
}
func (m *Model) callMethod(method string) error {
if m.Data == nil {
return nil
}
fm := reflect.ValueOf(m.Data).MethodByName(method)
if fm.IsValid() {
v := fm.Call([]reflect.Value{})