ledisdb/vendor/github.com/yuin/gopher-lua/pm/pm.go

638 lines
13 KiB
Go
Raw Permalink Normal View History

2017-04-15 16:51:03 +03:00
// Lua pattern match functions for Go
package pm
import (
"fmt"
)
const EOS = -1
const _UNKNOWN = -2
/* Error {{{ */
type Error struct {
Pos int
Message string
}
func newError(pos int, message string, args ...interface{}) *Error {
if len(args) == 0 {
return &Error{pos, message}
}
return &Error{pos, fmt.Sprintf(message, args...)}
}
func (e *Error) Error() string {
switch e.Pos {
case EOS:
return fmt.Sprintf("%s at EOS", e.Message)
case _UNKNOWN:
return fmt.Sprintf("%s", e.Message)
default:
return fmt.Sprintf("%s at %d", e.Message, e.Pos)
}
}
/* }}} */
/* MatchData {{{ */
type MatchData struct {
// captured positions
// layout
// xxxx xxxx xxxx xxx0 : caputured positions
// xxxx xxxx xxxx xxx1 : position captured positions
captures []uint32
}
func newMatchState() *MatchData { return &MatchData{[]uint32{}} }
func (st *MatchData) addPosCapture(s, pos int) {
for s+1 >= len(st.captures) {
st.captures = append(st.captures, 0)
}
st.captures[s] = (uint32(pos) << 1) | 1
st.captures[s+1] = (uint32(pos) << 1) | 1
}
func (st *MatchData) setCapture(s, pos int) uint32 {
for s >= len(st.captures) {
st.captures = append(st.captures, 0)
}
v := st.captures[s]
st.captures[s] = (uint32(pos) << 1)
return v
}
func (st *MatchData) restoreCapture(s int, pos uint32) { st.captures[s] = pos }
func (st *MatchData) CaptureLength() int { return len(st.captures) }
func (st *MatchData) IsPosCapture(idx int) bool { return (st.captures[idx] & 1) == 1 }
func (st *MatchData) Capture(idx int) int { return int(st.captures[idx] >> 1) }
/* }}} */
/* scanner {{{ */
type scannerState struct {
Pos int
started bool
}
type scanner struct {
src []byte
State scannerState
saved scannerState
}
func newScanner(src []byte) *scanner {
return &scanner{
src: src,
State: scannerState{
Pos: 0,
started: false,
},
saved: scannerState{},
}
}
func (sc *scanner) Length() int { return len(sc.src) }
func (sc *scanner) Next() int {
if !sc.State.started {
sc.State.started = true
if len(sc.src) == 0 {
sc.State.Pos = EOS
}
} else {
sc.State.Pos = sc.NextPos()
}
if sc.State.Pos == EOS {
return EOS
}
return int(sc.src[sc.State.Pos])
}
func (sc *scanner) CurrentPos() int {
return sc.State.Pos
}
func (sc *scanner) NextPos() int {
if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 {
return EOS
}
if !sc.State.started {
return 0
} else {
return sc.State.Pos + 1
}
}
func (sc *scanner) Peek() int {
cureof := sc.State.Pos == EOS
ch := sc.Next()
if !cureof {
if sc.State.Pos == EOS {
sc.State.Pos = len(sc.src) - 1
} else {
sc.State.Pos--
if sc.State.Pos < 0 {
sc.State.Pos = 0
sc.State.started = false
}
}
}
return ch
}
func (sc *scanner) Save() { sc.saved = sc.State }
func (sc *scanner) Restore() { sc.State = sc.saved }
/* }}} */
/* bytecode {{{ */
type opCode int
const (
opChar opCode = iota
opMatch
opTailMatch
opJmp
opSplit
opSave
opPSave
opBrace
opNumber
)
type inst struct {
OpCode opCode
Class class
Operand1 int
Operand2 int
}
/* }}} */
/* classes {{{ */
type class interface {
Matches(ch int) bool
}
type dotClass struct{}
func (pn *dotClass) Matches(ch int) bool { return true }
type charClass struct {
Ch int
}
func (pn *charClass) Matches(ch int) bool { return pn.Ch == ch }
type singleClass struct {
Class int
}
func (pn *singleClass) Matches(ch int) bool {
ret := false
switch pn.Class {
case 'a', 'A':
ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
case 'c', 'C':
ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F
case 'd', 'D':
ret = '0' <= ch && ch <= '9'
case 'l', 'L':
ret = 'a' <= ch && ch <= 'z'
case 'p', 'P':
ret = (0x21 <= ch && ch <= 0x2f) || (0x30 <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e)
case 's', 'S':
switch ch {
case ' ', '\f', '\n', '\r', '\t', '\v':
ret = true
}
case 'u', 'U':
ret = 'A' <= ch && ch <= 'Z'
case 'w', 'W':
ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
case 'x', 'X':
ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
case 'z', 'Z':
ret = ch == 0
default:
return ch == pn.Class
}
if 'A' <= pn.Class && pn.Class <= 'Z' {
return !ret
}
return ret
}
type setClass struct {
IsNot bool
Classes []class
}
func (pn *setClass) Matches(ch int) bool {
for _, class := range pn.Classes {
if class.Matches(ch) {
return !pn.IsNot
}
}
return pn.IsNot
}
type rangeClass struct {
Begin class
End class
}
func (pn *rangeClass) Matches(ch int) bool {
switch begin := pn.Begin.(type) {
case *charClass:
end, ok := pn.End.(*charClass)
if !ok {
return false
}
return begin.Ch <= ch && ch <= end.Ch
}
return false
}
// }}}
// patterns {{{
type pattern interface{}
type singlePattern struct {
Class class
}
type seqPattern struct {
MustHead bool
MustTail bool
Patterns []pattern
}
type repeatPattern struct {
Type int
Class class
}
type posCapPattern struct{}
type capPattern struct {
Pattern pattern
}
type numberPattern struct {
N int
}
type bracePattern struct {
Begin int
End int
}
// }}}
/* parse {{{ */
func parseClass(sc *scanner, allowset bool) class {
ch := sc.Next()
switch ch {
case '%':
return &singleClass{sc.Next()}
case '.':
if allowset {
return &dotClass{}
} else {
return &charClass{ch}
}
case '[':
if !allowset {
panic(newError(sc.CurrentPos(), "invalid '['"))
}
return parseClassSet(sc)
//case '^' '$', '(', ')', ']', '*', '+', '-', '?':
// panic(newError(sc.CurrentPos(), "invalid %c", ch))
case EOS:
panic(newError(sc.CurrentPos(), "unexpected EOS"))
default:
return &charClass{ch}
}
}
func parseClassSet(sc *scanner) class {
set := &setClass{false, []class{}}
if sc.Peek() == '^' {
set.IsNot = true
sc.Next()
}
isrange := false
for {
ch := sc.Peek()
switch ch {
case '[':
panic(newError(sc.CurrentPos(), "'[' can not be nested"))
case ']':
sc.Next()
goto exit
case EOS:
panic(newError(sc.CurrentPos(), "unexpected EOS"))
case '-':
if len(set.Classes) > 0 {
sc.Next()
isrange = true
continue
}
fallthrough
default:
set.Classes = append(set.Classes, parseClass(sc, false))
}
if isrange {
begin := set.Classes[len(set.Classes)-2]
end := set.Classes[len(set.Classes)-1]
set.Classes = set.Classes[0 : len(set.Classes)-2]
set.Classes = append(set.Classes, &rangeClass{begin, end})
isrange = false
}
}
exit:
if isrange {
set.Classes = append(set.Classes, &charClass{'-'})
}
return set
}
func parsePattern(sc *scanner, toplevel bool) *seqPattern {
pat := &seqPattern{}
if toplevel {
if sc.Peek() == '^' {
sc.Next()
pat.MustHead = true
}
}
for {
ch := sc.Peek()
switch ch {
case '%':
sc.Save()
sc.Next()
switch sc.Peek() {
case '0':
panic(newError(sc.CurrentPos(), "invalid capture index"))
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48})
case 'b':
sc.Next()
pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()})
default:
sc.Restore()
pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
}
case '.', '[':
pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
case ']':
panic(newError(sc.CurrentPos(), "invalid ']'"))
case ')':
if toplevel {
panic(newError(sc.CurrentPos(), "invalid ')'"))
}
return pat
case '(':
sc.Next()
if sc.Peek() == ')' {
sc.Next()
pat.Patterns = append(pat.Patterns, &posCapPattern{})
} else {
ret := &capPattern{parsePattern(sc, false)}
if sc.Peek() != ')' {
panic(newError(sc.CurrentPos(), "unfinished capture"))
}
sc.Next()
pat.Patterns = append(pat.Patterns, ret)
}
case '*', '+', '-', '?':
sc.Next()
if len(pat.Patterns) > 0 {
spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern)
if ok {
pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1]
pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class})
continue
}
}
pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
case '$':
if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) {
pat.MustTail = true
} else {
pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
}
sc.Next()
case EOS:
sc.Next()
goto exit
default:
sc.Next()
pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
}
}
exit:
return pat
}
type iptr struct {
insts []inst
capture int
}
func compilePattern(p pattern, ps ...*iptr) []inst {
var ptr *iptr
toplevel := false
if len(ps) == 0 {
toplevel = true
ptr = &iptr{[]inst{inst{opSave, nil, 0, -1}}, 2}
} else {
ptr = ps[0]
}
switch pat := p.(type) {
case *singlePattern:
ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1})
case *seqPattern:
for _, cp := range pat.Patterns {
compilePattern(cp, ptr)
}
case *repeatPattern:
idx := len(ptr.insts)
switch pat.Type {
case '*':
ptr.insts = append(ptr.insts,
inst{opSplit, nil, idx + 1, idx + 3},
inst{opChar, pat.Class, -1, -1},
inst{opJmp, nil, idx, -1})
case '+':
ptr.insts = append(ptr.insts,
inst{opChar, pat.Class, -1, -1},
inst{opSplit, nil, idx, idx + 2})
case '-':
ptr.insts = append(ptr.insts,
inst{opSplit, nil, idx + 3, idx + 1},
inst{opChar, pat.Class, -1, -1},
inst{opJmp, nil, idx, -1})
case '?':
ptr.insts = append(ptr.insts,
inst{opSplit, nil, idx + 1, idx + 2},
inst{opChar, pat.Class, -1, -1})
}
case *posCapPattern:
ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1})
ptr.capture += 2
case *capPattern:
c0, c1 := ptr.capture, ptr.capture+1
ptr.capture += 2
ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1})
compilePattern(pat.Pattern, ptr)
ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1})
case *bracePattern:
ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End})
case *numberPattern:
ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1})
}
if toplevel {
if p.(*seqPattern).MustTail {
ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1})
}
ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1})
}
return ptr.insts
}
/* }}} parse */
/* VM {{{ */
// Simple recursive virtual machine based on the
// "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html)
func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) {
var m *MatchData
if len(ms) == 0 {
m = newMatchState()
} else {
m = ms[0]
}
redo:
inst := insts[pc]
switch inst.OpCode {
case opChar:
if sp >= len(src) || !inst.Class.Matches(int(src[sp])) {
return false, sp, m
}
pc++
sp++
goto redo
case opMatch:
return true, sp, m
case opTailMatch:
return sp >= len(src), sp, m
case opJmp:
pc = inst.Operand1
goto redo
case opSplit:
if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok {
return true, nsp, m
}
pc = inst.Operand2
goto redo
case opSave:
s := m.setCapture(inst.Operand1, sp)
if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok {
return true, nsp, m
}
m.restoreCapture(inst.Operand1, s)
return false, sp, m
case opPSave:
m.addPosCapture(inst.Operand1, sp+1)
pc++
goto redo
case opBrace:
if sp >= len(src) || int(src[sp]) != inst.Operand1 {
return false, sp, m
}
count := 1
for sp = sp + 1; sp < len(src); sp++ {
if int(src[sp]) == inst.Operand2 {
count--
}
if count == 0 {
pc++
sp++
goto redo
}
if int(src[sp]) == inst.Operand1 {
count++
}
}
return false, sp, m
case opNumber:
idx := inst.Operand1 * 2
if idx >= m.CaptureLength()-1 {
panic(newError(_UNKNOWN, "invalid capture index"))
}
capture := src[m.Capture(idx):m.Capture(idx+1)]
for i := 0; i < len(capture); i++ {
if i+sp >= len(src) || capture[i] != src[i+sp] {
return false, sp, m
}
}
pc++
sp += len(capture)
goto redo
}
panic("should not reach here")
return false, sp, m
}
/* }}} */
/* API {{{ */
func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) {
defer func() {
if v := recover(); v != nil {
if perr, ok := v.(*Error); ok {
err = perr
} else {
panic(v)
}
}
}()
pat := parsePattern(newScanner([]byte(p)), true)
insts := compilePattern(pat)
matches = []*MatchData{}
for sp := offset; sp <= len(src); {
ok, nsp, ms := recursiveVM(src, insts, 0, sp)
sp++
if ok {
if sp < nsp {
sp = nsp
}
matches = append(matches, ms)
}
if len(matches) == limit || pat.MustHead {
break
}
}
return
}
/* }}} */