From 281c5d10f6e00b1644f1473c7b7ecc6cf88724cb Mon Sep 17 00:00:00 2001 From: kimiby Date: Sun, 16 Aug 2015 12:36:23 +0300 Subject: [PATCH] preload_m2m improve --- model_struct.go | 20 ++++++++++++-------- preload.go | 17 ++++------------- utils_private.go | 10 ---------- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/model_struct.go b/model_struct.go index 26c58fc5..db6d9a88 100644 --- a/model_struct.go +++ b/model_struct.go @@ -62,14 +62,16 @@ func (structField *StructField) clone() *StructField { } type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignStructFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignStructFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface } func (scope *Scope) GetModelStruct() *ModelStruct { @@ -224,6 +226,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if field, ok := scope.FieldByName(foreignKey); ok { relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name) joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } @@ -242,6 +245,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, name := range associationForeignKeys { if field, ok := toScope.FieldByName(name); ok { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } diff --git a/preload.go b/preload.go index dd19a8bc..c7810b63 100644 --- a/preload.go +++ b/preload.go @@ -16,15 +16,6 @@ func getRealValue(value reflect.Value, columns []string) (results []interface{}) result, _ = r.Value() } results = append(results, result) - } else { - column = upFL(column) - if reflect.Indirect(value).FieldByName(column).IsValid() { - result := reflect.Indirect(value).FieldByName(column).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } } } return @@ -283,11 +274,11 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf var checked []string object := reflect.Indirect(objects.Index(j)) - source := getRealValue(object, relation.AssociationForeignFieldNames) + source := getRealValue(object, relation.AssociationForeignStructFieldNames) for i := 0; i < results.Len(); i++ { result := results.Index(i) - value := getRealValue(result, relation.ForeignFieldNames) + value := getRealValue(result, relation.ForeignStructFieldNames) if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { f := object.FieldByName(field.Name) @@ -300,11 +291,11 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf } else { object := scope.IndirectValue() var checked []string - source := getRealValue(object, relation.AssociationForeignFieldNames) + source := getRealValue(object, relation.AssociationForeignStructFieldNames) for i := 0; i < results.Len(); i++ { result := results.Index(i) - value := getRealValue(result, relation.ForeignFieldNames) + value := getRealValue(result, relation.ForeignStructFieldNames) if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { f := object.FieldByName(field.Name) diff --git a/utils_private.go b/utils_private.go index 8b43453f..b82aa807 100644 --- a/utils_private.go +++ b/utils_private.go @@ -5,8 +5,6 @@ import ( "reflect" "regexp" "runtime" - "unicode" - "unicode/utf8" ) func fileWithLineNum() string { @@ -86,11 +84,3 @@ func strInSlice(a string, list []string) bool { } return false } - -func upFL(s string) string { - if s == "" { - return "" - } - r, n := utf8.DecodeRuneInString(s) - return string(unicode.ToUpper(r)) + s[n:] -}