diff --git a/argh.go b/argh.go index 87a6476..f1cc2f9 100644 --- a/argh.go +++ b/argh.go @@ -21,18 +21,6 @@ func init() { traceLogger = log.New(os.Stderr, "ARGH TRACING: ", 0) } -type Argh struct { - ParseTree *ParseTree `json:"parse_tree"` -} - -func (a *Argh) TypedAST() []TypedNode { - return a.ParseTree.typedAST() -} - -func (a *Argh) AST() []Node { - return a.ParseTree.ast() -} - func tracef(format string, v ...any) { if !tracingEnabled { return diff --git a/node.go b/node.go index f51f48c..270be73 100644 --- a/node.go +++ b/node.go @@ -16,8 +16,9 @@ type CompoundShortFlag struct { } type Program struct { - Name string `json:"name"` - Values []string `json:"values"` + Name string `json:"name"` + Values map[string]string `json:"values"` + Nodes []Node `json:"nodes"` } type Ident struct { @@ -25,13 +26,14 @@ type Ident struct { } type Command struct { - Name string `json:"name"` - Values []string `json:"values"` + Name string `json:"name"` + Values map[string]string `json:"values"` + Nodes []Node `json:"nodes"` } type Flag struct { - Name string `json:"name"` - Values []string `json:"values"` + Name string `json:"name"` + Values map[string]string `json:"values"` } type StdinFlag struct{} diff --git a/parser.go b/parser.go index 47c6fca..dc7409f 100644 --- a/parser.go +++ b/parser.go @@ -3,6 +3,7 @@ package argh import ( + "fmt" "io" "strings" @@ -19,15 +20,15 @@ var ( ErrSyntax = errors.New("syntax error") DefaultParserConfig = &ParserConfig{ - Commands: map[string]NValue{}, - Flags: map[string]NValue{}, + Commands: map[string]CommandConfig{}, + Flags: map[string]FlagConfig{}, ScannerConfig: DefaultScannerConfig, } ) type NValue int -func ParseArgs(args []string, pCfg *ParserConfig) (*Argh, error) { +func ParseArgs(args []string, pCfg *ParserConfig) (*ParseTree, error) { reEncoded := strings.Join(args, string(nul)) return NewParser( @@ -39,38 +40,50 @@ func ParseArgs(args []string, pCfg *ParserConfig) (*Argh, error) { type Parser struct { s *Scanner - buf []ScanEntry + buf []scanEntry cfg *ParserConfig - nodes []Node - stopSeen bool + nodes []Node + node Node } -type ScanEntry struct { +type ParseTree struct { + Nodes []Node `json:"nodes"` +} + +type scanEntry struct { tok Token lit string pos int } type ParserConfig struct { - ProgValues NValue - Commands map[string]NValue - Flags map[string]NValue - - OnUnknownFlag func(string) error - OnUnknownCommand func(string) error + Prog CommandConfig + Commands map[string]CommandConfig + Flags map[string]FlagConfig ScannerConfig *ScannerConfig } +type CommandConfig struct { + NValue NValue + ValueNames []string + Flags map[string]FlagConfig +} + +type FlagConfig struct { + NValue NValue + ValueNames []string +} + func NewParser(r io.Reader, pCfg *ParserConfig) *Parser { if pCfg == nil { pCfg = DefaultParserConfig } parser := &Parser{ - buf: []ScanEntry{}, + buf: []scanEntry{}, s: NewScanner(r, pCfg.ScannerConfig), cfg: pCfg, } @@ -81,7 +94,7 @@ func NewParser(r io.Reader, pCfg *ParserConfig) *Parser { return parser } -func (p *Parser) Parse() (*Argh, error) { +func (p *Parser) Parse() (*ParseTree, error) { p.nodes = []Node{} for { @@ -95,7 +108,7 @@ func (p *Parser) Parse() (*Argh, error) { } } - return &Argh{ParseTree: &ParseTree{Nodes: p.nodes}}, nil + return &ParseTree{Nodes: p.nodes}, nil } func (p *Parser) parseArg() (bool, error) { @@ -110,7 +123,7 @@ func (p *Parser) parseArg() (bool, error) { p.unscan(tok, lit, pos) - node, err := p.nodify() + node, err := p.scanNode() tracef("parseArg node=%+#v err=%+#v", node, err) @@ -125,10 +138,10 @@ func (p *Parser) parseArg() (bool, error) { return false, nil } -func (p *Parser) nodify() (Node, error) { +func (p *Parser) scanNode() (Node, error) { tok, lit, pos := p.scan() - tracef("nodify tok=%s lit=%q pos=%v", tok, lit, pos) + tracef("scanNode tok=%s lit=%q pos=%v", tok, lit, pos) switch tok { case ARG_DELIMITER: @@ -136,94 +149,120 @@ func (p *Parser) nodify() (Node, error) { case ASSIGN: return nil, errors.Wrapf(ErrSyntax, "bare assignment operator at pos=%v", pos) case IDENT: - if len(p.nodes) == 0 { - values, err := p.scanValues(lit, pos, p.cfg.ProgValues) - if err != nil { - return nil, err - } - - return Program{Name: lit, Values: values}, nil - } - - if n, ok := p.cfg.Commands[lit]; ok { - values, err := p.scanValues(lit, pos, n) - if err != nil { - return nil, err - } - - return Command{Name: lit, Values: values}, nil - } - - return Ident{Literal: lit}, nil + p.unscan(tok, lit, pos) + return p.scanCommandOrIdent() case COMPOUND_SHORT_FLAG: - flagNodes := []Node{} - - withoutFlagPrefix := lit[1:] - - for i, r := range withoutFlagPrefix { - if i == len(withoutFlagPrefix)-1 { - flagName := string(r) - - if n, ok := p.cfg.Flags[flagName]; ok { - values, err := p.scanValues(flagName, pos, n) - if err != nil { - return nil, err - } - - flagNodes = append(flagNodes, Flag{Name: flagName, Values: values}) - - continue - } - } - - flagNodes = append( - flagNodes, - Flag{ - Name: string(r), - }, - ) - } - - return CompoundShortFlag{Nodes: flagNodes}, nil - case SHORT_FLAG: - flagName := string(lit[1:]) - if n, ok := p.cfg.Flags[flagName]; ok { - values, err := p.scanValues(flagName, pos, n) - if err != nil { - return nil, err - } - - return Flag{Name: flagName, Values: values}, nil - } - - return Flag{Name: flagName}, nil - case LONG_FLAG: - flagName := string(lit[2:]) - if n, ok := p.cfg.Flags[flagName]; ok { - values, err := p.scanValues(flagName, pos, n) - if err != nil { - return nil, err - } - - return Flag{Name: flagName, Values: values}, nil - } - - return Flag{Name: flagName}, nil + p.unscan(tok, lit, pos) + return p.scanCompoundShortFlag() + case SHORT_FLAG, LONG_FLAG: + p.unscan(tok, lit, pos) + return p.scanFlag() default: } return Ident{Literal: lit}, nil } -func (p *Parser) scanValues(lit string, pos int, n NValue) ([]string, error) { - tracef("scanValues lit=%q pos=%v n=%v", lit, pos, n) +func (p *Parser) scanCommandOrIdent() (Node, error) { + tok, lit, pos := p.scan() - values, err := func() ([]string, error) { - if n == ZeroValue { - return []string{}, nil + if len(p.nodes) == 0 { + p.unscan(tok, lit, pos) + values, err := p.scanValues(p.cfg.Prog.NValue, p.cfg.Prog.ValueNames) + if err != nil { + return nil, err } - ret := []string{} + return Program{Name: lit, Values: values}, nil + } + + if cfg, ok := p.cfg.Commands[lit]; ok { + p.unscan(tok, lit, pos) + values, err := p.scanValues(cfg.NValue, cfg.ValueNames) + if err != nil { + return nil, err + } + + return Command{Name: lit, Values: values}, nil + } + + return Ident{Literal: lit}, nil +} + +func (p *Parser) scanFlag() (Node, error) { + tok, lit, pos := p.scan() + + flagName := string(lit[1:]) + if tok == LONG_FLAG { + flagName = string(lit[2:]) + } + + if cfg, ok := p.cfg.Flags[flagName]; ok { + p.unscan(tok, flagName, pos) + + values, err := p.scanValues(cfg.NValue, cfg.ValueNames) + if err != nil { + return nil, err + } + + return Flag{Name: flagName, Values: values}, nil + } + + return Flag{Name: flagName}, nil +} + +func (p *Parser) scanCompoundShortFlag() (Node, error) { + tok, lit, pos := p.scan() + + flagNodes := []Node{} + + withoutFlagPrefix := lit[1:] + + for i, r := range withoutFlagPrefix { + if i == len(withoutFlagPrefix)-1 { + flagName := string(r) + + if cfg, ok := p.cfg.Flags[flagName]; ok { + p.unscan(tok, flagName, pos) + + values, err := p.scanValues(cfg.NValue, cfg.ValueNames) + if err != nil { + return nil, err + } + + flagNodes = append(flagNodes, Flag{Name: flagName, Values: values}) + + continue + } + } + + flagNodes = append( + flagNodes, + Flag{ + Name: string(r), + }, + ) + } + + return CompoundShortFlag{Nodes: flagNodes}, nil +} + +func (p *Parser) scanValuesAndFlags() (map[string]string, []Node, error) { + return nil, nil, nil +} + +func (p *Parser) scanValues(n NValue, valueNames []string) (map[string]string, error) { + _, lit, pos := p.scan() + + tracef("scanValues lit=%q pos=%v n=%v valueNames=%+v", lit, pos, n, valueNames) + + values, err := func() (map[string]string, error) { + if n == ZeroValue { + return map[string]string{}, nil + } + + ret := map[string]string{} + i := 0 for { lit, err := p.scanIdent() @@ -237,11 +276,20 @@ func (p *Parser) scanValues(lit string, pos int, n NValue) ([]string, error) { } } - ret = append(ret, lit) + name := fmt.Sprintf("%d", i) + if len(valueNames)-1 >= i { + name = valueNames[i] + } else if len(valueNames) > 0 && strings.HasSuffix(valueNames[len(valueNames)-1], "+") { + name = strings.TrimSuffix(valueNames[len(valueNames)-1], "+") + } + + ret[name] = lit if n == NValue(1) && len(ret) == 1 { break } + + i++ } return ret, nil @@ -263,14 +311,14 @@ func (p *Parser) scanIdent() (string, error) { tracef("scanIdent scanned tok=%s lit=%q pos=%v", tok, lit, pos) - unscanBuf := []ScanEntry{} + unscanBuf := []scanEntry{} if tok == ASSIGN || tok == ARG_DELIMITER { - entry := ScanEntry{tok: tok, lit: lit, pos: pos} + entry := scanEntry{tok: tok, lit: lit, pos: pos} tracef("scanIdent tok=%s; scanning next and pushing to unscan buffer entry=%+#v", tok, entry) - unscanBuf = append([]ScanEntry{entry}, unscanBuf...) + unscanBuf = append([]scanEntry{entry}, unscanBuf...) tok, lit, pos = p.scan() } @@ -279,11 +327,11 @@ func (p *Parser) scanIdent() (string, error) { return lit, nil } - entry := ScanEntry{tok: tok, lit: lit, pos: pos} + entry := scanEntry{tok: tok, lit: lit, pos: pos} tracef("scanIdent tok=%s; unscanning entry=%+#v", tok, entry) - unscanBuf = append([]ScanEntry{entry}, unscanBuf...) + unscanBuf = append([]scanEntry{entry}, unscanBuf...) for _, entry := range unscanBuf { p.unscan(entry.tok, entry.lit, entry.pos) @@ -303,13 +351,13 @@ func (p *Parser) scan() (Token, string, int) { tok, lit, pos := p.s.Scan() - tracef("scan returning next=%s %+#v", tok, ScanEntry{tok: tok, lit: lit, pos: pos}) + tracef("scan returning next=%s %+#v", tok, scanEntry{tok: tok, lit: lit, pos: pos}) return tok, lit, pos } func (p *Parser) unscan(tok Token, lit string, pos int) { - entry := ScanEntry{tok: tok, lit: lit, pos: pos} + entry := scanEntry{tok: tok, lit: lit, pos: pos} tracef("unscan entry=%s %+#v", tok, entry) diff --git a/parser_test.go b/parser_test.go index db9bc03..743de16 100644 --- a/parser_test.go +++ b/parser_test.go @@ -7,10 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -func ptr[T any](v T) *T { - return &v -} - func TestParser(t *testing.T) { for _, tc := range []struct { name string @@ -19,7 +15,6 @@ func TestParser(t *testing.T) { expPT []argh.Node expAST []argh.Node expErr error - skip bool }{ { name: "bare", @@ -35,26 +30,26 @@ func TestParser(t *testing.T) { name: "one positional arg", args: []string{"pizzas", "excel"}, cfg: &argh.ParserConfig{ - ProgValues: 1, + Prog: argh.CommandConfig{NValue: 1}, }, expPT: []argh.Node{ - argh.Program{Name: "pizzas", Values: []string{"excel"}}, + argh.Program{Name: "pizzas", Values: map[string]string{"0": "excel"}}, }, expAST: []argh.Node{ - argh.Program{Name: "pizzas", Values: []string{"excel"}}, + argh.Program{Name: "pizzas", Values: map[string]string{"0": "excel"}}, }, }, { name: "many positional args", args: []string{"pizzas", "excel", "wildly", "when", "feral"}, cfg: &argh.ParserConfig{ - ProgValues: argh.OneOrMoreValue, + Prog: argh.CommandConfig{NValue: argh.OneOrMoreValue}, }, expPT: []argh.Node{ - argh.Program{Name: "pizzas", Values: []string{"excel", "wildly", "when", "feral"}}, + argh.Program{Name: "pizzas", Values: map[string]string{"0": "excel", "1": "wildly", "2": "when", "3": "feral"}}, }, expAST: []argh.Node{ - argh.Program{Name: "pizzas", Values: []string{"excel", "wildly", "when", "feral"}}, + argh.Program{Name: "pizzas", Values: map[string]string{"0": "excel", "1": "wildly", "2": "when", "3": "feral"}}, }, }, { @@ -87,10 +82,10 @@ func TestParser(t *testing.T) { "--please", }, cfg: &argh.ParserConfig{ - Commands: map[string]argh.NValue{}, - Flags: map[string]argh.NValue{ - "fresh": 1, - "box": argh.OneOrMoreValue, + Commands: map[string]argh.CommandConfig{}, + Flags: map[string]argh.FlagConfig{ + "fresh": argh.FlagConfig{NValue: 1}, + "box": argh.FlagConfig{NValue: argh.OneOrMoreValue}, }, }, expPT: []argh.Node{ @@ -98,20 +93,20 @@ func TestParser(t *testing.T) { argh.ArgDelimiter{}, argh.Flag{Name: "tasty"}, argh.ArgDelimiter{}, - argh.Flag{Name: "fresh", Values: []string{"soon"}}, + argh.Flag{Name: "fresh", Values: map[string]string{"0": "soon"}}, argh.ArgDelimiter{}, argh.Flag{Name: "super-hot-right-now"}, argh.ArgDelimiter{}, - argh.Flag{Name: "box", Values: []string{"square", "shaped", "hot"}}, + argh.Flag{Name: "box", Values: map[string]string{"0": "square", "1": "shaped", "2": "hot"}}, argh.ArgDelimiter{}, argh.Flag{Name: "please"}, }, expAST: []argh.Node{ argh.Program{Name: "pizzas"}, argh.Flag{Name: "tasty"}, - argh.Flag{Name: "fresh", Values: []string{"soon"}}, + argh.Flag{Name: "fresh", Values: map[string]string{"0": "soon"}}, argh.Flag{Name: "super-hot-right-now"}, - argh.Flag{Name: "box", Values: []string{"square", "shaped", "hot"}}, + argh.Flag{Name: "box", Values: map[string]string{"0": "square", "1": "shaped", "2": "hot"}}, argh.Flag{Name: "please"}, }, }, @@ -172,8 +167,10 @@ func TestParser(t *testing.T) { name: "mixed long short value flags", args: []string{"pizzas", "-a", "--ca", "-b", "1312", "-lol"}, cfg: &argh.ParserConfig{ - Commands: map[string]argh.NValue{}, - Flags: map[string]argh.NValue{"b": 1}, + Commands: map[string]argh.CommandConfig{}, + Flags: map[string]argh.FlagConfig{ + "b": argh.FlagConfig{NValue: 1}, + }, }, expPT: []argh.Node{ argh.Program{Name: "pizzas"}, @@ -182,7 +179,7 @@ func TestParser(t *testing.T) { argh.ArgDelimiter{}, argh.Flag{Name: "ca"}, argh.ArgDelimiter{}, - argh.Flag{Name: "b", Values: []string{"1312"}}, + argh.Flag{Name: "b", Values: map[string]string{"0": "1312"}}, argh.ArgDelimiter{}, argh.CompoundShortFlag{ Nodes: []argh.Node{ @@ -196,7 +193,7 @@ func TestParser(t *testing.T) { argh.Program{Name: "pizzas"}, argh.Flag{Name: "a"}, argh.Flag{Name: "ca"}, - argh.Flag{Name: "b", Values: []string{"1312"}}, + argh.Flag{Name: "b", Values: map[string]string{"0": "1312"}}, argh.Flag{Name: "l"}, argh.Flag{Name: "o"}, argh.Flag{Name: "l"}, @@ -206,8 +203,11 @@ func TestParser(t *testing.T) { name: "commands", args: []string{"pizzas", "fly", "fry"}, cfg: &argh.ParserConfig{ - Commands: map[string]argh.NValue{"fly": argh.ZeroValue, "fry": argh.ZeroValue}, - Flags: map[string]argh.NValue{}, + Commands: map[string]argh.CommandConfig{ + "fly": argh.CommandConfig{}, + "fry": argh.CommandConfig{}, + }, + Flags: map[string]argh.FlagConfig{}, }, expPT: []argh.Node{ argh.Program{Name: "pizzas"}, @@ -217,17 +217,60 @@ func TestParser(t *testing.T) { argh.Command{Name: "fry"}, }, }, + { + name: "command specific flags", + args: []string{"pizzas", "fly", "--freely", "fry", "--deeply", "-wAt"}, + cfg: &argh.ParserConfig{ + Commands: map[string]argh.CommandConfig{ + "fly": argh.CommandConfig{ + Flags: map[string]argh.FlagConfig{ + "freely": {}, + }, + }, + "fry": argh.CommandConfig{ + Flags: map[string]argh.FlagConfig{ + "deeply": {}, + "w": {}, + "A": {}, + "t": {}, + }, + }, + }, + Flags: map[string]argh.FlagConfig{}, + }, + expPT: []argh.Node{ + argh.Program{Name: "pizzas"}, + argh.ArgDelimiter{}, + argh.Command{Name: "fly"}, + argh.ArgDelimiter{}, + argh.Flag{Name: "freely"}, + argh.ArgDelimiter{}, + argh.Command{Name: "fry"}, + argh.ArgDelimiter{}, + argh.Flag{Name: "deeply"}, + argh.ArgDelimiter{}, + argh.CompoundShortFlag{ + Nodes: []argh.Node{ + argh.Flag{Name: "w"}, + argh.Flag{Name: "A"}, + argh.Flag{Name: "t"}, + }, + }, + }, + }, { name: "total weirdo", args: []string{"PIZZAs", "^wAT@golf", "^^hecKing", "goose", "bonk", "^^FIERCENESS@-2"}, cfg: &argh.ParserConfig{ - Commands: map[string]argh.NValue{"goose": 1}, - Flags: map[string]argh.NValue{ - "w": 0, - "A": 0, - "T": 1, - "hecking": 0, - "FIERCENESS": 1, + Commands: map[string]argh.CommandConfig{ + "goose": argh.CommandConfig{NValue: 1}, + }, + Flags: map[string]argh.FlagConfig{ + "w": argh.FlagConfig{}, + "A": argh.FlagConfig{}, + "T": argh.FlagConfig{NValue: 1}, + "hecking": argh.FlagConfig{}, + "FIERCENESS": argh.FlagConfig{NValue: 1}, }, ScannerConfig: &argh.ScannerConfig{ AssignmentOperator: '@', @@ -242,15 +285,15 @@ func TestParser(t *testing.T) { Nodes: []argh.Node{ argh.Flag{Name: "w"}, argh.Flag{Name: "A"}, - argh.Flag{Name: "T", Values: []string{"golf"}}, + argh.Flag{Name: "T", Values: map[string]string{"0": "golf"}}, }, }, argh.ArgDelimiter{}, argh.Flag{Name: "hecKing"}, argh.ArgDelimiter{}, - argh.Command{Name: "goose", Values: []string{"bonk"}}, + argh.Command{Name: "goose", Values: map[string]string{"0": "bonk"}}, argh.ArgDelimiter{}, - argh.Flag{Name: "FIERCENESS", Values: []string{"-2"}}, + argh.Flag{Name: "FIERCENESS", Values: map[string]string{"0": "-2"}}, }, }, { @@ -263,10 +306,6 @@ func TestParser(t *testing.T) { }, {}, } { - if tc.skip { - continue - } - if tc.expPT != nil { t.Run(tc.name+" parse tree", func(ct *testing.T) { actual, err := argh.ParseArgs(tc.args, tc.cfg) @@ -275,7 +314,7 @@ func TestParser(t *testing.T) { return } - assert.Equal(ct, tc.expPT, actual.ParseTree.Nodes) + assert.Equal(ct, tc.expPT, actual.Nodes) }) } @@ -287,7 +326,7 @@ func TestParser(t *testing.T) { return } - assert.Equal(ct, tc.expAST, actual.AST()) + assert.Equal(ct, tc.expAST, argh.NewQuerier(actual).AST()) }) } } diff --git a/parse_tree.go b/querier.go similarity index 52% rename from parse_tree.go rename to querier.go index 7478abe..27e2b9c 100644 --- a/parse_tree.go +++ b/querier.go @@ -2,14 +2,33 @@ package argh import "fmt" -type ParseTree struct { - Nodes []Node `json:"nodes"` +type Querier interface { + Program() (Program, bool) + TypedAST() []TypedNode + AST() []Node } -func (pt *ParseTree) typedAST() []TypedNode { +func NewQuerier(pt *ParseTree) Querier { + return &defaultQuerier{pt: pt} +} + +type defaultQuerier struct { + pt *ParseTree +} + +func (dq *defaultQuerier) Program() (Program, bool) { + if len(dq.pt.Nodes) == 0 { + return Program{}, false + } + + v, ok := dq.pt.Nodes[0].(Program) + return v, ok +} + +func (dq *defaultQuerier) TypedAST() []TypedNode { ret := []TypedNode{} - for _, node := range pt.Nodes { + for _, node := range dq.pt.Nodes { if _, ok := node.(ArgDelimiter); ok { continue } @@ -30,10 +49,10 @@ func (pt *ParseTree) typedAST() []TypedNode { return ret } -func (pt *ParseTree) ast() []Node { +func (dq *defaultQuerier) AST() []Node { ret := []Node{} - for _, node := range pt.Nodes { + for _, node := range dq.pt.Nodes { if _, ok := node.(ArgDelimiter); ok { continue } diff --git a/querier_test.go b/querier_test.go new file mode 100644 index 0000000..3ec6793 --- /dev/null +++ b/querier_test.go @@ -0,0 +1,46 @@ +package argh_test + +import ( + "testing" + + "git.meatballhat.com/x/box-o-sand/argh" + "github.com/stretchr/testify/require" +) + +func TestQuerier_Program(t *testing.T) { + for _, tc := range []struct { + name string + args []string + cfg *argh.ParserConfig + exp argh.Program + expOK bool + }{ + { + name: "typical", + args: []string{"pizzas", "ahoy", "--treatsa", "fun"}, + exp: argh.Program{Name: "pizzas"}, + expOK: true, + }, + { + name: "minimal", + args: []string{"pizzas"}, + exp: argh.Program{Name: "pizzas"}, + expOK: true, + }, + { + name: "invalid", + args: []string{}, + exp: argh.Program{}, + expOK: false, + }, + } { + t.Run(tc.name, func(ct *testing.T) { + pt, err := argh.ParseArgs(tc.args, tc.cfg) + require.Nil(ct, err) + + prog, ok := argh.NewQuerier(pt).Program() + require.Equal(ct, tc.exp, prog) + require.Equal(ct, tc.expOK, ok) + }) + } +}