Handle SubStruct

This commit is contained in:
Jinzhu 2013-11-02 14:12:18 +08:00
parent 49cfb0d4a0
commit 28b49124eb
3 changed files with 54 additions and 7 deletions

4
do.go
View File

@ -566,7 +566,9 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do { func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("") { for _, field := range s.model.fields("") {
sqls = append(sqls, field.DbName+" "+field.SqlType) if len(field.SqlType) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType)
}
} }
s.sql = fmt.Sprintf( s.sql = fmt.Sprintf(

View File

@ -995,3 +995,39 @@ func TestNot(t *testing.T) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users's name not equal 3")
} }
} }
type Category struct {
Id int64
Name string
}
type Post struct {
Id int64
Title string
Body string
Comments []Comment
Category Category
}
type Comment struct {
Id int64
Content string
Post Post
}
func TestSubStruct(t *testing.T) {
db.CreateTable(Category{})
db.CreateTable(Post{})
db.CreateTable(Comment{})
post := Post{
Title: "post 1",
Body: "body 1",
Comments: []Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
Category: Category{Name: "category"},
}
if err := db.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err)
}
}

View File

@ -77,6 +77,7 @@ func (m *Model) fields(operation string) (fields []Field) {
field.AutoCreateTime = "created_at" == field.DbName field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName field.AutoUpdateTime = "updated_at" == field.DbName
value := indirect_value.FieldByName(p.Name) value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time)
switch value.Kind() { switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32: case reflect.Int, reflect.Int64, reflect.Int32:
@ -84,15 +85,15 @@ func (m *Model) fields(operation string) (fields []Field) {
case reflect.String: case reflect.String:
field.IsBlank = value.String() == "" field.IsBlank = value.String() == ""
default: default:
if value, ok := value.Interface().(time.Time); ok { if is_time {
field.IsBlank = value.IsZero() field.IsBlank = time_value.IsZero()
} }
} }
if v, ok := value.Interface().(time.Time); ok { if is_time {
switch operation { switch operation {
case "create": case "create":
if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() { if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
case "update": case "update":
@ -107,7 +108,15 @@ func (m *Model) fields(operation string) (fields []Field) {
if field.IsPrimaryKey { if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0) field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
} else { } else {
field.SqlType = getSqlType(m.driver, field.Value, 0) switch reflect.TypeOf(field.Value).Kind() {
case reflect.Slice:
case reflect.Struct:
if is_time {
field.SqlType = getSqlType(m.driver, field.Value, 0)
}
default:
field.SqlType = getSqlType(m.driver, field.Value, 0)
}
} }
fields = append(fields, field) fields = append(fields, field)
} }
@ -165,7 +174,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
results := map[string]interface{}{} results := map[string]interface{}{}
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsPrimaryKey { if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
results[field.DbName] = field.Value results[field.DbName] = field.Value
} }
} }