feat: add MapColumns method (#6901)

* add MapColumns method

* fix MapColumns desc

* add TestMapColumns
This commit is contained in:
molon 2024-06-24 17:42:59 +08:00 committed by GitHub
parent 8a0af58cc5
commit 11c4331058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 4 deletions

View File

@ -185,6 +185,13 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
return return
} }
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
tx = db.getInstance()
tx.Statement.ColumnMapping = m
return
}
// Where add conditions // Where add conditions
// //
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. // See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.

View File

@ -131,6 +131,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
onConflictDonothing = mode&ScanOnConflictDoNothing != 0 onConflictDonothing = mode&ScanOnConflictDoNothing != 0
) )
if len(db.Statement.ColumnMapping) > 0 {
for i, column := range columns {
v, ok := db.Statement.ColumnMapping[column]
if ok {
columns[i] = v
}
}
}
db.RowsAffected = 0 db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {

View File

@ -30,8 +30,9 @@ type Statement struct {
Clauses map[string]clause.Clause Clauses map[string]clause.Clause
BuildClauses []string BuildClauses []string
Distinct bool Distinct bool
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns
ColumnMapping map[string]string // map columns
Joins []join Joins []join
Preloads map[string][]interface{} Preloads map[string][]interface{}
Settings sync.Map Settings sync.Map
@ -513,6 +514,7 @@ func (stmt *Statement) clone() *Statement {
Distinct: stmt.Distinct, Distinct: stmt.Distinct,
Selects: stmt.Selects, Selects: stmt.Selects,
Omits: stmt.Omits, Omits: stmt.Omits,
ColumnMapping: stmt.ColumnMapping,
Preloads: map[string][]interface{}{}, Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool, ConnPool: stmt.ConnPool,
Schema: stmt.Schema, Schema: stmt.Schema,

View File

@ -860,6 +860,28 @@ func TestOmitWithAllFields(t *testing.T) {
} }
} }
func TestMapColumns(t *testing.T) {
user := User{Name: "MapColumnsUser", Age: 12}
DB.Save(&user)
type result struct {
Name string
Nickname string
Age uint
}
var res result
DB.Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "nickname"}).Scan(&res)
if res.Nickname != user.Name {
t.Errorf("Expected res.Nickname to be %s, but got %s", user.Name, res.Nickname)
}
if res.Name != "" {
t.Errorf("Expected res.Name to be empty, but got %s", res.Name)
}
if res.Age != user.Age {
t.Errorf("Expected res.Age to be %d, but got %d", user.Age, res.Age)
}
}
func TestPluckWithSelect(t *testing.T) { func TestPluckWithSelect(t *testing.T) {
users := []User{ users := []User{
{Name: "pluck_with_select_1", Age: 25}, {Name: "pluck_with_select_1", Age: 25},
@ -1194,7 +1216,6 @@ func TestSubQueryWithRaw(t *testing.T) {
Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}).
Group("name"), Group("name"),
).Count(&count).Error ).Count(&count).Error
if err != nil { if err != nil {
t.Errorf("Expected to get no errors, but got %v", err) t.Errorf("Expected to get no errors, but got %v", err)
} }
@ -1210,7 +1231,6 @@ func TestSubQueryWithRaw(t *testing.T) {
Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}).
Group("name"), Group("name"),
).Count(&count).Error ).Count(&count).Error
if err != nil { if err != nil {
t.Errorf("Expected to get no errors, but got %v", err) t.Errorf("Expected to get no errors, but got %v", err)
} }