diff --git a/main.go b/main.go index 43ce867e..abf20c59 100644 --- a/main.go +++ b/main.go @@ -224,7 +224,7 @@ func (s *DB) Related(value interface{}, foreign_keys ...string) *DB { } func (s *DB) Pluck(column string, value interface{}) *DB { - return s.do(s.Value).pluck(column, value).db + return s.NewScope(s.Value).pluck(column, value).db } func (s *DB) Count(value interface{}) *DB { diff --git a/scope.go b/scope.go index 6e1456d0..aef54e51 100644 --- a/scope.go +++ b/scope.go @@ -191,6 +191,10 @@ func (scope *Scope) TableName() string { if len(scope.Search.tableName) > 0 { return scope.Search.tableName } else { + if scope.Value == nil { + scope.Err(errors.New("can't get table name")) + return "" + } data := reflect.Indirect(reflect.ValueOf(scope.Value)) if data.Kind() == reflect.Slice { @@ -380,22 +384,6 @@ func (scope *Scope) CommitOrRollback() *Scope { return scope } -func (scope *Scope) prepareQuerySql() { - if scope.Search.raw { - scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.TableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values []interface{}) *Scope { - if len(values) > 0 { - scope.Search = scope.Search.clone().where(values[0], values[1:]...) - } - return scope -} - func (scope *Scope) row() *sql.Row { defer scope.Trace(time.Now()) scope.prepareQuerySql() @@ -416,3 +404,23 @@ func (scope *Scope) initialize() *Scope { scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) return scope } + +func (scope *Scope) pluck(column string, value interface{}) *Scope { + dest := reflect.Indirect(reflect.ValueOf(value)) + scope.Search = scope.Search.clone().selects(column) + if dest.Kind() != reflect.Slice { + scope.Err(errors.New("Results should be a slice")) + return scope + } + + rows, err := scope.rows() + if scope.Err(err) == nil { + defer rows.Close() + for rows.Next() { + elem := reflect.New(dest.Type().Elem()).Interface() + scope.Err(rows.Scan(elem)) + dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) + } + } + return scope +} diff --git a/scope_condition.go b/scope_condition.go index d57d3a12..556855b3 100644 --- a/scope_condition.go +++ b/scope_condition.go @@ -233,3 +233,19 @@ func (s *Scope) havingSql() string { func (s *Scope) joinsSql() string { return s.Search.joinsStr + " " } + +func (scope *Scope) prepareQuerySql() { + if scope.Search.raw { + scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) + } else { + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.TableName(), scope.CombinedConditionSql())) + } + return +} + +func (scope *Scope) inlineCondition(values []interface{}) *Scope { + if len(values) > 0 { + scope.Search = scope.Search.clone().where(values[0], values[1:]...) + } + return scope +}