forked from mirror/gorm
Yay, Add Update, Updates Support
This commit is contained in:
parent
2a20e551ed
commit
84b280c0ff
17
README.md
17
README.md
|
@ -9,6 +9,8 @@ Yet Another ORM library for Go, aims for developer friendly
|
||||||
* Before/After Create/Save/Update/Delete Callbacks
|
* Before/After Create/Save/Update/Delete Callbacks
|
||||||
* Order/Select/Limit/Offset Support
|
* Order/Select/Limit/Offset Support
|
||||||
* Automatically CreatedAt, UpdatedAt
|
* Automatically CreatedAt, UpdatedAt
|
||||||
|
* Update, Updates Like Rails's update_attribute, update_attributes
|
||||||
|
* Dynamically set table name when search, update, delete...
|
||||||
* Create table from struct
|
* Create table from struct
|
||||||
* Prevent SQL Injection
|
* Prevent SQL Injection
|
||||||
* Goroutines friendly
|
* Goroutines friendly
|
||||||
|
@ -83,7 +85,6 @@ db.Order("age desc").Find(&users1).Order("age", true).Find(&users2)
|
||||||
//// users1 -> select * from users order by age desc;
|
//// users1 -> select * from users order by age desc;
|
||||||
//// users2 -> select * from users order by age;
|
//// users2 -> select * from users order by age;
|
||||||
|
|
||||||
|
|
||||||
// Limit
|
// Limit
|
||||||
db.Limit(3).Find(&users)
|
db.Limit(3).Find(&users)
|
||||||
//// users -> select * from users limit 3;
|
//// users -> select * from users limit 3;
|
||||||
|
@ -92,7 +93,6 @@ db.Limit(10).Find(&users1).Limit(20).Find(&users2).Limit(-1).Find(&users3)
|
||||||
//// users2 -> select * from users limit 20;
|
//// users2 -> select * from users limit 20;
|
||||||
//// users3 -> select * from users;
|
//// users3 -> select * from users;
|
||||||
|
|
||||||
|
|
||||||
// Offset
|
// Offset
|
||||||
//// select * from users offset 3;
|
//// select * from users offset 3;
|
||||||
db.Offset(3).Find(&users)
|
db.Offset(3).Find(&users)
|
||||||
|
@ -159,6 +159,18 @@ db.Table("deleted_users").Find(&deleted_users)
|
||||||
db.Table("deleted_users").Find(&deleted_user)
|
db.Table("deleted_users").Find(&deleted_user)
|
||||||
//// deleted_user -> select * from deleted_users limit 1;
|
//// deleted_user -> select * from deleted_users limit 1;
|
||||||
|
|
||||||
|
// Update
|
||||||
|
db.Table("users").Where(10).Update("name", "hello")
|
||||||
|
//// update users set name='hello' where id = 10;
|
||||||
|
db.Table("users").Update("name", "hello")
|
||||||
|
//// update users set name='hello';
|
||||||
|
|
||||||
|
// Updates
|
||||||
|
db.Table("users").Where(10).Updates(map[string]interface{}{"name": "hello", "age": 18})
|
||||||
|
//// update users set name='hello', age=18 where id = 10;
|
||||||
|
db.Table("users").Updates(map[string]interface{}{"name": "hello", "age": 18})
|
||||||
|
//// update users set name='hello', age=18;
|
||||||
|
|
||||||
// Run Raw SQL
|
// Run Raw SQL
|
||||||
db.Exec("drop table users;")
|
db.Exec("drop table users;")
|
||||||
|
|
||||||
|
@ -203,7 +215,6 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us
|
||||||
```
|
```
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
* Update, Updates like rails's update_attribute, update_attributes
|
|
||||||
* Soft Delete
|
* Soft Delete
|
||||||
* Query with map or struct
|
* Query with map or struct
|
||||||
* FindOrInitialize / FindOrCreate
|
* FindOrInitialize / FindOrCreate
|
||||||
|
|
10
chain.go
10
chain.go
|
@ -132,11 +132,15 @@ func (s *Chain) Delete(value interface{}) *Chain {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Update(column string, value string) *Chain {
|
func (s *Chain) Update(column string, value interface{}) *Chain {
|
||||||
return s
|
return s.Updates(map[string]interface{}{column: value}, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Updates(values map[string]string) *Chain {
|
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain {
|
||||||
|
do := s.do(s.value)
|
||||||
|
do.updateAttrs = values
|
||||||
|
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0
|
||||||
|
do.update()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
23
do.go
23
do.go
|
@ -33,6 +33,9 @@ type Do struct {
|
||||||
offsetStr string
|
offsetStr string
|
||||||
limitStr string
|
limitStr string
|
||||||
operation string
|
operation string
|
||||||
|
|
||||||
|
updateAttrs map[string]interface{}
|
||||||
|
ignoreProtectedAttrs bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) tableName() string {
|
func (s *Do) tableName() string {
|
||||||
|
@ -93,10 +96,10 @@ func (s *Do) save() *Do {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareCreateSql() *Do {
|
func (s *Do) prepareCreateSql() *Do {
|
||||||
columns, values := s.model.columnsAndValues("create")
|
var sqls, columns []string
|
||||||
|
|
||||||
var sqls []string
|
for key, value := range s.model.columnsAndValues("create") {
|
||||||
for _, value := range values {
|
columns = append(columns, key)
|
||||||
sqls = append(sqls, s.addToVars(value))
|
sqls = append(sqls, s.addToVars(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,16 +141,20 @@ func (s *Do) create() *Do {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareUpdateSql() *Do {
|
func (s *Do) prepareUpdateSql() *Do {
|
||||||
columns, values := s.model.columnsAndValues("update")
|
update_attrs := s.updateAttrs
|
||||||
var sets []string
|
if len(update_attrs) == 0 {
|
||||||
for index, column := range columns {
|
update_attrs = s.model.columnsAndValues("update")
|
||||||
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))
|
}
|
||||||
|
|
||||||
|
var sqls []string
|
||||||
|
for key, value := range update_attrs {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Sql = fmt.Sprintf(
|
s.Sql = fmt.Sprintf(
|
||||||
"UPDATE %v SET %v %v",
|
"UPDATE %v SET %v %v",
|
||||||
s.tableName(),
|
s.tableName(),
|
||||||
strings.Join(sets, ", "),
|
strings.Join(sqls, ", "),
|
||||||
s.combinedSql(),
|
s.combinedSql(),
|
||||||
)
|
)
|
||||||
return s
|
return s
|
||||||
|
|
68
gorm_test.go
68
gorm_test.go
|
@ -116,7 +116,7 @@ func TestSaveAndFind(t *testing.T) {
|
||||||
db.Find(&users)
|
db.Find(&users)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdate(t *testing.T) {
|
func TestSaveAndUpdate(t *testing.T) {
|
||||||
name, name2, new_name := "update", "update2", "new_update"
|
name, name2, new_name := "update", "update2", "new_update"
|
||||||
user := User{Name: name, Age: 1}
|
user := User{Name: name, Age: 1}
|
||||||
db.Save(&user)
|
db.Save(&user)
|
||||||
|
@ -172,6 +172,10 @@ func TestWhere(t *testing.T) {
|
||||||
t.Errorf("Should found out user with name '%v'", name)
|
t.Errorf("Should found out user with name '%v'", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.Where(user.Id).First(&User{}).Error != nil {
|
||||||
|
t.Errorf("Should found out users only with id")
|
||||||
|
}
|
||||||
|
|
||||||
user = &User{}
|
user = &User{}
|
||||||
orm := db.Where("name LIKE ?", "%nonono%").First(user)
|
orm := db.Where("name LIKE ?", "%nonono%").First(user)
|
||||||
if orm.Error == nil {
|
if orm.Error == nil {
|
||||||
|
@ -616,3 +620,65 @@ func TestSetTableDirectly(t *testing.T) {
|
||||||
t.Errorf("Set Table Chain Should works well")
|
t.Errorf("Set Table Chain Should works well")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdate(t *testing.T) {
|
||||||
|
product1 := Product{Code: "123"}
|
||||||
|
product2 := Product{Code: "234"}
|
||||||
|
db.Save(&product1).Save(&product2).Update("code", "456")
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = '123'").Error != nil {
|
||||||
|
t.Errorf("Product 123's code should not be changed!")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = '234'").Error == nil {
|
||||||
|
t.Errorf("Product 234's code should be changed to 456!")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = '456'").Error != nil {
|
||||||
|
t.Errorf("Product 234's code should be 456 now!")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
|
||||||
|
if db.First(&Product{}, "code = '123'").Error == nil {
|
||||||
|
t.Errorf("Product 123's code should be changed to 789")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = '456'").Error != nil {
|
||||||
|
t.Errorf("Product 456's code should not be changed to 789")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = '789'").Error != nil {
|
||||||
|
t.Errorf("We should have Product 789")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdates(t *testing.T) {
|
||||||
|
product1 := Product{Code: "abc", Price: 10}
|
||||||
|
product2 := Product{Code: "cde", Price: 20}
|
||||||
|
db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100})
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil {
|
||||||
|
t.Errorf("Product abc should not be updated!")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = 'cde'").Error == nil {
|
||||||
|
t.Errorf("Product cde should be renamed to edf!")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = 'edf' and price = 100").Error != nil {
|
||||||
|
t.Errorf("We should have product edf!")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Table("products").Where("code in (?)", []string{"abc"}).Updates(map[string]interface{}{"code": "fgh", "price": 200})
|
||||||
|
if db.First(&Product{}, "code = 'abc'").Error == nil {
|
||||||
|
t.Errorf("Product abc's code should be changed to fgh")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil {
|
||||||
|
t.Errorf("Product cde's code should not be changed to fgh")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Product{}, "code = 'fgh' and price = 200").Error != nil {
|
||||||
|
t.Errorf("We should have Product fgh")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
12
model.go
12
model.go
|
@ -96,14 +96,14 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) columnsAndValues(operation string) (columns []string, values []interface{}) {
|
func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
||||||
|
results := map[string]interface{}{}
|
||||||
for _, field := range m.fields(operation) {
|
for _, field := range m.fields(operation) {
|
||||||
if !field.IsPrimaryKey {
|
if !field.IsPrimaryKey {
|
||||||
columns = append(columns, field.DbName)
|
results[field.DbName] = field.Value
|
||||||
values = append(values, field.Value)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) tableName() (str string, err error) {
|
func (m *Model) tableName() (str string, err error) {
|
||||||
|
@ -138,6 +138,10 @@ func (m *Model) tableName() (str string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) callMethod(method string) error {
|
func (m *Model) callMethod(method string) error {
|
||||||
|
if m.Data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
fm := reflect.ValueOf(m.Data).MethodByName(method)
|
fm := reflect.ValueOf(m.Data).MethodByName(method)
|
||||||
if fm.IsValid() {
|
if fm.IsValid() {
|
||||||
v := fm.Call([]reflect.Value{})
|
v := fm.Call([]reflect.Value{})
|
||||||
|
|
Loading…
Reference in New Issue