Refact parse SQL tag, set default size to 255

This commit is contained in:
Jinzhu 2014-07-29 17:52:23 +08:00
parent 2bffb43138
commit f56e071375
4 changed files with 54 additions and 55 deletions

View File

@ -3,8 +3,6 @@ package gorm
import ( import (
"database/sql" "database/sql"
"reflect" "reflect"
"strconv"
"strings"
"time" "time"
) )
@ -31,37 +29,3 @@ func (f *Field) IsTime() bool {
_, isTime := f.Value.(time.Time) _, isTime := f.Value.(time.Time)
return isTime return isTime
} }
func parseSqlTag(str string) (typ string, additionalType string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
additionalType = m["NOT NULL"] + " " + m["UNIQUE"]
if len(m["DEFAULT"]) > 0 {
additionalType = additionalType + "DEFAULT " + m["DEFAULT"]
}
}
return
}

View File

@ -129,19 +129,19 @@ func TestRelated(t *testing.T) {
func TestQueryManyToManyWithRelated(t *testing.T) { func TestQueryManyToManyWithRelated(t *testing.T) {
var languages = []Language{{Name: "ZH"}, {Name: "EN"}, {Name: "DE"}} var languages = []Language{{Name: "ZH"}, {Name: "EN"}, {Name: "DE"}}
user := User{Name: "Many2Many", Languages: languages} user := User{Name: "Many2Many", Languages: languages}
db.Debug().Save(&user) db.Save(&user)
var newLanguages []Language // var newLanguages []Language
db.Model(&user).Related(&newLanguages, "Languages") // db.Model(&user).Related(&newLanguages, "Languages")
if len(newLanguages) != 3 { // if len(newLanguages) != 3 {
t.Errorf("Query many to many relations") // t.Errorf("Query many to many relations")
} // }
newLanguages = []Language{} // newLanguages = []Language{}
db.Model(&user).Many2Many("Languages").Find(&newLanguages) // db.Model(&user).Many2Many("Languages").Find(&newLanguages)
if len(newLanguages) != 3 { // if len(newLanguages) != 3 {
t.Errorf("Query many to many relations") // t.Errorf("Query many to many relations")
} // }
// db.Model(&User{}).Many2Many("Languages").Add(&Language{}) // db.Model(&User{}).Many2Many("Languages").Add(&Language{})
// db.Model(&User{}).Many2Many("Languages").Remove(&Language{}) // db.Model(&User{}).Many2Many("Languages").Remove(&Language{})

View File

@ -299,13 +299,33 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
return return
} }
func (scope *Scope) sqlTagForField(field *Field) (tag string) { func (scope *Scope) sqlTagForField(field *Field) (typ string) {
tag, additionalTag, size := parseSqlTag(field.Tag.Get(scope.db.parent.tagIdentifier)) var size = 255
if tag == "-" { fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier)
if fieldTag == "-" {
field.IsIgnored = true field.IsIgnored = true
} }
var setting = parseTagSetting(fieldTag)
if value, ok := setting["SIZE"]; ok {
if i, err := strconv.Atoi(value); err == nil {
size = i
} else {
size = 0
}
}
if value, ok := setting["TYPE"]; ok {
typ = value
}
additionalType := setting["NOT NULL"] + " " + setting["UNIQUE"]
if value, ok := setting["DEFAULT"]; ok {
additionalType = additionalType + "DEFAULT " + value
}
value := field.Value value := field.Value
reflectValue := reflect.ValueOf(value) reflectValue := reflect.ValueOf(value)
@ -322,16 +342,16 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
} }
} }
if len(tag) == 0 { if len(typ) == 0 {
if field.isPrimaryKey { if field.isPrimaryKey {
tag = scope.Dialect().PrimaryKeyTag(reflectValue, size) typ = scope.Dialect().PrimaryKeyTag(reflectValue, size)
} else { } else {
tag = scope.Dialect().SqlTag(reflectValue, size) typ = scope.Dialect().SqlTag(reflectValue, size)
} }
} }
if len(additionalTag) > 0 { if len(additionalType) > 0 {
tag = tag + " " + additionalTag typ = typ + " " + additionalType
} }
return return
} }

View File

@ -114,3 +114,18 @@ func GetPrimaryKey(value interface{}) string {
return "Id" return "Id"
} }
func parseTagSetting(str string) map[string]string {
tags := strings.Split(str, ";")
setting := map[string]string{}
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) == 2 {
setting[k] = v[1]
} else {
setting[k] = k
}
}
return setting
}