pkger/parser/visitor.go

294 lines
5.4 KiB
Go
Raw Normal View History

2019-08-01 19:03:12 +03:00
package parser
import (
"fmt"
"go/ast"
"strconv"
2019-08-02 05:34:32 +03:00
"github.com/markbates/pkger"
2019-08-01 19:03:12 +03:00
)
2019-08-02 06:21:37 +03:00
type visitor struct {
2019-08-01 19:03:12 +03:00
File string
2019-08-02 05:34:32 +03:00
Found map[pkger.Path]bool
2019-08-01 19:03:12 +03:00
errors []error
}
2019-08-02 06:21:37 +03:00
func newVisitor(p string) (*visitor, error) {
return &visitor{
2019-08-01 19:03:12 +03:00
File: p,
2019-08-02 05:34:32 +03:00
Found: map[pkger.Path]bool{},
2019-08-01 19:03:12 +03:00
}, nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) Run() ([]pkger.Path, error) {
2019-08-01 19:06:50 +03:00
pf, err := parseFile(v.File)
2019-08-01 19:03:12 +03:00
if err != nil {
return nil, err
}
ast.Walk(v, pf.Ast)
2019-08-02 05:34:32 +03:00
var found []pkger.Path
2019-08-01 19:03:12 +03:00
for k := range v.Found {
found = append(found, k)
}
return found, nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) addPath(p string) error {
2019-08-01 19:03:12 +03:00
p, _ = strconv.Unquote(p)
2019-08-02 05:34:32 +03:00
pt, err := pkger.Parse(p)
2019-08-01 19:03:12 +03:00
if err != nil {
return err
}
v.Found[pt] = true
return nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) Visit(node ast.Node) ast.Visitor {
2019-08-01 19:03:12 +03:00
if node == nil {
return v
}
if err := v.eval(node); err != nil {
v.errors = append(v.errors, err)
}
return v
}
2019-08-02 06:21:37 +03:00
func (v *visitor) eval(node ast.Node) error {
2019-08-01 19:03:12 +03:00
switch t := node.(type) {
case *ast.CallExpr:
return v.evalExpr(t)
case *ast.Ident:
return v.evalIdent(t)
case *ast.GenDecl:
for _, n := range t.Specs {
if err := v.eval(n); err != nil {
return err
}
}
case *ast.FuncDecl:
if t.Body == nil {
return nil
}
for _, b := range t.Body.List {
if err := v.evalStmt(b); err != nil {
return err
}
}
return nil
case *ast.ValueSpec:
for _, e := range t.Values {
if err := v.evalExpr(e); err != nil {
return err
}
}
}
return nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) evalStmt(stmt ast.Stmt) error {
2019-08-01 19:03:12 +03:00
switch t := stmt.(type) {
case *ast.ExprStmt:
return v.evalExpr(t.X)
case *ast.AssignStmt:
for _, e := range t.Rhs {
if err := v.evalArgs(e); err != nil {
return err
}
}
}
return nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) evalExpr(expr ast.Expr) error {
2019-08-01 19:03:12 +03:00
switch t := expr.(type) {
case *ast.CallExpr:
if t.Fun == nil {
return nil
}
for _, a := range t.Args {
switch at := a.(type) {
case *ast.CallExpr:
if sel, ok := t.Fun.(*ast.SelectorExpr); ok {
return v.evalSelector(at, sel)
}
if err := v.evalArgs(at); err != nil {
return err
}
case *ast.CompositeLit:
for _, e := range at.Elts {
if err := v.evalExpr(e); err != nil {
return err
}
}
}
}
if ft, ok := t.Fun.(*ast.SelectorExpr); ok {
return v.evalSelector(t, ft)
}
case *ast.KeyValueExpr:
return v.evalExpr(t.Value)
}
return nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) evalArgs(expr ast.Expr) error {
2019-08-01 19:03:12 +03:00
switch at := expr.(type) {
case *ast.CompositeLit:
for _, e := range at.Elts {
if err := v.evalExpr(e); err != nil {
return err
}
}
case *ast.CallExpr:
if at.Fun == nil {
return nil
}
switch st := at.Fun.(type) {
case *ast.SelectorExpr:
if err := v.evalSelector(at, st); err != nil {
return err
}
case *ast.Ident:
return v.evalIdent(st)
}
for _, a := range at.Args {
if err := v.evalArgs(a); err != nil {
return err
}
}
}
return nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) evalSelector(expr *ast.CallExpr, sel *ast.SelectorExpr) error {
2019-08-01 19:03:12 +03:00
x, ok := sel.X.(*ast.Ident)
if !ok {
return nil
}
if x.Name == "pkger" {
switch sel.Sel.Name {
case "Walk":
if len(expr.Args) != 2 {
return fmt.Errorf("`New` requires two arguments")
}
zz := func(e ast.Expr) (string, error) {
switch at := e.(type) {
case *ast.Ident:
switch at.Obj.Kind {
case ast.Var:
if as, ok := at.Obj.Decl.(*ast.AssignStmt); ok {
return v.fromVariable(as)
}
case ast.Con:
if vs, ok := at.Obj.Decl.(*ast.ValueSpec); ok {
return v.fromConstant(vs)
}
}
return "", v.evalIdent(at)
case *ast.BasicLit:
return at.Value, nil
case *ast.CallExpr:
return "", v.evalExpr(at)
}
return "", fmt.Errorf("can't handle %T", e)
}
k1, err := zz(expr.Args[0])
if err != nil {
return err
}
if err := v.addPath(k1); err != nil {
return err
}
return nil
case "Open":
for _, e := range expr.Args {
switch at := e.(type) {
case *ast.Ident:
switch at.Obj.Kind {
case ast.Var:
if as, ok := at.Obj.Decl.(*ast.AssignStmt); ok {
v.addVariable("", as)
}
case ast.Con:
if vs, ok := at.Obj.Decl.(*ast.ValueSpec); ok {
v.addConstant("", vs)
}
}
return v.evalIdent(at)
case *ast.BasicLit:
return v.addPath(at.Value)
case *ast.CallExpr:
return v.evalExpr(at)
}
}
}
}
return nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) evalIdent(i *ast.Ident) error {
2019-08-01 19:03:12 +03:00
if i.Obj == nil {
return nil
}
if s, ok := i.Obj.Decl.(*ast.AssignStmt); ok {
return v.evalStmt(s)
}
return nil
}
2019-08-02 06:21:37 +03:00
func (v *visitor) fromVariable(as *ast.AssignStmt) (string, error) {
2019-08-01 19:03:12 +03:00
if len(as.Rhs) == 1 {
if bs, ok := as.Rhs[0].(*ast.BasicLit); ok {
return bs.Value, nil
}
}
return "", fmt.Errorf("unable to find value from variable %v", as)
}
2019-08-02 06:21:37 +03:00
func (v *visitor) addVariable(bn string, as *ast.AssignStmt) error {
2019-08-01 19:03:12 +03:00
bv, err := v.fromVariable(as)
if err != nil {
return nil
}
if len(bn) == 0 {
bn = bv
}
return v.addPath(bn)
}
2019-08-02 06:21:37 +03:00
func (v *visitor) fromConstant(vs *ast.ValueSpec) (string, error) {
2019-08-01 19:03:12 +03:00
if len(vs.Values) == 1 {
if bs, ok := vs.Values[0].(*ast.BasicLit); ok {
return bs.Value, nil
}
}
return "", fmt.Errorf("unable to find value from constant %v", vs)
}
2019-08-02 06:21:37 +03:00
func (v *visitor) addConstant(bn string, vs *ast.ValueSpec) error {
2019-08-01 19:03:12 +03:00
if len(vs.Values) == 1 {
if bs, ok := vs.Values[0].(*ast.BasicLit); ok {
bv := bs.Value
if len(bn) == 0 {
bn = bv
}
return v.addPath(bn)
}
}
return nil
}