From 28b49124eb23c90e6a5c3906cfdee24f3d4d36b0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 2 Nov 2013 14:12:18 +0800 Subject: [PATCH] Handle SubStruct --- do.go | 4 +++- gorm_test.go | 36 ++++++++++++++++++++++++++++++++++++ model.go | 21 +++++++++++++++------ 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/do.go b/do.go index 51beb230..8f3b0213 100644 --- a/do.go +++ b/do.go @@ -566,7 +566,9 @@ func (s *Do) combinedSql() string { func (s *Do) createTable() *Do { var sqls []string for _, field := range s.model.fields("") { - sqls = append(sqls, field.DbName+" "+field.SqlType) + if len(field.SqlType) > 0 { + sqls = append(sqls, field.DbName+" "+field.SqlType) + } } s.sql = fmt.Sprintf( diff --git a/gorm_test.go b/gorm_test.go index 5aa6e304..e975a4ae 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -995,3 +995,39 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } } + +type Category struct { + Id int64 + Name string +} + +type Post struct { + Id int64 + Title string + Body string + Comments []Comment + Category Category +} + +type Comment struct { + Id int64 + Content string + Post Post +} + +func TestSubStruct(t *testing.T) { + db.CreateTable(Category{}) + db.CreateTable(Post{}) + db.CreateTable(Comment{}) + + post := Post{ + Title: "post 1", + Body: "body 1", + Comments: []Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, + Category: Category{Name: "category"}, + } + + if err := db.Save(&post).Error; err != nil { + t.Errorf("Got errors when save post", err) + } +} diff --git a/model.go b/model.go index 3c79cb1e..17f80388 100644 --- a/model.go +++ b/model.go @@ -77,6 +77,7 @@ func (m *Model) fields(operation string) (fields []Field) { field.AutoCreateTime = "created_at" == field.DbName field.AutoUpdateTime = "updated_at" == field.DbName value := indirect_value.FieldByName(p.Name) + time_value, is_time := value.Interface().(time.Time) switch value.Kind() { case reflect.Int, reflect.Int64, reflect.Int32: @@ -84,15 +85,15 @@ func (m *Model) fields(operation string) (fields []Field) { case reflect.String: field.IsBlank = value.String() == "" default: - if value, ok := value.Interface().(time.Time); ok { - field.IsBlank = value.IsZero() + if is_time { + field.IsBlank = time_value.IsZero() } } - if v, ok := value.Interface().(time.Time); ok { + if is_time { switch operation { case "create": - if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() { + if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() { value.Set(reflect.ValueOf(time.Now())) } case "update": @@ -107,7 +108,15 @@ func (m *Model) fields(operation string) (fields []Field) { if field.IsPrimaryKey { field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0) } else { - field.SqlType = getSqlType(m.driver, field.Value, 0) + switch reflect.TypeOf(field.Value).Kind() { + case reflect.Slice: + case reflect.Struct: + if is_time { + field.SqlType = getSqlType(m.driver, field.Value, 0) + } + default: + field.SqlType = getSqlType(m.driver, field.Value, 0) + } } fields = append(fields, field) } @@ -165,7 +174,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} { results := map[string]interface{}{} for _, field := range m.fields(operation) { - if !field.IsPrimaryKey { + if !field.IsPrimaryKey && (len(field.SqlType) > 0) { results[field.DbName] = field.Value } }