forked from mirror/gorm
Handle SubStruct
This commit is contained in:
parent
49cfb0d4a0
commit
28b49124eb
2
do.go
2
do.go
|
@ -566,8 +566,10 @@ func (s *Do) combinedSql() string {
|
||||||
func (s *Do) createTable() *Do {
|
func (s *Do) createTable() *Do {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range s.model.fields("") {
|
for _, field := range s.model.fields("") {
|
||||||
|
if len(field.SqlType) > 0 {
|
||||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.sql = fmt.Sprintf(
|
s.sql = fmt.Sprintf(
|
||||||
"CREATE TABLE \"%v\" (%v)",
|
"CREATE TABLE \"%v\" (%v)",
|
||||||
|
|
36
gorm_test.go
36
gorm_test.go
|
@ -995,3 +995,39 @@ func TestNot(t *testing.T) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Category struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
Id int64
|
||||||
|
Title string
|
||||||
|
Body string
|
||||||
|
Comments []Comment
|
||||||
|
Category Category
|
||||||
|
}
|
||||||
|
|
||||||
|
type Comment struct {
|
||||||
|
Id int64
|
||||||
|
Content string
|
||||||
|
Post Post
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubStruct(t *testing.T) {
|
||||||
|
db.CreateTable(Category{})
|
||||||
|
db.CreateTable(Post{})
|
||||||
|
db.CreateTable(Comment{})
|
||||||
|
|
||||||
|
post := Post{
|
||||||
|
Title: "post 1",
|
||||||
|
Body: "body 1",
|
||||||
|
Comments: []Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
|
||||||
|
Category: Category{Name: "category"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Save(&post).Error; err != nil {
|
||||||
|
t.Errorf("Got errors when save post", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
19
model.go
19
model.go
|
@ -77,6 +77,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
field.AutoCreateTime = "created_at" == field.DbName
|
field.AutoCreateTime = "created_at" == field.DbName
|
||||||
field.AutoUpdateTime = "updated_at" == field.DbName
|
field.AutoUpdateTime = "updated_at" == field.DbName
|
||||||
value := indirect_value.FieldByName(p.Name)
|
value := indirect_value.FieldByName(p.Name)
|
||||||
|
time_value, is_time := value.Interface().(time.Time)
|
||||||
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Int, reflect.Int64, reflect.Int32:
|
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||||
|
@ -84,15 +85,15 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.IsBlank = value.String() == ""
|
field.IsBlank = value.String() == ""
|
||||||
default:
|
default:
|
||||||
if value, ok := value.Interface().(time.Time); ok {
|
if is_time {
|
||||||
field.IsBlank = value.IsZero()
|
field.IsBlank = time_value.IsZero()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := value.Interface().(time.Time); ok {
|
if is_time {
|
||||||
switch operation {
|
switch operation {
|
||||||
case "create":
|
case "create":
|
||||||
if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() {
|
if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
case "update":
|
case "update":
|
||||||
|
@ -107,8 +108,16 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
if field.IsPrimaryKey {
|
if field.IsPrimaryKey {
|
||||||
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
|
field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0)
|
||||||
} else {
|
} else {
|
||||||
|
switch reflect.TypeOf(field.Value).Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
case reflect.Struct:
|
||||||
|
if is_time {
|
||||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,7 +174,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
||||||
|
|
||||||
results := map[string]interface{}{}
|
results := map[string]interface{}{}
|
||||||
for _, field := range m.fields(operation) {
|
for _, field := range m.fields(operation) {
|
||||||
if !field.IsPrimaryKey {
|
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
|
||||||
results[field.DbName] = field.Value
|
results[field.DbName] = field.Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue