From 97f234e481b304fdb325b247c83476e81e88a31b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Oct 2013 14:34:34 +0800 Subject: [PATCH] yay, created_at, updated_at works --- README.md | 1 - model.go | 57 +++++++++++++++++++++++++++++++++-------------------- orm_test.go | 38 +++++++++++++++++++++++++++++++---- sql.go | 4 ++-- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index e1e8357f..801f8e6d 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ Yet Another ORM library for Go, aims for developer friendly ## TODO -* CreatedAt, UpdatedAt * After/Before Save/Update/Create/Delete * Soft Delete * Better First method (First(&user, primary_key, where conditions)) diff --git a/model.go b/model.go index abe72af9..bdee653e 100644 --- a/model.go +++ b/model.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" "regexp" - "strings" + "time" ) type Model struct { @@ -14,10 +14,13 @@ type Model struct { } type Field struct { - Name string - Value interface{} - SqlType string - DbName string + Name string + Value interface{} + SqlType string + DbName string + AutoCreateTime bool + AutoUpdateTime bool + IsPrimaryKey bool } func (s *Orm) toModel(value interface{}) *Model { @@ -52,7 +55,7 @@ func (m *Model) PrimaryKeyDb() string { return toSnake(m.PrimaryKey()) } -func (m *Model) Fields() (fields []Field) { +func (m *Model) Fields(operation string) (fields []Field) { typ := reflect.TypeOf(m.Data).Elem() for i := 0; i < typ.NumField(); i++ { @@ -61,8 +64,27 @@ func (m *Model) Fields() (fields []Field) { var field Field field.Name = p.Name field.DbName = toSnake(p.Name) - field.Value = reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Interface() - if m.PrimaryKeyDb() == field.DbName { + field.IsPrimaryKey = m.PrimaryKeyDb() == field.DbName + field.AutoCreateTime = "created_at" == field.DbName + field.AutoUpdateTime = "updated_at" == field.DbName + value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name) + + switch operation { + case "create": + if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() { + value = reflect.ValueOf(time.Now()) + reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value) + } + case "update": + if field.AutoUpdateTime { + value = reflect.ValueOf(time.Now()) + reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value) + } + default: + } + field.Value = value.Interface() + + if field.IsPrimaryKey { field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0) } else { field.SqlType = getSqlType(m.driver, field.Value, 0) @@ -73,18 +95,11 @@ func (m *Model) Fields() (fields []Field) { return } -func (m *Model) ColumnsAndValues() (columns []string, values []interface{}) { - typ := reflect.TypeOf(m.Data).Elem() - - for i := 0; i < typ.NumField(); i++ { - p := typ.Field(i) - if !p.Anonymous { - db_name := toSnake(p.Name) - if m.PrimaryKeyDb() != db_name { - columns = append(columns, db_name) - value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name) - values = append(values, value.Interface()) - } +func (m *Model) ColumnsAndValues(operation string) (columns []string, values []interface{}) { + for _, field := range m.Fields(operation) { + if !field.IsPrimaryKey { + columns = append(columns, field.DbName) + values = append(values, field.Value) } } return @@ -113,7 +128,7 @@ func (model *Model) MissingColumns() (results []string) { func (model *Model) CreateTable() (sql string) { var sqls []string - for _, field := range model.Fields() { + for _, field := range model.Fields("null") { sqls = append(sqls, field.DbName+" "+field.SqlType) } diff --git a/orm_test.go b/orm_test.go index 69b2b1d8..1a127f68 100644 --- a/orm_test.go +++ b/orm_test.go @@ -6,10 +6,12 @@ import ( ) type User struct { - Id int64 - Age int64 - Birthday time.Time - Name string + Id int64 + Age int64 + Birthday time.Time + Name string + CreatedAt time.Time + UpdatedAt time.Time } var ( @@ -281,3 +283,31 @@ func TestOrAndNot(t *testing.T) { t.Errorf("Should find three users with name 1 and 3") } } + +func TestCreatedAtAndUpdatedAt(t *testing.T) { + name := "check_created_at_and_updated_at" + u := User{Name: name, Age: 1} + db.Save(&u) + created_at := u.CreatedAt + updated_at := u.UpdatedAt + + if created_at.IsZero() { + t.Errorf("Should have created_at after create") + } + if updated_at.IsZero() { + t.Errorf("Should have updated_at after create") + } + + u.Name = "check_created_at_and_updated_at_2" + db.Save(&u) + created_at2 := u.CreatedAt + updated_at2 := u.UpdatedAt + + if created_at != created_at2 { + t.Errorf("Created At should not changed after update") + } + + if updated_at == updated_at2 { + t.Errorf("Updated At should be changed after update") + } +} diff --git a/sql.go b/sql.go index 8fc23e96..4e75a947 100644 --- a/sql.go +++ b/sql.go @@ -110,7 +110,7 @@ func (s *Orm) pluck(value interface{}) { } func (s *Orm) createSql(value interface{}) { - columns, values := s.model.ColumnsAndValues() + columns, values := s.model.ColumnsAndValues("create") var sqls []string for _, value := range values { @@ -144,7 +144,7 @@ func (s *Orm) create(value interface{}) { } func (s *Orm) updateSql(value interface{}) { - columns, values := s.model.ColumnsAndValues() + columns, values := s.model.ColumnsAndValues("update") var sets []string for index, column := range columns { sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index])))