From 846a2d401a57ba94b9d5bfc9d45f1f08c2a84d77 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Mar 2016 17:13:48 +0800 Subject: [PATCH] Fix scan columns with same name --- callback_query.go | 2 +- callback_query_preload.go | 18 +++++++----- main.go | 2 +- scope.go | 61 +++++++++++++++++++-------------------- 4 files changed, 41 insertions(+), 42 deletions(-) diff --git a/callback_query.go b/callback_query.go index 32f2e001..93782b1d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -67,7 +67,7 @@ func queryCallback(scope *Scope) { elem = reflect.New(resultType).Elem() } - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap()) + scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) if isSlice { if isPtr { diff --git a/callback_query_preload.go b/callback_query_preload.go index 1c9bbc84..5746f533 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -255,20 +255,23 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface for rows.Next() { var ( elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).fieldsMap() + fields = scope.New(elem.Addr().Interface()).Fields() ) // register foreign keys in join tables + var joinTableFields []*Field for _, sourceKey := range sourceKeys { - fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()} + joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) } - scope.scan(rows, columns, fields) + scope.scan(rows, columns, append(fields, joinTableFields...)) - // generate hashed forkey keys in join table var foreignKeys = make([]interface{}, len(sourceKeys)) - for idx, sourceKey := range sourceKeys { - foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface() + // generate hashed forkey keys in join table + for idx, joinTableField := range joinTableFields { + if !joinTableField.Field.IsNil() { + foreignKeys[idx] = joinTableField.Field.Elem().Interface() + } } hashedSourceKeys := toString(foreignKeys) @@ -284,11 +287,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface indirectScopeValue = scope.IndirectValue() fieldsSourceMap = map[string]reflect.Value{} foreignFieldNames = []string{} - fields = scope.fieldsMap() ) for _, dbName := range relation.ForeignFieldNames { - if field, ok := fields[dbName]; ok { + if field, ok := scope.FieldByName(dbName); ok { foreignFieldNames = append(foreignFieldNames, field.Name) } } diff --git a/main.go b/main.go index d09cf416..243ee208 100644 --- a/main.go +++ b/main.go @@ -277,7 +277,7 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { ) if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.fieldsMap()) + scope.scan(rows, columns, scope.Fields()) } return clone.Error diff --git a/scope.go b/scope.go index da5f7ff3..9a8acbd3 100644 --- a/scope.go +++ b/scope.go @@ -412,16 +412,6 @@ func (scope *Scope) CommitOrRollback() *Scope { // Private Methods For *gorm.Scope //////////////////////////////////////////////////////////////////////////////// -func (scope *Scope) fieldsMap() map[string]*Field { - var results = map[string]*Field{} - for _, field := range scope.Fields() { - if field.IsNormal { - results[field.DBName] = field - } - } - return results -} - func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { if reflectValue.CanAddr() { reflectValue = reflectValue.Addr() @@ -458,33 +448,43 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } -func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) { - var values = make([]interface{}, len(columns)) - var ignored interface{} +func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { + var ( + ignored interface{} + selectFields []*Field + values = make([]interface{}, len(columns)) + selectedColumnsMap = map[string]int{} + resetFields = map[*Field]int{} + ) for index, column := range columns { - if field, ok := fieldsMap[column]; ok { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() + values[index] = &ignored + + selectFields = fields + if idx, ok := selectedColumnsMap[column]; ok { + selectFields = selectFields[idx:] + } + + for _, field := range selectFields { + if field.DBName == column { + if field.Field.Kind() == reflect.Ptr { + values[index] = field.Field.Addr().Interface() + } else { + reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) + reflectValue.Elem().Set(field.Field.Addr()) + values[index] = reflectValue.Interface() + resetFields[field] = index + } + break } - } else { - values[index] = &ignored } } scope.Err(rows.Scan(values...)) - for index, column := range columns { - if field, ok := fieldsMap[column]; ok { - if field.Field.Kind() != reflect.Ptr { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } + for field, index := range resetFields { + if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { + field.Field.Set(v) } } } @@ -710,9 +710,6 @@ func (scope *Scope) whereSQL() (sql string) { func (scope *Scope) selectSQL() string { if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } return "*" } return scope.buildSelectQuery(scope.Search.selects)