diff --git a/README.md b/README.md index 88e8dd80..3baf57d1 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,6 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO -* Perimary key -* Save and fill the record * Update * Delete * Complex where query (= / > / < / <> / in) @@ -16,9 +14,10 @@ Yet Another ORM library for Go, aims for developer friendly * Not query * Even more complex where query (with map or struct) * ORM.Errors -* Better First method (First(&user, primary_key)) +* Better First method (First(&user, primary_key, where conditions)) * Soft Delete * After/Before Save/Update/Create/Delete +* CreatedAt, UpdatedAt * FindOrInitialize / FindOrCreate * SQL Log * Auto Migration diff --git a/model.go b/model.go index e22a5c64..7df9b7f3 100644 --- a/model.go +++ b/model.go @@ -27,6 +27,10 @@ func (m *Model) PrimaryKey() string { return "Id" } +func (m *Model) PrimaryKeyDb() string { + return toSnake(m.PrimaryKey()) +} + func (m *Model) Fields() (fields []Field) { typ := reflect.TypeOf(m.Data).Elem() @@ -37,7 +41,7 @@ func (m *Model) Fields() (fields []Field) { field.Name = p.Name field.DbName = toSnake(p.Name) field.Value = reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Interface() - if m.PrimaryKey() == p.Name { + if m.PrimaryKeyDb() == field.DbName { field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0) } else { field.SqlType = getSqlType(m.driver, field.Value, 0) @@ -54,9 +58,12 @@ func (m *Model) ColumnsAndValues() (columns []string, values []interface{}) { for i := 0; i < typ.NumField(); i++ { p := typ.Field(i) if !p.Anonymous { - columns = append(columns, toSnake(p.Name)) - value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name) - values = append(values, value.Interface()) + db_name := toSnake(p.Name) + if m.PrimaryKeyDb() != db_name { + columns = append(columns, db_name) + value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name) + values = append(values, value.Interface()) + } } } return @@ -83,10 +90,6 @@ func (model *Model) MissingColumns() (results []string) { return } -func (model *Model) ColumnType(column string) (result string) { - return -} - func (model *Model) CreateTable() (sql string) { var sqls []string for _, field := range model.Fields() { @@ -100,3 +103,10 @@ func (model *Model) CreateTable() (sql string) { ) return } + +func (model *Model) ReturningStr() (str string) { + if model.driver == "postgres" { + str = fmt.Sprintf("RETURNING \"%v\"", model.PrimaryKeyDb()) + } + return +} diff --git a/orm.go b/orm.go index ed8494f4..1c09e75a 100644 --- a/orm.go +++ b/orm.go @@ -28,7 +28,7 @@ type Orm struct { func (s *Orm) setModel(model interface{}) (err error) { s.Model = s.toModel(model) s.TableName = s.Model.TableName() - s.PrimaryKey = s.Model.PrimaryKey() + s.PrimaryKey = s.Model.PrimaryKeyDb() return } @@ -86,7 +86,8 @@ func (s *Orm) Select(value interface{}) *Orm { } func (s *Orm) Save(value interface{}) *Orm { - s.explain(value, "Save").Exec() + s.explain(value, "Create").create(value) + // s.explain(value, "Update").update(value) return s } diff --git a/orm_test.go b/orm_test.go index 82876ff0..0c9b28bd 100644 --- a/orm_test.go +++ b/orm_test.go @@ -17,6 +17,9 @@ func TestSaveAndFirst(t *testing.T) { db := getDB() u := &User{Name: "jinzhu"} db.Save(u) + if u.Id == 0 { + t.Errorf("Should have ID after create record") + } user := &User{} db.First(user) diff --git a/sql.go b/sql.go index 0e74ff57..90590976 100644 --- a/sql.go +++ b/sql.go @@ -11,8 +11,10 @@ import ( func (s *Orm) explain(value interface{}, operation string) *Orm { s.setModel(value) switch operation { - case "Save": - s.saveSql(value) + case "Create": + s.createSql(value) + case "Update": + s.updateSql(value) case "Delete": s.deleteSql(value) case "Query": @@ -70,18 +72,40 @@ func (s *Orm) query(out interface{}) { } } -func (s *Orm) saveSql(value interface{}) { +func (s *Orm) createSql(value interface{}) { columns, values := s.Model.ColumnsAndValues() s.Sql = fmt.Sprintf( - "INSERT INTO \"%v\" (%v) VALUES (%v)", + "INSERT INTO \"%v\" (%v) VALUES (%v) %v", s.TableName, strings.Join(quoteMap(columns), ","), valuesToBinVar(values), + s.Model.ReturningStr(), ) s.SqlVars = values return } +func (s *Orm) create(value interface{}) { + var id int64 + if s.driver == "postgres" { + s.Error = s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id) + } else { + s.SqlResult, s.Error = s.db.Exec(s.Sql, s.SqlVars...) + id, s.Error = s.SqlResult.LastInsertId() + } + + result := reflect.ValueOf(s.Model.Data).Elem() + result.FieldByName(s.Model.PrimaryKey()).SetInt(id) +} + +func (s *Orm) updateSql(value interface{}) { + return +} + +func (s *Orm) update(value interface{}) { + return +} + func (s *Orm) deleteSql(value interface{}) { s.Sql = fmt.Sprintf("DELETE FROM %v WHERE %v", s.TableName, s.whereSql) return