From 09b6fc3ab0a88f7cd2088b78bc36b9a531081505 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 26 Oct 2013 23:30:17 +0800 Subject: [PATCH] Fix some bugs about Update and Delete --- README.md | 5 ++--- model.go | 14 ++++++++++++-- sql.go | 42 ++++++++++++++++++++++++++++++------------ utils.go | 9 --------- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index cabb2fbb..723389ec 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO -* Delete * Complex where query (= / > / < / <> / in) * Order * Limit @@ -11,12 +10,12 @@ Yet Another ORM library for Go, aims for developer friendly * Offset * Or query * Not query +* Better First method (First(&user, primary_key, where conditions)) * Even more complex where query (with map or struct) * ORM.Errors -* Better First method (First(&user, primary_key, where conditions)) -* Soft Delete * After/Before Save/Update/Create/Delete * CreatedAt, UpdatedAt +* Soft Delete * FindOrInitialize / FindOrCreate * SQL Log * Auto Migration diff --git a/model.go b/model.go index 48d68fae..e1865433 100644 --- a/model.go +++ b/model.go @@ -25,8 +25,18 @@ func (s *Orm) toModel(value interface{}) *Model { } func (m *Model) PrimaryKeyIsEmpty() bool { - result := reflect.ValueOf(m.Data).Elem() - return result.FieldByName(m.PrimaryKey()).Interface().(int64) == 0 + return m.PrimaryKeyValue() == 0 +} + +func (m *Model) PrimaryKeyValue() int64 { + t := reflect.TypeOf(m.Data).Elem() + switch t.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: + return 0 + default: + result := reflect.ValueOf(m.Data).Elem() + return result.FieldByName(m.PrimaryKey()).Interface().(int64) + } } func (m *Model) PrimaryKey() string { diff --git a/sql.go b/sql.go index 4ea0d077..6198c42c 100644 --- a/sql.go +++ b/sql.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "reflect" - "strconv" "strings" ) @@ -74,14 +73,19 @@ func (s *Orm) query(out interface{}) { func (s *Orm) createSql(value interface{}) { columns, values := s.Model.ColumnsAndValues() + + var sqls []string + for _, value := range values { + sqls = append(sqls, s.addToVars(value)) + } + s.Sql = fmt.Sprintf( "INSERT INTO \"%v\" (%v) VALUES (%v) %v", s.TableName, strings.Join(s.quoteMap(columns), ","), - valuesToBinVar(values), + strings.Join(sqls, ","), s.Model.ReturningStr(), ) - s.SqlVars = values return } @@ -102,15 +106,16 @@ func (s *Orm) updateSql(value interface{}) { columns, values := s.Model.ColumnsAndValues() var sets []string for index, column := range columns { - s.SqlVars = append(s.SqlVars, values[index]) - sets = append(sets, fmt.Sprintf("%v = $%d", s.quote(column), len(s.SqlVars))) + sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index]))) } s.Sql = fmt.Sprintf( - "UPDATE %v SET %v", + "UPDATE %v SET %v %v", s.TableName, strings.Join(sets, ", "), + s.whereSql(), ) + return } @@ -125,18 +130,31 @@ func (s *Orm) deleteSql(value interface{}) { } func (s *Orm) whereSql() (sql string) { - if len(s.whereClause) == 0 { - return - } else { - sql = "WHERE " + var conditions []string + if !s.Model.PrimaryKeyIsEmpty() { + conditions = append(conditions, fmt.Sprintf("(%v = %v)", s.quote(s.Model.PrimaryKeyDb()), s.addToVars(s.Model.PrimaryKeyValue()))) + } + + if len(s.whereClause) > 0 { for _, clause := range s.whereClause { - sql += clause["query"].(string) + str := "( " + clause["query"].(string) + " )" args := clause["args"].([]interface{}) for _, arg := range args { s.SqlVars = append(s.SqlVars, arg.([]interface{})...) - sql = strings.Replace(sql, "?", "$"+strconv.Itoa(len(s.SqlVars)), 1) + str = strings.Replace(str, "?", fmt.Sprintf("$%d", len(s.SqlVars)), 1) } + conditions = append(conditions, str) } } + + if len(conditions) > 0 { + sql = "WHERE " + strings.Join(conditions, " AND ") + } + return } + +func (s *Orm) addToVars(value interface{}) string { + s.SqlVars = append(s.SqlVars, value) + return fmt.Sprintf("$%d", len(s.SqlVars)) +} diff --git a/utils.go b/utils.go index c96b1f48..f77f8ae3 100644 --- a/utils.go +++ b/utils.go @@ -7,15 +7,6 @@ import ( "strings" ) -// FIXME -func valuesToBinVar(values []interface{}) string { - var sqls []string - for index, _ := range values { - sqls = append(sqls, fmt.Sprintf("$%d", index+1)) - } - return strings.Join(sqls, ",") -} - func (s *Orm) quote(value string) string { return "\"" + value + "\"" }