diff --git a/argh.go b/argh.go index 085cc76..4077aa6 100644 --- a/argh.go +++ b/argh.go @@ -1,5 +1,35 @@ package argh +import ( + "log" + "os" +) + +// NOTE: much of this is lifted from +// https://blog.gopheracademy.com/advent-2014/parsers-lexers/ + +var ( + tracingEnabled = os.Getenv("ARGH_TRACING") == "enabled" +) + type Argh struct { - AST *AST + ParseTree *ParseTree `json:"parse_tree"` +} + +func (a *Argh) AST() []TypedNode { + return a.ParseTree.toAST() +} + +/* +func (a *Argh) String() string { + return a.ParseTree.String() +} +*/ + +func tracef(format string, v ...any) { + if !tracingEnabled { + return + } + + log.Printf(format, v...) } diff --git a/ast.go b/ast.go deleted file mode 100644 index 30834f8..0000000 --- a/ast.go +++ /dev/null @@ -1,10 +0,0 @@ -package argh - -type AST struct { - Nodes []*Node `json:"nodes"` -} - -type Node struct { - Token string `json:"token"` - Literal string `json:"literal"` -} diff --git a/cmd/argh/main.go b/cmd/argh/main.go index 2b77e52..bd9b32c 100644 --- a/cmd/argh/main.go +++ b/cmd/argh/main.go @@ -10,7 +10,9 @@ import ( ) func main() { - ast, err := argh.ParseArgs(os.Args) + log.SetFlags(0) + + ast, err := argh.ParseArgs(os.Args, nil) if err != nil { log.Fatal(err) } diff --git a/go.mod b/go.mod index f70b654..7091c9a 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,10 @@ module git.meatballhat.com/x/box-o-sand/argh go 1.18 require github.com/pkg/errors v0.9.1 + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.7.1 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/go.sum b/go.sum index 7c401c3..842edf5 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,12 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/parse_tree.go b/parse_tree.go new file mode 100644 index 0000000..7ffb182 --- /dev/null +++ b/parse_tree.go @@ -0,0 +1,82 @@ +package argh + +import "fmt" + +type ParseTree struct { + Nodes []Node `json:"nodes"` +} + +func (pt *ParseTree) toAST() []TypedNode { + ret := []TypedNode{} + + for _, node := range pt.Nodes { + if _, ok := node.(ArgDelimiter); ok { + continue + } + + if _, ok := node.(StopFlag); ok { + continue + } + + ret = append( + ret, + TypedNode{ + Type: fmt.Sprintf("%T", node), + Node: node, + }, + ) + } + + return ret +} + +type Node interface{} + +type TypedNode struct { + Type string `json:"type"` + Node Node `json:"node"` +} + +type Args struct { + Pos int `json:"pos"` + Nodes []Node `json:"nodes"` +} + +type Statement struct { + Pos int `json:"pos"` + Nodes []Node `json:"nodes"` +} + +type Program struct { + Pos int `json:"pos"` + Name string `json:"name"` +} + +type Ident struct { + Pos int `json:"pos"` + Literal string `json:"literal"` +} + +type Command struct { + Pos int `json:"pos"` + Name string `json:"name"` + Nodes []Node `json:"nodes"` +} + +type Flag struct { + Pos int `json:"pos"` + Name string `json:"name"` + Value *string `json:"value,omitempty"` +} + +type StdinFlag struct { + Pos int `json:"pos"` +} + +type StopFlag struct { + Pos int `json:"pos"` +} + +type ArgDelimiter struct { + Pos int `json:"pos"` +} diff --git a/parser.go b/parser.go index cf74b87..211ce9c 100644 --- a/parser.go +++ b/parser.go @@ -7,76 +7,220 @@ import ( "github.com/pkg/errors" ) -// NOTE: much of this is lifted from -// https://blog.gopheracademy.com/advent-2014/parsers-lexers/ - var ( errSyntax = errors.New("syntax error") + + DefaultParserConfig = &ParserConfig{ + Commands: []string{}, + ValueFlags: []string{}, + ScannerConfig: DefaultScannerConfig, + } ) -func ParseArgs(args []string) (*Argh, error) { +func ParseArgs(args []string, pCfg *ParserConfig) (*Argh, error) { reEncoded := strings.Join(args, string(nul)) return NewParser( strings.NewReader(reEncoded), - nil, + pCfg, ).Parse() } type Parser struct { s *Scanner buf ParserBuffer + + commands map[string]struct{} + valueFlags map[string]struct{} + + nodes []Node + stopSeen bool } type ParserBuffer struct { tok Token lit string + pos int n int } -func NewParser(r io.Reader, cfg *ScannerConfig) *Parser { - return &Parser{s: NewScanner(r, cfg)} +type ParserConfig struct { + Commands []string + ValueFlags []string + ScannerConfig *ScannerConfig +} + +type parseDirective struct { + Break bool +} + +func NewParser(r io.Reader, pCfg *ParserConfig) *Parser { + if pCfg == nil { + pCfg = DefaultParserConfig + } + + parser := &Parser{ + s: NewScanner(r, pCfg.ScannerConfig), + commands: map[string]struct{}{}, + valueFlags: map[string]struct{}{}, + } + + for _, command := range pCfg.Commands { + parser.commands[command] = struct{}{} + } + + for _, valueFlag := range pCfg.ValueFlags { + parser.valueFlags[valueFlag] = struct{}{} + } + + tracef("NewParser parser=%+#v", parser) + tracef("NewParser pCfg=%+#v", pCfg) + + return parser } func (p *Parser) Parse() (*Argh, error) { - arghOut := &Argh{ - AST: &AST{ - Nodes: []*Node{}, - }, - } + p.nodes = []Node{} for { - tok, lit := p.scan() - if tok == ILLEGAL { - return nil, errors.Wrapf(errSyntax, "illegal value %q", lit) + pd, err := p.parseArg() + if err != nil { + return nil, err } - if tok == EOL { + if pd != nil && pd.Break { break } - - arghOut.AST.Nodes = append( - arghOut.AST.Nodes, - &Node{Token: tok.String(), Literal: lit}, - ) } - return arghOut, nil + return &Argh{ParseTree: &ParseTree{Nodes: p.nodes}}, nil } -func (p *Parser) scan() (Token, string) { - if p.buf.n != 0 { - p.buf.n = 0 - return p.buf.tok, p.buf.lit +func (p *Parser) parseArg() (*parseDirective, error) { + tok, lit, pos := p.scan() + if tok == ILLEGAL { + return nil, errors.Wrapf(errSyntax, "illegal value %q at pos=%v", lit, pos) } - tok, lit := p.s.Scan() + if tok == EOL { + return &parseDirective{Break: true}, nil + } - p.buf.tok, p.buf.lit = tok, lit + p.unscan() - return tok, lit + node, err := p.nodify() + + tracef("parseArg node=%+#v err=%+#v", node, err) + + if err != nil { + return nil, errors.Wrapf(err, "value %q at pos=%v", lit, pos) + } + + if node != nil { + p.nodes = append(p.nodes, node) + } + + return nil, nil +} + +func (p *Parser) nodify() (Node, error) { + tok, lit, pos := p.scan() + + tracef("nodify tok=%s lit=%q pos=%v", tok, lit, pos) + + switch tok { + case IDENT: + if len(p.nodes) == 0 { + return Program{Name: lit, Pos: pos - len(lit)}, nil + } + return Ident{Literal: lit, Pos: pos - len(lit)}, nil + case ARG_DELIMITER: + return ArgDelimiter{Pos: pos - 1}, nil + case COMPOUND_SHORT_FLAG: + flagNodes := []Node{} + + for i, r := range lit[1:] { + flagNodes = append( + flagNodes, + Flag{ + Pos: pos + i + 1, + Name: string(r), + }, + ) + } + + return Statement{Pos: pos, Nodes: flagNodes}, nil + case SHORT_FLAG: + flagName := string(lit[1:]) + if _, ok := p.valueFlags[flagName]; ok { + return p.scanValueFlag(flagName, pos) + } + + return Flag{Name: flagName, Pos: pos - len(flagName) - 1}, nil + case LONG_FLAG: + flagName := string(lit[2:]) + if _, ok := p.valueFlags[flagName]; ok { + return p.scanValueFlag(flagName, pos) + } + + return Flag{Name: flagName, Pos: pos - len(flagName) - 2}, nil + default: + } + + return Ident{Literal: lit, Pos: pos - len(lit)}, nil +} + +func (p *Parser) scanValueFlag(flagName string, pos int) (Node, error) { + tracef("scanValueFlag flagName=%q pos=%v", flagName, pos) + + lit, err := p.scanIdent() + if err != nil { + return nil, err + } + + flagSepLen := len("--") + 1 + + return Flag{Name: flagName, Pos: pos - len(lit) - flagSepLen, Value: ptr(lit)}, nil +} + +func (p *Parser) scanIdent() (string, error) { + tok, lit, pos := p.scan() + + nUnscan := 0 + + if tok == ASSIGN || tok == ARG_DELIMITER { + nUnscan++ + tok, lit, pos = p.scan() + } + + if tok == IDENT { + return lit, nil + } + + for i := 0; i < nUnscan; i++ { + p.unscan() + } + + return "", errors.Wrapf(errSyntax, "expected ident at pos=%v but got %s (%q)", pos, tok, lit) +} + +func (p *Parser) scan() (Token, string, int) { + if p.buf.n != 0 { + p.buf.n = 0 + return p.buf.tok, p.buf.lit, p.buf.pos + } + + tok, lit, pos := p.s.Scan() + + p.buf.tok, p.buf.lit, p.buf.pos = tok, lit, pos + + return tok, lit, pos } func (p *Parser) unscan() { p.buf.n = 1 } + +func ptr[T any](v T) *T { + return &v +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..0f95be1 --- /dev/null +++ b/parser_test.go @@ -0,0 +1,119 @@ +package argh_test + +import ( + "testing" + + "git.meatballhat.com/x/box-o-sand/argh" + "github.com/stretchr/testify/assert" +) + +func ptr[T any](v T) *T { + return &v +} + +func TestParser(t *testing.T) { + for _, tc := range []struct { + name string + args []string + cfg *argh.ParserConfig + expected *argh.Argh + expectedErr error + skip bool + }{ + { + name: "bare", + args: []string{"pizzas"}, + expected: &argh.Argh{ + ParseTree: &argh.ParseTree{ + Nodes: []argh.Node{ + argh.Program{Name: "pizzas"}, + }, + }, + }, + }, + { + name: "long value-less flags", + args: []string{"pizzas", "--tasty", "--fresh", "--super-hot-right-now"}, + expected: &argh.Argh{ + ParseTree: &argh.ParseTree{ + Nodes: []argh.Node{ + argh.Program{Name: "pizzas", Pos: 0}, + argh.ArgDelimiter{Pos: 6}, + argh.Flag{Name: "tasty", Pos: 7}, + argh.ArgDelimiter{Pos: 14}, + argh.Flag{Name: "fresh", Pos: 15}, + argh.ArgDelimiter{Pos: 22}, + argh.Flag{Name: "super-hot-right-now", Pos: 23}, + }, + }, + }, + }, + { + name: "long flags mixed", + args: []string{"pizzas", "--tasty", "--fresh", "soon", "--super-hot-right-now"}, + cfg: &argh.ParserConfig{ + Commands: []string{}, + ValueFlags: []string{"fresh"}, + }, + expected: &argh.Argh{ + ParseTree: &argh.ParseTree{ + Nodes: []argh.Node{ + argh.Program{Name: "pizzas", Pos: 0}, + argh.ArgDelimiter{Pos: 6}, + argh.Flag{Name: "tasty", Pos: 7}, + argh.ArgDelimiter{Pos: 14}, + argh.Flag{Name: "fresh", Pos: 15, Value: ptr("soon")}, + argh.ArgDelimiter{Pos: 27}, + argh.Flag{Name: "super-hot-right-now", Pos: 28}, + }, + }, + }, + }, + { + skip: true, + + name: "typical", + args: []string{"pizzas", "-a", "--ca", "-b", "1312", "-lol"}, + cfg: &argh.ParserConfig{ + Commands: []string{}, + ValueFlags: []string{"b"}, + }, + expected: &argh.Argh{ + ParseTree: &argh.ParseTree{ + Nodes: []argh.Node{ + argh.Program{Name: "pizzas", Pos: 0}, + argh.ArgDelimiter{Pos: 6}, + argh.Flag{Name: "a", Pos: 7}, + argh.ArgDelimiter{Pos: 9}, + argh.Flag{Name: "ca", Pos: 10}, + argh.ArgDelimiter{Pos: 14}, + argh.Flag{Name: "b", Pos: 15, Value: ptr("1312")}, + argh.ArgDelimiter{Pos: 22}, + argh.Statement{ + Pos: 23, + Nodes: []argh.Node{ + argh.Flag{Name: "l", Pos: 29}, + argh.Flag{Name: "o", Pos: 30}, + argh.Flag{Name: "l", Pos: 31}, + }, + }, + }, + }, + }, + }, + } { + if tc.skip { + continue + } + + t.Run(tc.name, func(ct *testing.T) { + actual, err := argh.ParseArgs(tc.args, tc.cfg) + if err != nil { + assert.ErrorIs(ct, err, tc.expectedErr) + return + } + + assert.Equal(ct, tc.expected, actual) + }) + } +} diff --git a/scanner.go b/scanner.go index 02e1267..cc24842 100644 --- a/scanner.go +++ b/scanner.go @@ -1,8 +1,5 @@ package argh -// NOTE: much of this is lifted from -// https://blog.gopheracademy.com/advent-2014/parsers-lexers/ - import ( "bufio" "bytes" @@ -27,6 +24,7 @@ var ( type Scanner struct { r *bufio.Reader + i int cfg *ScannerConfig } @@ -34,8 +32,6 @@ type ScannerConfig struct { AssignmentOperator rune FlagPrefix rune MultiValueDelim rune - - Commands []string } func NewScanner(r io.Reader, cfg *ScannerConfig) *Scanner { @@ -49,52 +45,56 @@ func NewScanner(r io.Reader, cfg *ScannerConfig) *Scanner { } } -func (s *Scanner) Scan() (Token, string) { - ch := s.read() +func (s *Scanner) Scan() (Token, string, int) { + ch, pos := s.read() if s.isBlankspace(ch) { - s.unread() + _ = s.unread() return s.scanBlankspace() } if s.isAssignmentOperator(ch) { - return ASSIGN, string(ch) + return ASSIGN, string(ch), pos } if s.isMultiValueDelim(ch) { - return MULTI_VALUE_DELIMITER, string(ch) + return MULTI_VALUE_DELIMITER, string(ch), pos } if ch == eol { - return EOL, "" + return EOL, "", pos } if ch == nul { - return ARG_DELIMITER, string(ch) + return ARG_DELIMITER, string(ch), pos } if unicode.IsGraphic(ch) { - s.unread() + _ = s.unread() return s.scanArg() } - return ILLEGAL, string(ch) + return ILLEGAL, string(ch), pos } -func (s *Scanner) read() rune { +func (s *Scanner) read() (rune, int) { ch, _, err := s.r.ReadRune() + s.i++ + if errors.Is(err, io.EOF) { - return eol + return eol, s.i } else if err != nil { log.Printf("unknown scanner error=%+v", err) - return eol + return eol, s.i } - return ch + return ch, s.i } -func (s *Scanner) unread() { +func (s *Scanner) unread() int { _ = s.r.UnreadRune() + s.i-- + return s.i } func (s *Scanner) isBlankspace(ch rune) bool { @@ -117,33 +117,37 @@ func (s *Scanner) isAssignmentOperator(ch rune) bool { return ch == s.cfg.AssignmentOperator } -func (s *Scanner) scanBlankspace() (Token, string) { +func (s *Scanner) scanBlankspace() (Token, string, int) { buf := &bytes.Buffer{} - buf.WriteRune(s.read()) + ch, pos := s.read() + buf.WriteRune(ch) for { - if ch := s.read(); ch == eol { + ch, pos = s.read() + + if ch == eol { break } else if !s.isBlankspace(ch) { - s.unread() + pos = s.unread() break } else { _, _ = buf.WriteRune(ch) } } - return BS, buf.String() + return BS, buf.String(), pos } -func (s *Scanner) scanArg() (Token, string) { +func (s *Scanner) scanArg() (Token, string, int) { buf := &bytes.Buffer{} - buf.WriteRune(s.read()) + ch, pos := s.read() + buf.WriteRune(ch) for { - ch := s.read() + ch, pos = s.read() if ch == eol || ch == nul || s.isAssignmentOperator(ch) || s.isMultiValueDelim(ch) { - s.unread() + pos = s.unread() break } @@ -153,38 +157,38 @@ func (s *Scanner) scanArg() (Token, string) { str := buf.String() if len(str) == 0 { - return EMPTY, str + return EMPTY, str, pos } ch0 := rune(str[0]) if len(str) == 1 { if s.isFlagPrefix(ch0) { - return STDIN_FLAG, str + return STDIN_FLAG, str, pos } - return IDENT, str + return IDENT, str, pos } ch1 := rune(str[1]) if len(str) == 2 { if str == string(s.cfg.FlagPrefix)+string(s.cfg.FlagPrefix) { - return STOP_FLAG, str + return STOP_FLAG, str, pos } if s.isFlagPrefix(ch0) { - return SHORT_FLAG, str + return SHORT_FLAG, str, pos } } if s.isFlagPrefix(ch0) { if s.isFlagPrefix(ch1) { - return LONG_FLAG, str + return LONG_FLAG, str, pos } - return COMPOUND_SHORT_FLAG, str + return COMPOUND_SHORT_FLAG, str, pos } - return IDENT, str + return IDENT, str, pos }