Compatible with both gorm, sql tag

This commit is contained in:
Jinzhu 2016-01-03 09:52:27 +08:00
parent 5d2b9bfe34
commit 19b85b1f17
3 changed files with 17 additions and 16 deletions

View File

@ -512,7 +512,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
scope := s.NewScope(source) scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields { for _, field := range scope.GetModelStruct().StructFields {
if field.Name == column || field.DBName == column { if field.Name == column || field.DBName == column {
if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { if many2many := parseTagSetting(field.Tag)["MANY2MANY"]; many2many != "" {
source := (&Scope{Value: source}).GetModelStruct().ModelType source := (&Scope{Value: source}).GetModelStruct().ModelType
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination) handler.Setup(field.Relationship, many2many, source, destination)

View File

@ -145,14 +145,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.IsIgnored = true field.IsIgnored = true
} }
sqlSettings := parseTagSetting(field.Tag.Get("sql")) gormSettings := parseTagSetting(field.Tag)
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
if _, ok := gormSettings["PRIMARY_KEY"]; ok { if _, ok := gormSettings["PRIMARY_KEY"]; ok {
field.IsPrimaryKey = true field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
} }
if _, ok := sqlSettings["DEFAULT"]; ok { if _, ok := gormSettings["DEFAULT"]; ok {
field.HasDefaultValue = true field.HasDefaultValue = true
} }
@ -185,7 +184,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
if !field.IsNormal { if !field.IsNormal {
gormSettings := parseTagSetting(field.Tag.Get("gorm")) gormSettings := parseTagSetting(field.Tag)
toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
getForeignField := func(column string, fields []*StructField) *StructField { getForeignField := func(column string, fields []*StructField) *StructField {
@ -400,7 +399,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
structType = structType.Elem() structType = structType.Elem()
} }
reflectValue := reflect.Indirect(reflect.New(structType)) reflectValue := reflect.Indirect(reflect.New(structType))
sqlSettings := parseTagSetting(field.Tag.Get("sql")) sqlSettings := parseTagSetting(field.Tag)
if value, ok := sqlSettings["TYPE"]; ok { if value, ok := sqlSettings["TYPE"]; ok {
sqlType = value sqlType = value
@ -447,9 +446,10 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
} }
} }
func parseTagSetting(str string) map[string]string { func parseTagSetting(tags reflect.StructTag) map[string]string {
tags := strings.Split(str, ";")
setting := map[string]string{} setting := map[string]string{}
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
tags := strings.Split(str, ";")
for _, value := range tags { for _, value := range tags {
v := strings.Split(value, ":") v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0])) k := strings.TrimSpace(strings.ToUpper(v[0]))
@ -459,5 +459,6 @@ func parseTagSetting(str string) map[string]string {
setting[k] = k setting[k] = k
} }
} }
}
return setting return setting
} }

View File

@ -630,7 +630,7 @@ func (scope *Scope) autoIndex() *Scope {
var uniqueIndexes = map[string][]string{} var uniqueIndexes = map[string][]string{}
for _, field := range scope.GetStructFields() { for _, field := range scope.GetStructFields() {
sqlSettings := parseTagSetting(field.Tag.Get("sql")) sqlSettings := parseTagSetting(field.Tag)
if name, ok := sqlSettings["INDEX"]; ok { if name, ok := sqlSettings["INDEX"]; ok {
if name == "INDEX" { if name == "INDEX" {
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)