mirror of https://github.com/go-gorm/gorm.git
Handle SubStruct
This commit is contained in:
parent
49cfb0d4a0
commit
28b49124eb
4
do.go
4
do.go
|
@ -566,7 +566,9 @@ func (s *Do) combinedSql() string {
|
|||
func (s *Do) createTable() *Do {
|
||||
var sqls []string
|
||||
for _, field := range s.model.fields("") {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
if len(field.SqlType) > 0 {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
}
|
||||
}
|
||||
|
||||
s.sql = fmt.Sprintf(
|
||||
|
|
36
gorm_test.go
36
gorm_test.go
|
@ -995,3 +995,39 @@ func TestNot(t *testing.T) {
|
|||
t.Errorf("Should find all users's name not equal 3")
|
||||
}
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
Id int64
|
||||
Name string
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
Id int64
|
||||
Title string
|
||||
Body string
|
||||
Comments []Comment
|
||||
Category Category
|
||||
}
|
||||
|
||||
type Comment struct {
|
||||
Id int64
|
||||
Content string
|
||||
Post Post
|
||||
}
|
||||
|
||||
func TestSubStruct(t *testing.T) {
|
||||
db.CreateTable(Category{})
|
||||
db.CreateTable(Post{})
|
||||
db.CreateTable(Comment{})
|
||||
|
||||
post := Post{
|
||||
Title: "post 1",
|
||||
Body: "body 1",
|
||||
Comments: []Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
|
||||
Category: Category{Name: "category"},
|
||||
}
|
||||
|
||||
if err := db.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post", err)
|
||||
}
|
||||
}
|
||||
|
|
21
model.go
21
model.go
|
@ -77,6 +77,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.AutoCreateTime = "created_at" == field.DbName
|
||||
field.AutoUpdateTime = "updated_at" == field.DbName
|
||||
value := indirect_value.FieldByName(p.Name)
|
||||
time_value, is_time := value.Interface().(time.Time)
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||
|
@ -84,15 +85,15 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
case reflect.String:
|
||||
field.IsBlank = value.String() == ""
|
||||
default:
|
||||
if value, ok := value.Interface().(time.Time); ok {
|
||||
field.IsBlank = value.IsZero()
|
||||
if is_time {
|
||||
field.IsBlank = time_value.IsZero()
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := value.Interface().(time.Time); ok {
|
||||
if is_time {
|
||||
switch operation {
|
||||
case "create":
|
||||
if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() {
|
||||
if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
case "update":
|
||||
|
@ -107,7 +108,15 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
if field.IsPrimaryKey {
|
||||
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
|
||||
} else {
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
switch reflect.TypeOf(field.Value).Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Struct:
|
||||
if is_time {
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
}
|
||||
default:
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
}
|
||||
}
|
||||
fields = append(fields, field)
|
||||
}
|
||||
|
@ -165,7 +174,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
|||
|
||||
results := map[string]interface{}{}
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.IsPrimaryKey {
|
||||
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
|
||||
results[field.DbName] = field.Value
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue