Fix scan columns with same name

This commit is contained in:
Jinzhu 2016-03-10 17:13:48 +08:00
parent bd99af5067
commit 846a2d401a
4 changed files with 41 additions and 42 deletions

View File

@ -67,7 +67,7 @@ func queryCallback(scope *Scope) {
elem = reflect.New(resultType).Elem() elem = reflect.New(resultType).Elem()
} }
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap()) scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
if isSlice { if isSlice {
if isPtr { if isPtr {

View File

@ -255,20 +255,23 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
for rows.Next() { for rows.Next() {
var ( var (
elem = reflect.New(fieldType).Elem() elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).fieldsMap() fields = scope.New(elem.Addr().Interface()).Fields()
) )
// register foreign keys in join tables // register foreign keys in join tables
var joinTableFields []*Field
for _, sourceKey := range sourceKeys { for _, sourceKey := range sourceKeys {
fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()} joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
} }
scope.scan(rows, columns, fields) scope.scan(rows, columns, append(fields, joinTableFields...))
// generate hashed forkey keys in join table
var foreignKeys = make([]interface{}, len(sourceKeys)) var foreignKeys = make([]interface{}, len(sourceKeys))
for idx, sourceKey := range sourceKeys { // generate hashed forkey keys in join table
foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface() for idx, joinTableField := range joinTableFields {
if !joinTableField.Field.IsNil() {
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
}
} }
hashedSourceKeys := toString(foreignKeys) hashedSourceKeys := toString(foreignKeys)
@ -284,11 +287,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
indirectScopeValue = scope.IndirectValue() indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string]reflect.Value{} fieldsSourceMap = map[string]reflect.Value{}
foreignFieldNames = []string{} foreignFieldNames = []string{}
fields = scope.fieldsMap()
) )
for _, dbName := range relation.ForeignFieldNames { for _, dbName := range relation.ForeignFieldNames {
if field, ok := fields[dbName]; ok { if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name) foreignFieldNames = append(foreignFieldNames, field.Name)
} }
} }

View File

@ -277,7 +277,7 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
) )
if clone.AddError(err) == nil { if clone.AddError(err) == nil {
scope.scan(rows, columns, scope.fieldsMap()) scope.scan(rows, columns, scope.Fields())
} }
return clone.Error return clone.Error

View File

@ -412,16 +412,6 @@ func (scope *Scope) CommitOrRollback() *Scope {
// Private Methods For *gorm.Scope // Private Methods For *gorm.Scope
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
func (scope *Scope) fieldsMap() map[string]*Field {
var results = map[string]*Field{}
for _, field := range scope.Fields() {
if field.IsNormal {
results[field.DBName] = field
}
}
return results
}
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
if reflectValue.CanAddr() { if reflectValue.CanAddr() {
reflectValue = reflectValue.Addr() reflectValue = reflectValue.Addr()
@ -458,33 +448,43 @@ func (scope *Scope) quoteIfPossible(str string) string {
return str return str
} }
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) { func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
var values = make([]interface{}, len(columns)) var (
var ignored interface{} ignored interface{}
selectFields []*Field
values = make([]interface{}, len(columns))
selectedColumnsMap = map[string]int{}
resetFields = map[*Field]int{}
)
for index, column := range columns { for index, column := range columns {
if field, ok := fieldsMap[column]; ok { values[index] = &ignored
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface() selectFields = fields
} else { if idx, ok := selectedColumnsMap[column]; ok {
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) selectFields = selectFields[idx:]
reflectValue.Elem().Set(field.Field.Addr()) }
values[index] = reflectValue.Interface()
for _, field := range selectFields {
if field.DBName == column {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
reflectValue.Elem().Set(field.Field.Addr())
values[index] = reflectValue.Interface()
resetFields[field] = index
}
break
} }
} else {
values[index] = &ignored
} }
} }
scope.Err(rows.Scan(values...)) scope.Err(rows.Scan(values...))
for index, column := range columns { for field, index := range resetFields {
if field, ok := fieldsMap[column]; ok { if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
if field.Field.Kind() != reflect.Ptr { field.Field.Set(v)
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
}
} }
} }
} }
@ -710,9 +710,6 @@ func (scope *Scope) whereSQL() (sql string) {
func (scope *Scope) selectSQL() string { func (scope *Scope) selectSQL() string {
if len(scope.Search.selects) == 0 { if len(scope.Search.selects) == 0 {
if len(scope.Search.joinConditions) > 0 {
return fmt.Sprintf("%v.*", scope.QuotedTableName())
}
return "*" return "*"
} }
return scope.buildSelectQuery(scope.Search.selects) return scope.buildSelectQuery(scope.Search.selects)