mirror of https://github.com/ledisdb/ledisdb.git
638 lines
13 KiB
Go
638 lines
13 KiB
Go
// Lua pattern match functions for Go
|
|
package pm
|
|
|
|
import (
|
|
"fmt"
|
|
)
|
|
|
|
const EOS = -1
|
|
const _UNKNOWN = -2
|
|
|
|
/* Error {{{ */
|
|
|
|
type Error struct {
|
|
Pos int
|
|
Message string
|
|
}
|
|
|
|
func newError(pos int, message string, args ...interface{}) *Error {
|
|
if len(args) == 0 {
|
|
return &Error{pos, message}
|
|
}
|
|
return &Error{pos, fmt.Sprintf(message, args...)}
|
|
}
|
|
|
|
func (e *Error) Error() string {
|
|
switch e.Pos {
|
|
case EOS:
|
|
return fmt.Sprintf("%s at EOS", e.Message)
|
|
case _UNKNOWN:
|
|
return fmt.Sprintf("%s", e.Message)
|
|
default:
|
|
return fmt.Sprintf("%s at %d", e.Message, e.Pos)
|
|
}
|
|
}
|
|
|
|
/* }}} */
|
|
|
|
/* MatchData {{{ */
|
|
|
|
type MatchData struct {
|
|
// captured positions
|
|
// layout
|
|
// xxxx xxxx xxxx xxx0 : caputured positions
|
|
// xxxx xxxx xxxx xxx1 : position captured positions
|
|
captures []uint32
|
|
}
|
|
|
|
func newMatchState() *MatchData { return &MatchData{[]uint32{}} }
|
|
|
|
func (st *MatchData) addPosCapture(s, pos int) {
|
|
for s+1 >= len(st.captures) {
|
|
st.captures = append(st.captures, 0)
|
|
}
|
|
st.captures[s] = (uint32(pos) << 1) | 1
|
|
st.captures[s+1] = (uint32(pos) << 1) | 1
|
|
}
|
|
|
|
func (st *MatchData) setCapture(s, pos int) uint32 {
|
|
for s >= len(st.captures) {
|
|
st.captures = append(st.captures, 0)
|
|
}
|
|
v := st.captures[s]
|
|
st.captures[s] = (uint32(pos) << 1)
|
|
return v
|
|
}
|
|
|
|
func (st *MatchData) restoreCapture(s int, pos uint32) { st.captures[s] = pos }
|
|
|
|
func (st *MatchData) CaptureLength() int { return len(st.captures) }
|
|
|
|
func (st *MatchData) IsPosCapture(idx int) bool { return (st.captures[idx] & 1) == 1 }
|
|
|
|
func (st *MatchData) Capture(idx int) int { return int(st.captures[idx] >> 1) }
|
|
|
|
/* }}} */
|
|
|
|
/* scanner {{{ */
|
|
|
|
type scannerState struct {
|
|
Pos int
|
|
started bool
|
|
}
|
|
|
|
type scanner struct {
|
|
src []byte
|
|
State scannerState
|
|
saved scannerState
|
|
}
|
|
|
|
func newScanner(src []byte) *scanner {
|
|
return &scanner{
|
|
src: src,
|
|
State: scannerState{
|
|
Pos: 0,
|
|
started: false,
|
|
},
|
|
saved: scannerState{},
|
|
}
|
|
}
|
|
|
|
func (sc *scanner) Length() int { return len(sc.src) }
|
|
|
|
func (sc *scanner) Next() int {
|
|
if !sc.State.started {
|
|
sc.State.started = true
|
|
if len(sc.src) == 0 {
|
|
sc.State.Pos = EOS
|
|
}
|
|
} else {
|
|
sc.State.Pos = sc.NextPos()
|
|
}
|
|
if sc.State.Pos == EOS {
|
|
return EOS
|
|
}
|
|
return int(sc.src[sc.State.Pos])
|
|
}
|
|
|
|
func (sc *scanner) CurrentPos() int {
|
|
return sc.State.Pos
|
|
}
|
|
|
|
func (sc *scanner) NextPos() int {
|
|
if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 {
|
|
return EOS
|
|
}
|
|
if !sc.State.started {
|
|
return 0
|
|
} else {
|
|
return sc.State.Pos + 1
|
|
}
|
|
}
|
|
|
|
func (sc *scanner) Peek() int {
|
|
cureof := sc.State.Pos == EOS
|
|
ch := sc.Next()
|
|
if !cureof {
|
|
if sc.State.Pos == EOS {
|
|
sc.State.Pos = len(sc.src) - 1
|
|
} else {
|
|
sc.State.Pos--
|
|
if sc.State.Pos < 0 {
|
|
sc.State.Pos = 0
|
|
sc.State.started = false
|
|
}
|
|
}
|
|
}
|
|
return ch
|
|
}
|
|
|
|
func (sc *scanner) Save() { sc.saved = sc.State }
|
|
|
|
func (sc *scanner) Restore() { sc.State = sc.saved }
|
|
|
|
/* }}} */
|
|
|
|
/* bytecode {{{ */
|
|
|
|
type opCode int
|
|
|
|
const (
|
|
opChar opCode = iota
|
|
opMatch
|
|
opTailMatch
|
|
opJmp
|
|
opSplit
|
|
opSave
|
|
opPSave
|
|
opBrace
|
|
opNumber
|
|
)
|
|
|
|
type inst struct {
|
|
OpCode opCode
|
|
Class class
|
|
Operand1 int
|
|
Operand2 int
|
|
}
|
|
|
|
/* }}} */
|
|
|
|
/* classes {{{ */
|
|
|
|
type class interface {
|
|
Matches(ch int) bool
|
|
}
|
|
|
|
type dotClass struct{}
|
|
|
|
func (pn *dotClass) Matches(ch int) bool { return true }
|
|
|
|
type charClass struct {
|
|
Ch int
|
|
}
|
|
|
|
func (pn *charClass) Matches(ch int) bool { return pn.Ch == ch }
|
|
|
|
type singleClass struct {
|
|
Class int
|
|
}
|
|
|
|
func (pn *singleClass) Matches(ch int) bool {
|
|
ret := false
|
|
switch pn.Class {
|
|
case 'a', 'A':
|
|
ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
|
|
case 'c', 'C':
|
|
ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F
|
|
case 'd', 'D':
|
|
ret = '0' <= ch && ch <= '9'
|
|
case 'l', 'L':
|
|
ret = 'a' <= ch && ch <= 'z'
|
|
case 'p', 'P':
|
|
ret = (0x21 <= ch && ch <= 0x2f) || (0x30 <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e)
|
|
case 's', 'S':
|
|
switch ch {
|
|
case ' ', '\f', '\n', '\r', '\t', '\v':
|
|
ret = true
|
|
}
|
|
case 'u', 'U':
|
|
ret = 'A' <= ch && ch <= 'Z'
|
|
case 'w', 'W':
|
|
ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
|
|
case 'x', 'X':
|
|
ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
|
|
case 'z', 'Z':
|
|
ret = ch == 0
|
|
default:
|
|
return ch == pn.Class
|
|
}
|
|
if 'A' <= pn.Class && pn.Class <= 'Z' {
|
|
return !ret
|
|
}
|
|
return ret
|
|
}
|
|
|
|
type setClass struct {
|
|
IsNot bool
|
|
Classes []class
|
|
}
|
|
|
|
func (pn *setClass) Matches(ch int) bool {
|
|
for _, class := range pn.Classes {
|
|
if class.Matches(ch) {
|
|
return !pn.IsNot
|
|
}
|
|
}
|
|
return pn.IsNot
|
|
}
|
|
|
|
type rangeClass struct {
|
|
Begin class
|
|
End class
|
|
}
|
|
|
|
func (pn *rangeClass) Matches(ch int) bool {
|
|
switch begin := pn.Begin.(type) {
|
|
case *charClass:
|
|
end, ok := pn.End.(*charClass)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return begin.Ch <= ch && ch <= end.Ch
|
|
}
|
|
return false
|
|
}
|
|
|
|
// }}}
|
|
|
|
// patterns {{{
|
|
|
|
type pattern interface{}
|
|
|
|
type singlePattern struct {
|
|
Class class
|
|
}
|
|
|
|
type seqPattern struct {
|
|
MustHead bool
|
|
MustTail bool
|
|
Patterns []pattern
|
|
}
|
|
|
|
type repeatPattern struct {
|
|
Type int
|
|
Class class
|
|
}
|
|
|
|
type posCapPattern struct{}
|
|
|
|
type capPattern struct {
|
|
Pattern pattern
|
|
}
|
|
|
|
type numberPattern struct {
|
|
N int
|
|
}
|
|
|
|
type bracePattern struct {
|
|
Begin int
|
|
End int
|
|
}
|
|
|
|
// }}}
|
|
|
|
/* parse {{{ */
|
|
|
|
func parseClass(sc *scanner, allowset bool) class {
|
|
ch := sc.Next()
|
|
switch ch {
|
|
case '%':
|
|
return &singleClass{sc.Next()}
|
|
case '.':
|
|
if allowset {
|
|
return &dotClass{}
|
|
} else {
|
|
return &charClass{ch}
|
|
}
|
|
case '[':
|
|
if !allowset {
|
|
panic(newError(sc.CurrentPos(), "invalid '['"))
|
|
}
|
|
return parseClassSet(sc)
|
|
//case '^' '$', '(', ')', ']', '*', '+', '-', '?':
|
|
// panic(newError(sc.CurrentPos(), "invalid %c", ch))
|
|
case EOS:
|
|
panic(newError(sc.CurrentPos(), "unexpected EOS"))
|
|
default:
|
|
return &charClass{ch}
|
|
}
|
|
}
|
|
|
|
func parseClassSet(sc *scanner) class {
|
|
set := &setClass{false, []class{}}
|
|
if sc.Peek() == '^' {
|
|
set.IsNot = true
|
|
sc.Next()
|
|
}
|
|
isrange := false
|
|
for {
|
|
ch := sc.Peek()
|
|
switch ch {
|
|
case '[':
|
|
panic(newError(sc.CurrentPos(), "'[' can not be nested"))
|
|
case ']':
|
|
sc.Next()
|
|
goto exit
|
|
case EOS:
|
|
panic(newError(sc.CurrentPos(), "unexpected EOS"))
|
|
case '-':
|
|
if len(set.Classes) > 0 {
|
|
sc.Next()
|
|
isrange = true
|
|
continue
|
|
}
|
|
fallthrough
|
|
default:
|
|
set.Classes = append(set.Classes, parseClass(sc, false))
|
|
}
|
|
if isrange {
|
|
begin := set.Classes[len(set.Classes)-2]
|
|
end := set.Classes[len(set.Classes)-1]
|
|
set.Classes = set.Classes[0 : len(set.Classes)-2]
|
|
set.Classes = append(set.Classes, &rangeClass{begin, end})
|
|
isrange = false
|
|
}
|
|
}
|
|
exit:
|
|
if isrange {
|
|
set.Classes = append(set.Classes, &charClass{'-'})
|
|
}
|
|
|
|
return set
|
|
}
|
|
|
|
func parsePattern(sc *scanner, toplevel bool) *seqPattern {
|
|
pat := &seqPattern{}
|
|
if toplevel {
|
|
if sc.Peek() == '^' {
|
|
sc.Next()
|
|
pat.MustHead = true
|
|
}
|
|
}
|
|
for {
|
|
ch := sc.Peek()
|
|
switch ch {
|
|
case '%':
|
|
sc.Save()
|
|
sc.Next()
|
|
switch sc.Peek() {
|
|
case '0':
|
|
panic(newError(sc.CurrentPos(), "invalid capture index"))
|
|
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
|
pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48})
|
|
case 'b':
|
|
sc.Next()
|
|
pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()})
|
|
default:
|
|
sc.Restore()
|
|
pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
|
|
}
|
|
case '.', '[':
|
|
pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
|
|
case ']':
|
|
panic(newError(sc.CurrentPos(), "invalid ']'"))
|
|
case ')':
|
|
if toplevel {
|
|
panic(newError(sc.CurrentPos(), "invalid ')'"))
|
|
}
|
|
return pat
|
|
case '(':
|
|
sc.Next()
|
|
if sc.Peek() == ')' {
|
|
sc.Next()
|
|
pat.Patterns = append(pat.Patterns, &posCapPattern{})
|
|
} else {
|
|
ret := &capPattern{parsePattern(sc, false)}
|
|
if sc.Peek() != ')' {
|
|
panic(newError(sc.CurrentPos(), "unfinished capture"))
|
|
}
|
|
sc.Next()
|
|
pat.Patterns = append(pat.Patterns, ret)
|
|
}
|
|
case '*', '+', '-', '?':
|
|
sc.Next()
|
|
if len(pat.Patterns) > 0 {
|
|
spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern)
|
|
if ok {
|
|
pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1]
|
|
pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class})
|
|
continue
|
|
}
|
|
}
|
|
pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
|
|
case '$':
|
|
if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) {
|
|
pat.MustTail = true
|
|
} else {
|
|
pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
|
|
}
|
|
sc.Next()
|
|
case EOS:
|
|
sc.Next()
|
|
goto exit
|
|
default:
|
|
sc.Next()
|
|
pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
|
|
}
|
|
}
|
|
exit:
|
|
return pat
|
|
}
|
|
|
|
type iptr struct {
|
|
insts []inst
|
|
capture int
|
|
}
|
|
|
|
func compilePattern(p pattern, ps ...*iptr) []inst {
|
|
var ptr *iptr
|
|
toplevel := false
|
|
if len(ps) == 0 {
|
|
toplevel = true
|
|
ptr = &iptr{[]inst{inst{opSave, nil, 0, -1}}, 2}
|
|
} else {
|
|
ptr = ps[0]
|
|
}
|
|
switch pat := p.(type) {
|
|
case *singlePattern:
|
|
ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1})
|
|
case *seqPattern:
|
|
for _, cp := range pat.Patterns {
|
|
compilePattern(cp, ptr)
|
|
}
|
|
case *repeatPattern:
|
|
idx := len(ptr.insts)
|
|
switch pat.Type {
|
|
case '*':
|
|
ptr.insts = append(ptr.insts,
|
|
inst{opSplit, nil, idx + 1, idx + 3},
|
|
inst{opChar, pat.Class, -1, -1},
|
|
inst{opJmp, nil, idx, -1})
|
|
case '+':
|
|
ptr.insts = append(ptr.insts,
|
|
inst{opChar, pat.Class, -1, -1},
|
|
inst{opSplit, nil, idx, idx + 2})
|
|
case '-':
|
|
ptr.insts = append(ptr.insts,
|
|
inst{opSplit, nil, idx + 3, idx + 1},
|
|
inst{opChar, pat.Class, -1, -1},
|
|
inst{opJmp, nil, idx, -1})
|
|
case '?':
|
|
ptr.insts = append(ptr.insts,
|
|
inst{opSplit, nil, idx + 1, idx + 2},
|
|
inst{opChar, pat.Class, -1, -1})
|
|
}
|
|
case *posCapPattern:
|
|
ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1})
|
|
ptr.capture += 2
|
|
case *capPattern:
|
|
c0, c1 := ptr.capture, ptr.capture+1
|
|
ptr.capture += 2
|
|
ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1})
|
|
compilePattern(pat.Pattern, ptr)
|
|
ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1})
|
|
case *bracePattern:
|
|
ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End})
|
|
case *numberPattern:
|
|
ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1})
|
|
}
|
|
if toplevel {
|
|
if p.(*seqPattern).MustTail {
|
|
ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1})
|
|
}
|
|
ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1})
|
|
}
|
|
return ptr.insts
|
|
}
|
|
|
|
/* }}} parse */
|
|
|
|
/* VM {{{ */
|
|
|
|
// Simple recursive virtual machine based on the
|
|
// "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html)
|
|
func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) {
|
|
var m *MatchData
|
|
if len(ms) == 0 {
|
|
m = newMatchState()
|
|
} else {
|
|
m = ms[0]
|
|
}
|
|
redo:
|
|
inst := insts[pc]
|
|
switch inst.OpCode {
|
|
case opChar:
|
|
if sp >= len(src) || !inst.Class.Matches(int(src[sp])) {
|
|
return false, sp, m
|
|
}
|
|
pc++
|
|
sp++
|
|
goto redo
|
|
case opMatch:
|
|
return true, sp, m
|
|
case opTailMatch:
|
|
return sp >= len(src), sp, m
|
|
case opJmp:
|
|
pc = inst.Operand1
|
|
goto redo
|
|
case opSplit:
|
|
if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok {
|
|
return true, nsp, m
|
|
}
|
|
pc = inst.Operand2
|
|
goto redo
|
|
case opSave:
|
|
s := m.setCapture(inst.Operand1, sp)
|
|
if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok {
|
|
return true, nsp, m
|
|
}
|
|
m.restoreCapture(inst.Operand1, s)
|
|
return false, sp, m
|
|
case opPSave:
|
|
m.addPosCapture(inst.Operand1, sp+1)
|
|
pc++
|
|
goto redo
|
|
case opBrace:
|
|
if sp >= len(src) || int(src[sp]) != inst.Operand1 {
|
|
return false, sp, m
|
|
}
|
|
count := 1
|
|
for sp = sp + 1; sp < len(src); sp++ {
|
|
if int(src[sp]) == inst.Operand2 {
|
|
count--
|
|
}
|
|
if count == 0 {
|
|
pc++
|
|
sp++
|
|
goto redo
|
|
}
|
|
if int(src[sp]) == inst.Operand1 {
|
|
count++
|
|
}
|
|
}
|
|
return false, sp, m
|
|
case opNumber:
|
|
idx := inst.Operand1 * 2
|
|
if idx >= m.CaptureLength()-1 {
|
|
panic(newError(_UNKNOWN, "invalid capture index"))
|
|
}
|
|
capture := src[m.Capture(idx):m.Capture(idx+1)]
|
|
for i := 0; i < len(capture); i++ {
|
|
if i+sp >= len(src) || capture[i] != src[i+sp] {
|
|
return false, sp, m
|
|
}
|
|
}
|
|
pc++
|
|
sp += len(capture)
|
|
goto redo
|
|
}
|
|
panic("should not reach here")
|
|
return false, sp, m
|
|
}
|
|
|
|
/* }}} */
|
|
|
|
/* API {{{ */
|
|
|
|
func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) {
|
|
defer func() {
|
|
if v := recover(); v != nil {
|
|
if perr, ok := v.(*Error); ok {
|
|
err = perr
|
|
} else {
|
|
panic(v)
|
|
}
|
|
}
|
|
}()
|
|
pat := parsePattern(newScanner([]byte(p)), true)
|
|
insts := compilePattern(pat)
|
|
matches = []*MatchData{}
|
|
for sp := offset; sp <= len(src); {
|
|
ok, nsp, ms := recursiveVM(src, insts, 0, sp)
|
|
sp++
|
|
if ok {
|
|
if sp < nsp {
|
|
sp = nsp
|
|
}
|
|
matches = append(matches, ms)
|
|
}
|
|
if len(matches) == limit || pat.MustHead {
|
|
break
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
/* }}} */
|