Test Pluck

This commit is contained in:
Jinzhu 2020-05-31 21:11:20 +08:00
parent e26abb84b3
commit 95a6539331
3 changed files with 54 additions and 11 deletions

View File

@ -273,6 +273,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
// db.Find(&users).Pluck("age", &ages) // db.Find(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Select{Columns: []clause.Column{{Name: column}}})
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) tx.callbacks.Query().Execute(tx)
return return

32
scan.go
View File

@ -58,7 +58,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
default: default:
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr reflectValueType := db.Statement.ReflectValue.Type().Elem()
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
fields := make([]*schema.Field, len(columns)) fields := make([]*schema.Field, len(columns))
joinFields := make([][2]*schema.Field, len(columns)) joinFields := make([][2]*schema.Field, len(columns))
@ -81,17 +86,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
elem := reflect.New(db.Statement.Schema.ModelType).Elem() elem := reflect.New(reflectValueType).Elem()
for idx, field := range fields {
if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
values[0] = elem.Addr().Interface()
} else {
for idx, field := range fields {
if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
}
} }
} }

View File

@ -80,3 +80,35 @@ func TestFind(t *testing.T) {
} }
} }
} }
func TestPluck(t *testing.T) {
users := []*User{
GetUser("pluck-user1", Config{}),
GetUser("pluck-user2", Config{}),
GetUser("pluck-user3", Config{}),
}
DB.Create(&users)
var names []string
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil {
t.Errorf("Raise error when pluck name, got %v", err)
}
var ids []int
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("id", &ids).Error; err != nil {
t.Errorf("Raise error when pluck id, got %v", err)
}
for idx, name := range names {
if name != users[idx].Name {
t.Errorf("Unexpected result on pluck name, got %+v", names)
}
}
for idx, id := range ids {
if int(id) != int(users[idx].ID) {
t.Errorf("Unexpected result on pluck id, got %+v", ids)
}
}
}