Fix scan columns with same name

This commit is contained in:
Jinzhu 2016-03-10 17:13:48 +08:00
parent bd99af5067
commit 846a2d401a
4 changed files with 41 additions and 42 deletions

View File

@ -67,7 +67,7 @@ func queryCallback(scope *Scope) {
elem = reflect.New(resultType).Elem()
}
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap())
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
if isSlice {
if isPtr {

View File

@ -255,20 +255,23 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
for rows.Next() {
var (
elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).fieldsMap()
fields = scope.New(elem.Addr().Interface()).Fields()
)
// register foreign keys in join tables
var joinTableFields []*Field
for _, sourceKey := range sourceKeys {
fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()}
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
}
scope.scan(rows, columns, fields)
scope.scan(rows, columns, append(fields, joinTableFields...))
// generate hashed forkey keys in join table
var foreignKeys = make([]interface{}, len(sourceKeys))
for idx, sourceKey := range sourceKeys {
foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface()
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {
if !joinTableField.Field.IsNil() {
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
}
}
hashedSourceKeys := toString(foreignKeys)
@ -284,11 +287,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string]reflect.Value{}
foreignFieldNames = []string{}
fields = scope.fieldsMap()
)
for _, dbName := range relation.ForeignFieldNames {
if field, ok := fields[dbName]; ok {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}

View File

@ -277,7 +277,7 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
)
if clone.AddError(err) == nil {
scope.scan(rows, columns, scope.fieldsMap())
scope.scan(rows, columns, scope.Fields())
}
return clone.Error

View File

@ -412,16 +412,6 @@ func (scope *Scope) CommitOrRollback() *Scope {
// Private Methods For *gorm.Scope
////////////////////////////////////////////////////////////////////////////////
func (scope *Scope) fieldsMap() map[string]*Field {
var results = map[string]*Field{}
for _, field := range scope.Fields() {
if field.IsNormal {
results[field.DBName] = field
}
}
return results
}
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
if reflectValue.CanAddr() {
reflectValue = reflectValue.Addr()
@ -458,33 +448,43 @@ func (scope *Scope) quoteIfPossible(str string) string {
return str
}
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
var values = make([]interface{}, len(columns))
var ignored interface{}
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
var (
ignored interface{}
selectFields []*Field
values = make([]interface{}, len(columns))
selectedColumnsMap = map[string]int{}
resetFields = map[*Field]int{}
)
for index, column := range columns {
if field, ok := fieldsMap[column]; ok {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
reflectValue.Elem().Set(field.Field.Addr())
values[index] = reflectValue.Interface()
values[index] = &ignored
selectFields = fields
if idx, ok := selectedColumnsMap[column]; ok {
selectFields = selectFields[idx:]
}
for _, field := range selectFields {
if field.DBName == column {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
reflectValue.Elem().Set(field.Field.Addr())
values[index] = reflectValue.Interface()
resetFields[field] = index
}
break
}
} else {
values[index] = &ignored
}
}
scope.Err(rows.Scan(values...))
for index, column := range columns {
if field, ok := fieldsMap[column]; ok {
if field.Field.Kind() != reflect.Ptr {
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
}
for field, index := range resetFields {
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
}
}
@ -710,9 +710,6 @@ func (scope *Scope) whereSQL() (sql string) {
func (scope *Scope) selectSQL() string {
if len(scope.Search.selects) == 0 {
if len(scope.Search.joinConditions) > 0 {
return fmt.Sprintf("%v.*", scope.QuotedTableName())
}
return "*"
}
return scope.buildSelectQuery(scope.Search.selects)