diff --git a/README.md b/README.md index 3baf57d1..cabb2fbb 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO -* Update * Delete * Complex where query (= / > / < / <> / in) * Order diff --git a/model.go b/model.go index dd7189b8..48d68fae 100644 --- a/model.go +++ b/model.go @@ -26,7 +26,7 @@ func (s *Orm) toModel(value interface{}) *Model { func (m *Model) PrimaryKeyIsEmpty() bool { result := reflect.ValueOf(m.Data).Elem() - return result.FieldByName(m.PrimaryKey()).Interface().(int64) == int64(0) + return result.FieldByName(m.PrimaryKey()).Interface().(int64) == 0 } func (m *Model) PrimaryKey() string { diff --git a/orm_test.go b/orm_test.go index 946036ac..06f0e28f 100644 --- a/orm_test.go +++ b/orm_test.go @@ -39,6 +39,34 @@ func TestSaveAndFind(t *testing.T) { db.Find(&users) } +func TestUpdate(t *testing.T) { + name := "update" + user := User{Name: name} + db.Save(&user) + + user_id := user.Id + if user_id == 0 { + t.Errorf("User Id should exist after create") + } + + orm := db.Where("name = ?", "update").First(&User{}) + if orm.Error != nil { + t.Errorf("No error should raise when looking for a exiting user") + } + + user.Name = "update2" + db.Save(&user) + orm = db.Where("name = ?", "update").First(&User{}) + if orm.Error == nil { + t.Errorf("Should raise error when looking for a existing user with an outdated name") + } + + orm = db.Where("name = ?", "update2").First(&User{}) + if orm.Error != nil { + t.Errorf("Shouldn't raise error when looking for a existing user with the new name") + } +} + func TestWhere(t *testing.T) { name := "where" db.Save(&User{Name: name}) diff --git a/sql.go b/sql.go index 90590976..936d0fbd 100644 --- a/sql.go +++ b/sql.go @@ -77,7 +77,7 @@ func (s *Orm) createSql(value interface{}) { s.Sql = fmt.Sprintf( "INSERT INTO \"%v\" (%v) VALUES (%v) %v", s.TableName, - strings.Join(quoteMap(columns), ","), + strings.Join(s.quoteMap(columns), ","), valuesToBinVar(values), s.Model.ReturningStr(), ) @@ -99,10 +99,23 @@ func (s *Orm) create(value interface{}) { } func (s *Orm) updateSql(value interface{}) { + columns, values := s.Model.ColumnsAndValues() + var sets []string + for index, column := range columns { + s.SqlVars = append(s.SqlVars, values[index]) + sets = append(sets, fmt.Sprintf("%v = $%d", s.quote(column), len(s.SqlVars))) + } + + s.Sql = fmt.Sprintf( + "UPDATE %v SET %v", + s.TableName, + strings.Join(sets, ", "), + ) return } func (s *Orm) update(value interface{}) { + s.Exec() return } diff --git a/utils.go b/utils.go index f2b9f612..c96b1f48 100644 --- a/utils.go +++ b/utils.go @@ -16,9 +16,13 @@ func valuesToBinVar(values []interface{}) string { return strings.Join(sqls, ",") } -func quoteMap(values []string) (results []string) { +func (s *Orm) quote(value string) string { + return "\"" + value + "\"" +} + +func (s *Orm) quoteMap(values []string) (results []string) { for _, value := range values { - results = append(results, "\""+value+"\"") + results = append(results, s.quote(value)) } return }