Skip to content

Commit

Permalink
syntax: initial support for recovering from missing tokens
Browse files Browse the repository at this point in the history
This helps with the most basic forms of incomplete shell, such as

    (foo |

being an incomplete version of a complete shell input like

    (foo | bar)

In particular, this is helpful when writing interactive shell prompts,
as they may want to provide tab completion even if the shell written
so far is not complete yet.
  • Loading branch information
mvdan committed Dec 21, 2024
1 parent 71f43b4 commit c6175e6
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 17 deletions.
18 changes: 14 additions & 4 deletions cmd/shfmt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ var (
diff = &multiFlag[bool]{"d", "diff", false}
applyIgnore = &multiFlag[bool]{"", "apply-ignore", false}

lang = &multiFlag[syntax.LangVariant]{"ln", "language-dialect", syntax.LangAuto}
posix = &multiFlag[bool]{"p", "posix", false}
filename = &multiFlag[string]{"", "filename", ""}
lang = &multiFlag[syntax.LangVariant]{"ln", "language-dialect", syntax.LangAuto}
posix = &multiFlag[bool]{"p", "posix", false}
filename = &multiFlag[string]{"", "filename", ""}
expRecover = &multiFlag[int]{"", "exp.recover", 0}

indent = &multiFlag[uint]{"i", "indent", 0}
binNext = &multiFlag[bool]{"bn", "binary-next-line", false}
Expand All @@ -81,7 +82,7 @@ var (

allFlags = []any{
versionFlag, list, write, simplify, minify, find, diff, applyIgnore,
lang, posix, filename,
lang, posix, filename, expRecover,
indent, binNext, caseIndent, spaceRedirs, keepPadding, funcNext, toJSON, fromJSON,
}
)
Expand Down Expand Up @@ -113,6 +114,13 @@ func init() {
if name := f.long; name != "" {
flag.StringVar(&f.val, name, f.val, "")
}
case *multiFlag[int]:
if name := f.short; name != "" {
flag.IntVar(&f.val, name, f.val, "")
}
if name := f.long; name != "" {
flag.IntVar(&f.val, name, f.val, "")
}
case *multiFlag[uint]:
if name := f.short; name != "" {
flag.UintVar(&f.val, name, f.val, "")
Expand Down Expand Up @@ -227,6 +235,8 @@ For more information and to report bugs, see https://github.com/mvdan/sh.
parser = syntax.NewParser(syntax.KeepComments(true))
printer = syntax.NewPrinter(syntax.Minify(minify.val))

syntax.RecoverErrors(expRecover.val)(parser)

if !useEditorConfig {
if posix.val {
// -p equals -ln=posix
Expand Down
48 changes: 40 additions & 8 deletions syntax/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,16 @@ type Pos struct {
offs, lineCol uint32
}

// We used to split line and column numbers evenly in 16 bits, but line numbers
// are significantly more important in practice. Use more bits for them.
const (
offsetMax = math.MaxUint32
// Offsets use 32 bits for a reasonable amount of precision.
// We reserve a few of the highest values to represent types of invalid positions.
// We leave some space before the real uint32 maximum so that we can easily detect
// when arithmetic on invalid positions is done by mistake.
offsetRecovered = math.MaxUint32 - 10
offsetMax = math.MaxUint32 - 11

// We used to split line and column numbers evenly in 16 bits, but line numbers
// are significantly more important in practice. Use more bits for them.

lineBitSize = 18
lineMax = (1 << lineBitSize) - 1
Expand Down Expand Up @@ -109,20 +115,26 @@ func NewPos(offset, line, column uint) Pos {
}

// Offset returns the byte offset of the position in the original source file.
// Byte offsets start at 0.
// Byte offsets start at 0. Invalid positions always report the offset 0.
//
// Offset has basic protection against overflows; if an input is too large,
// offset numbers will stop increasing past a very large number.
func (p Pos) Offset() uint { return uint(p.offs) }
func (p Pos) Offset() uint {
if p.offs > offsetMax {
return 0 // invalid
}
return uint(p.offs)
}

// Line returns the line number of the position, starting at 1.
// Invalid positions always report the line number 0.
//
// Line is protected against overflows; if an input has too many lines, extra
// lines will have a line number of 0, rendered as "?" by [Pos.String].
func (p Pos) Line() uint { return uint(p.lineCol >> colBitSize) }

// Col returns the column number of the position, starting at 1. It counts in
// bytes.
// bytes. Invalid positions always report the column number 0.
//
// Col is protected against overflows; if an input line has too many columns,
// extra columns will have a column number of 0, rendered as "?" by [Pos.String].
Expand All @@ -147,13 +159,33 @@ func (p Pos) String() string {
// IsValid reports whether the position contains useful position information.
// Some positions returned via [Parse] may be invalid: for example, [Stmt.Semicolon]
// will only be valid if a statement contained a closing token such as ';'.
func (p Pos) IsValid() bool { return p != Pos{} }
//
// Recovered positions, as reported by [Pos.IsRecovered], are not considered valid
// given that they don't contain position information.
func (p Pos) IsValid() bool {
return p.offs <= offsetMax && p.lineCol != 0
}

var recoveredPos = Pos{offs: offsetRecovered}

// IsRecovered reports whether the position that the token or node belongs to
// was missing in the original input and recovered via [RecoverErrors].
func (p Pos) IsRecovered() bool { return p == recoveredPos }

// After reports whether the position p is after p2. It is a more expressive
// version of p.Offset() > p2.Offset().
func (p Pos) After(p2 Pos) bool { return p.offs > p2.offs }
// It always returns false if p is an invalid position.
func (p Pos) After(p2 Pos) bool {
if !p.IsValid() {
return false
}
return p.offs > p2.offs
}

func posAddCol(p Pos, n int) Pos {
if !p.IsValid() {
return p
}
// TODO: guard against overflows
p.lineCol += uint32(n)
p.offs += uint32(n)
Expand Down
77 changes: 72 additions & 5 deletions syntax/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,26 @@ func StopAt(word string) ParserOption {
return func(p *Parser) { p.stopAt = []byte(word) }
}

// RecoverErrors allows the parser to skip up to a maximum number of
// errors in the given input on a best-effort basis.
// This can be useful to tab-complete an interactive shell prompt,
// or when providing diagnostics on slightly incomplete shell source.
//
// Currently, this only helps with mandatory tokens from the shell grammar
// which are not present in the input. They result in position fields
// or nodes whose position report [Pos.IsRecovered] as true.
//
// For example, given the input
//
// (foo |
//
// the result will contain two recovered positions; first, the pipe requires
// a statement to follow, and as [Stmt.Pos] reports, the entire node is recovered.
// Second, the subshell needs to be closed, so [Subshell.Rparen] is recovered.
func RecoverErrors(maximum int) ParserOption {
return func(p *Parser) { p.recoverErrorsMax = maximum }
}

// NewParser allocates a new [Parser] and applies any number of options.
func NewParser(options ...ParserOption) *Parser {
p := &Parser{}
Expand Down Expand Up @@ -364,6 +384,9 @@ type Parser struct {

stopAt []byte

recoveredErrors int
recoverErrorsMax int

forbidNested bool

// list of pending heredoc bodies
Expand Down Expand Up @@ -422,6 +445,7 @@ func (p *Parser) reset() {
p.err, p.readErr = nil, nil
p.quote, p.forbidNested = noState, false
p.openStmts = 0
p.recoveredErrors = 0
p.heredocs, p.buriedHdocs = p.heredocs[:0], 0
p.hdocStops = nil
p.parsingDoc = false
Expand Down Expand Up @@ -649,6 +673,14 @@ func (p *Parser) gotRsrv(val string) (Pos, bool) {
return pos, false
}

func (p *Parser) recoverError() bool {
if p.recoveredErrors < p.recoverErrorsMax {
p.recoveredErrors++
return true
}
return false
}

func readableStr(s string) string {
// don't quote tokens like & or }
if s != "" && s[0] >= 'a' && s[0] <= 'z' {
Expand All @@ -675,6 +707,9 @@ func (p *Parser) follow(lpos Pos, left string, tok token) {
func (p *Parser) followRsrv(lpos Pos, left, val string) Pos {
pos, ok := p.gotRsrv(val)
if !ok {
if p.recoverError() {
return recoveredPos
}
p.followErr(lpos, left, fmt.Sprintf("%q", val))
}
return pos
Expand All @@ -687,6 +722,9 @@ func (p *Parser) followStmts(left string, lpos Pos, stops ...string) ([]*Stmt, [
newLine := p.got(_Newl)
stmts, last := p.stmtList(stops...)
if len(stmts) < 1 && !newLine {
if p.recoverError() {
return []*Stmt{{Position: recoveredPos}}, nil
}
p.followErr(lpos, left, "a statement list")
}
return stmts, last
Expand All @@ -695,6 +733,9 @@ func (p *Parser) followStmts(left string, lpos Pos, stops ...string) ([]*Stmt, [
func (p *Parser) followWordTok(tok token, pos Pos) *Word {
w := p.getWord()
if w == nil {
if p.recoverError() {
return p.wordOne(&Lit{ValuePos: recoveredPos})
}
p.followErr(pos, tok.String(), "a word")
}
return w
Expand All @@ -703,6 +744,9 @@ func (p *Parser) followWordTok(tok token, pos Pos) *Word {
func (p *Parser) stmtEnd(n Node, start, end string) Pos {
pos, ok := p.gotRsrv(end)
if !ok {
if p.recoverError() {
return recoveredPos
}
p.posErr(n.Pos(), "%s statement must end with %q", start, end)
}
return pos
Expand All @@ -721,6 +765,9 @@ func (p *Parser) matchingErr(lpos Pos, left, right any) {
func (p *Parser) matched(lpos Pos, left, right token) Pos {
pos := p.pos
if !p.got(right) {
if p.recoverError() {
return recoveredPos
}
p.matchingErr(lpos, left, right)
}
return pos
Expand Down Expand Up @@ -1107,6 +1154,10 @@ func (p *Parser) wordPart() WordPart {
p.litBs = append(p.litBs, '\\', '\n')
case utf8.RuneSelf:
p.tok = _EOF
if p.recoverError() {
sq.Right = recoveredPos
return sq
}
p.quoteErr(sq.Pos(), sglQuote)
return nil
}
Expand Down Expand Up @@ -1144,7 +1195,11 @@ func (p *Parser) wordPart() WordPart {
// Like above, the lexer didn't call p.rune for us.
p.rune()
if !p.got(bckQuote) {
p.quoteErr(cs.Pos(), bckQuote)
if p.recoverError() {
cs.Right = recoveredPos
} else {
p.quoteErr(cs.Pos(), bckQuote)
}
}
return cs
case globQuest, globStar, globPlus, globAt, globExcl:
Expand Down Expand Up @@ -1194,7 +1249,11 @@ func (p *Parser) dblQuoted() *DblQuoted {
p.quote = old
q.Right = p.pos
if !p.got(dblQuote) {
p.quoteErr(q.Pos(), dblQuote)
if p.recoverError() {
q.Right = recoveredPos
} else {
p.quoteErr(q.Pos(), dblQuote)
}
}
return q
}
Expand Down Expand Up @@ -1661,6 +1720,9 @@ func (p *Parser) getStmt(readEnd, binCmd, fnBody bool) *Stmt {
p.got(_Newl)
b.Y = p.getStmt(false, true, false)
if b.Y == nil || p.err != nil {
if p.recoverError() {
return &Stmt{Position: recoveredPos}
}
p.followErr(b.OpPos, b.Op.String(), "a statement")
return nil
}
Expand Down Expand Up @@ -1834,6 +1896,9 @@ func (p *Parser) gotStmtPipe(s *Stmt, binCmd bool) *Stmt {
p.next()
p.got(_Newl)
if b.Y = p.gotStmtPipe(&Stmt{Position: p.pos}, true); b.Y == nil || p.err != nil {
if p.recoverError() {
return &Stmt{Position: recoveredPos}
}
p.followErr(b.OpPos, b.Op.String(), "a statement")
break
}
Expand Down Expand Up @@ -1876,9 +1941,11 @@ func (p *Parser) block(s *Stmt) {
b := &Block{Lbrace: p.pos}
p.next()
b.Stmts, b.Last = p.stmtList("}")
pos, ok := p.gotRsrv("}")
b.Rbrace = pos
if !ok {
if pos, ok := p.gotRsrv("}"); ok {
b.Rbrace = pos
} else if p.recoverError() {
b.Rbrace = recoveredPos
} else {
p.matchingErr(b.Lbrace, "{", "}")
}
s.Cmd = b
Expand Down
3 changes: 3 additions & 0 deletions syntax/parser_arithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ func (p *Parser) matchedArithm(lpos Pos, left, right token) {

func (p *Parser) arithmEnd(ltok token, lpos Pos, old saveState) Pos {
if !p.peekArithmEnd() {
if p.recoverError() {
return recoveredPos
}
p.arithmMatchingErr(lpos, ltok, dblRightParen)
}
p.rune()
Expand Down
Loading

0 comments on commit c6175e6

Please sign in to comment.