winnow_rule/
lib.rs

1use pratt::{Affix, Associativity, PrattError, PrattParser, Precedence};
2use proc_macro2::{Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
3use proc_macro_error::{abort, abort_call_site, proc_macro_error};
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{punctuated::Punctuated, Token};
6use winnow::{
7    combinator::{alt, opt, repeat, separated, trace},
8    error::ContextError,
9    token::any,
10    PResult, Parser,
11};
12use wrapper::InputWrapper;
13
14mod wrapper;
15
16#[proc_macro]
17#[proc_macro_error]
18pub fn rule(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
19    let tokens: TokenStream = tokens.into();
20    let i: Vec<TokenTree> = tokens.into_iter().collect();
21
22    let rule = parse_rule(i.iter().cloned().collect());
23    rule.check_return_type();
24    rule.to_token_stream().into()
25}
26
27#[derive(Debug, Clone)]
28struct Path {
29    segments: Vec<Ident>,
30}
31
32#[derive(Debug, Clone)]
33enum Rule {
34    MatchText(Span, Literal),
35    MatchToken(Span, Path),
36    ExternalFunction(Span, Path, Option<Group>),
37    Context(Span, Literal, Box<Rule>),
38    Peek(Span, Box<Rule>),
39    Not(Span, Box<Rule>),
40    Opt(Span, Box<Rule>),
41    Cut(Span, Box<Rule>),
42    Many0(Span, Box<Rule>),
43    Many1(Span, Box<Rule>),
44    Sequence(Span, Vec<Rule>),
45    Alt(Span, Vec<Rule>),
46}
47
48#[derive(Debug, Clone)]
49enum RuleElement {
50    MatchText(Literal),
51    MatchToken(Path),
52    ExternalFunction(Path, Option<Group>),
53    Context(Literal),
54    Peek,
55    Not,
56    Opt,
57    Cut,
58    Many0,
59    Many1,
60    Sequence,
61    Alt,
62    SubRule(Rule),
63}
64
65#[derive(Debug, Clone)]
66struct WithSpan {
67    elem: RuleElement,
68    span: Span,
69}
70
71#[derive(Debug, Clone)]
72enum ReturnType {
73    Option(Box<ReturnType>),
74    Vec(Box<ReturnType>),
75    Unit,
76    Unknown,
77}
78
79type Input<'a> = InputWrapper<'a>;
80
81fn match_punct<'a>(punct: char) -> impl Parser<Input<'a>, TokenTree, ContextError> {
82    trace(
83        punct,
84        any.verify_map(move |token| match token {
85            TokenTree::Punct(ref p) if p.as_char() == punct => Some(token.clone()),
86            _ => None,
87        }),
88    )
89}
90
91fn group<'a>(input: &mut Input<'a>) -> PResult<Group> {
92    any.verify_map(move |token| match token {
93        TokenTree::Group(ref group) => Some(group.clone()),
94        _ => None,
95    })
96    .parse_next(input)
97}
98
99fn literal<'a>(input: &mut Input<'a>) -> PResult<Literal> {
100    any.verify_map(move |token| match token {
101        TokenTree::Literal(ref lit) => Some(lit.clone()),
102        _ => None,
103    })
104    .parse_next(input)
105}
106
107fn ident<'a>(input: &mut Input<'a>) -> PResult<Ident> {
108    trace(
109        "ident",
110        any.verify_map(move |token| match token {
111            TokenTree::Ident(ref ident) => Some(ident.clone()),
112            _ => None,
113        }),
114    )
115    .parse_next(input)
116}
117
118fn path<'a>(input: &mut Input<'a>) -> PResult<(Span, Path)> {
119    separated(1.., ident, (match_punct(':'), match_punct(':')))
120        .map(|segments: Vec<_>| {
121            let span = segments[1..]
122                .iter()
123                .fold(segments[0].span(), |acc, segment| {
124                    acc.join(segment.span()).unwrap()
125                })
126                .unwrap()
127                .into();
128            let path = Path { segments };
129            (span, path)
130        })
131        .parse_next(input)
132}
133
134fn parse_rule(tokens: TokenStream) -> Rule {
135    let i: Vec<TokenTree> = tokens.into_iter().collect();
136    let i = &mut InputWrapper(&i[..]);
137
138    let elems: Vec<_> = repeat(0.., parse_rule_element).parse_next(i).unwrap();
139    let i = i.0;
140    if !i.is_empty() {
141        let rest: TokenStream = i.iter().cloned().collect();
142        abort!(rest, "unable to parse the following rules: {}", rest);
143    }
144
145    let mut iter = elems.into_iter().peekable();
146    let rule = unwrap_pratt(RuleParser.parse(&mut iter));
147    if iter.peek().is_some() {
148        let rest: Vec<_> = iter.collect();
149        abort!(
150            rest[0].span,
151            "unable to parse the following rules: {:?}",
152            rest
153        );
154    }
155
156    rule
157}
158
159fn parse_rule_element<'a>(i: &mut Input<'a>) -> PResult<WithSpan> {
160    let function_call = |i: &mut Input<'a>| {
161        let hashtag = match_punct('#').parse_next(i)?;
162        let (path_span, fn_path) = path(i)?;
163        let args = opt(group).parse_next(i)?;
164        let span = hashtag.span().join(path_span).unwrap();
165        let span = args
166            .as_ref()
167            .map(|args| args.span().join(span).unwrap())
168            .unwrap_or(span);
169
170        Ok(WithSpan {
171            elem: RuleElement::ExternalFunction(fn_path, args),
172            span,
173        })
174    };
175    let context = (match_punct(':'), literal).map(|(colon, msg)| {
176        let span = colon.span().join(msg.span()).unwrap();
177        WithSpan {
178            elem: RuleElement::Context(msg),
179            span,
180        }
181    });
182    alt((
183        match_punct('|').map(|token| WithSpan {
184            span: token.span(),
185            elem: RuleElement::Alt,
186        }),
187        match_punct('*').map(|token| WithSpan {
188            span: token.span(),
189            elem: RuleElement::Many0,
190        }),
191        match_punct('+').map(|token| WithSpan {
192            span: token.span(),
193            elem: RuleElement::Many1,
194        }),
195        match_punct('?').map(|token| WithSpan {
196            span: token.span(),
197            elem: RuleElement::Opt,
198        }),
199        match_punct('^').map(|token| WithSpan {
200            span: token.span(),
201            elem: RuleElement::Cut,
202        }),
203        match_punct('&').map(|token| WithSpan {
204            span: token.span(),
205            elem: RuleElement::Peek,
206        }),
207        match_punct('!').map(|token| WithSpan {
208            span: token.span(),
209            elem: RuleElement::Not,
210        }),
211        match_punct('~').map(|token| WithSpan {
212            span: token.span(),
213            elem: RuleElement::Sequence,
214        }),
215        literal.map(|lit| WithSpan {
216            span: lit.span(),
217            elem: RuleElement::MatchText(lit),
218        }),
219        path.map(|(span, p)| WithSpan {
220            span,
221            elem: RuleElement::MatchToken(p),
222        }),
223        group.map(|group| WithSpan {
224            span: group.span(),
225            elem: RuleElement::SubRule(parse_rule(group.stream())),
226        }),
227        function_call,
228        context,
229    ))
230    .parse_next(i)
231}
232
233fn unwrap_pratt(res: Result<Rule, PrattError<WithSpan, pratt::NoError>>) -> Rule {
234    match res {
235        Ok(res) => res,
236        Err(PrattError::EmptyInput) => abort_call_site!("expected more tokens for rule"),
237        Err(PrattError::UnexpectedNilfix(input)) => {
238            abort!(input.span, "unable to parse the value")
239        }
240        Err(PrattError::UnexpectedPrefix(input)) => {
241            abort!(input.span, "unable to parse the prefix operator")
242        }
243        Err(PrattError::UnexpectedInfix(input)) => {
244            abort!(input.span, "unable to parse the binary operator")
245        }
246        Err(PrattError::UnexpectedPostfix(input)) => {
247            abort!(input.span, "unable to parse the postfix operator")
248        }
249        Err(PrattError::UserError(_)) => unreachable!(),
250    }
251}
252
253struct RuleParser;
254
255impl<I: Iterator<Item = WithSpan>> PrattParser<I> for RuleParser {
256    type Error = pratt::NoError;
257    type Input = WithSpan;
258    type Output = Rule;
259
260    fn query(&mut self, elem: &WithSpan) -> pratt::Result<Affix> {
261        let affix = match elem.elem {
262            RuleElement::Alt => Affix::Infix(Precedence(1), Associativity::Left),
263            RuleElement::Context(_) => Affix::Postfix(Precedence(2)),
264            RuleElement::Sequence => Affix::Infix(Precedence(3), Associativity::Left),
265            RuleElement::Opt => Affix::Postfix(Precedence(4)),
266            RuleElement::Many1 => Affix::Postfix(Precedence(4)),
267            RuleElement::Many0 => Affix::Postfix(Precedence(4)),
268            RuleElement::Cut => Affix::Prefix(Precedence(5)),
269            RuleElement::Peek => Affix::Prefix(Precedence(5)),
270            RuleElement::Not => Affix::Prefix(Precedence(5)),
271            _ => Affix::Nilfix,
272        };
273        Ok(affix)
274    }
275
276    fn primary(&mut self, elem: WithSpan) -> pratt::Result<Rule> {
277        let rule = match elem.elem {
278            RuleElement::SubRule(rule) => rule,
279            RuleElement::MatchText(text) => Rule::MatchText(elem.span, text),
280            RuleElement::MatchToken(token) => Rule::MatchToken(elem.span, token),
281            RuleElement::ExternalFunction(func, args) => {
282                Rule::ExternalFunction(elem.span, func, args)
283            }
284            _ => unreachable!(),
285        };
286        Ok(rule)
287    }
288
289    fn infix(&mut self, lhs: Rule, elem: WithSpan, rhs: Rule) -> pratt::Result<Rule> {
290        let rule = match elem.elem {
291            RuleElement::Sequence => match lhs {
292                Rule::Sequence(span, mut seq) => {
293                    let span = span.join(elem.span).unwrap().join(rhs.span()).unwrap();
294                    seq.push(rhs);
295                    Rule::Sequence(span, seq)
296                }
297                lhs => {
298                    let span = lhs.span().join(rhs.span()).unwrap();
299                    Rule::Sequence(span, vec![lhs, rhs])
300                }
301            },
302            RuleElement::Alt => match lhs {
303                Rule::Alt(span, mut choices) => {
304                    let span = span.join(elem.span).unwrap().join(rhs.span()).unwrap();
305                    choices.push(rhs);
306                    Rule::Alt(span, choices)
307                }
308                lhs => {
309                    let span = lhs.span().join(rhs.span()).unwrap();
310                    Rule::Alt(span, vec![lhs, rhs])
311                }
312            },
313            _ => unreachable!(),
314        };
315        Ok(rule)
316    }
317
318    fn prefix(&mut self, elem: WithSpan, rhs: Rule) -> pratt::Result<Rule> {
319        let rule = match elem.elem {
320            RuleElement::Cut => {
321                let span = elem.span.join(rhs.span()).unwrap();
322                Rule::Cut(span, Box::new(rhs))
323            }
324            RuleElement::Peek => {
325                let span = elem.span.join(rhs.span()).unwrap();
326                Rule::Peek(span, Box::new(rhs))
327            }
328            RuleElement::Not => {
329                let span = elem.span.join(rhs.span()).unwrap();
330                Rule::Not(span, Box::new(rhs))
331            }
332            _ => unreachable!(),
333        };
334        Ok(rule)
335    }
336
337    fn postfix(&mut self, lhs: Rule, elem: WithSpan) -> pratt::Result<Rule> {
338        let rule = match elem.elem {
339            RuleElement::Opt => {
340                let span = lhs.span().join(elem.span).unwrap();
341                Rule::Opt(span, Box::new(lhs))
342            }
343            RuleElement::Many0 => {
344                let span = lhs.span().join(elem.span).unwrap();
345                Rule::Many0(span, Box::new(lhs))
346            }
347            RuleElement::Many1 => {
348                let span = lhs.span().join(elem.span).unwrap();
349                Rule::Many1(span, Box::new(lhs))
350            }
351            RuleElement::Context(msg) => {
352                let span = lhs.span().join(elem.span).unwrap();
353                Rule::Context(span, msg, Box::new(lhs))
354            }
355            _ => unreachable!(),
356        };
357        Ok(rule)
358    }
359}
360
361impl std::fmt::Display for ReturnType {
362    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        match self {
364            ReturnType::Option(ty) => write!(f, "Option<{}>", ty),
365            ReturnType::Vec(ty) => write!(f, "Vec<{}>", ty),
366            ReturnType::Unit => write!(f, "()"),
367            ReturnType::Unknown => write!(f, "_"),
368        }
369    }
370}
371
372impl PartialEq for ReturnType {
373    fn eq(&self, other: &ReturnType) -> bool {
374        match (self, other) {
375            (ReturnType::Option(lhs), ReturnType::Option(rhs)) => lhs == rhs,
376            (ReturnType::Vec(lhs), ReturnType::Vec(rhs)) => lhs == rhs,
377            (ReturnType::Unit, ReturnType::Unit) => true,
378            (ReturnType::Unknown, _) => true,
379            (_, ReturnType::Unknown) => true,
380            _ => false,
381        }
382    }
383}
384
385impl Rule {
386    fn check_return_type(&self) -> ReturnType {
387        match self {
388            Rule::MatchText(_, _) | Rule::MatchToken(_, _) | Rule::ExternalFunction(_, _, _) => {
389                ReturnType::Unknown
390            }
391            Rule::Context(_, _, rule) | Rule::Peek(_, rule) => rule.check_return_type(),
392            Rule::Not(_, _) => ReturnType::Unit,
393            Rule::Opt(_, rule) => ReturnType::Option(Box::new(rule.check_return_type())),
394            Rule::Cut(_, rule) => rule.check_return_type(),
395            Rule::Many0(_, rule) | Rule::Many1(_, rule) => {
396                ReturnType::Vec(Box::new(rule.check_return_type()))
397            }
398            Rule::Sequence(_, rules) => {
399                rules.iter().for_each(|rule| {
400                    rule.check_return_type();
401                });
402                ReturnType::Vec(Box::new(ReturnType::Unknown))
403            }
404            Rule::Alt(_, rules) => {
405                for slice in rules.windows(2) {
406                    match (slice[0].check_return_type(), slice[1].check_return_type()) {
407                        (ReturnType::Option(_), _) => {
408                            abort!(
409                                slice[0].span(),
410                                "optional shouldn't be in a choice because it will shortcut the following branches",
411                            )
412                        }
413                        (a, b) if a != b => abort!(
414                            slice[0].span().join(slice[1].span()).unwrap(),
415                            "type mismatched between {:} and {:}",
416                            a,
417                            b,
418                        ),
419                        _ => (),
420                    }
421                }
422                ReturnType::Vec(Box::new(rules[0].check_return_type()))
423            }
424        }
425    }
426
427    fn span(&self) -> Span {
428        match self {
429            Rule::MatchText(span, _)
430            | Rule::MatchToken(span, _)
431            | Rule::ExternalFunction(span, _, _)
432            | Rule::Context(span, _, _)
433            | Rule::Peek(span, _)
434            | Rule::Not(span, _)
435            | Rule::Opt(span, _)
436            | Rule::Cut(span, _)
437            | Rule::Many0(span, _)
438            | Rule::Many1(span, _)
439            | Rule::Sequence(span, _)
440            | Rule::Alt(span, _) => *span,
441        }
442    }
443
444    fn to_tokens(&self, tokens: &mut TokenStream) {
445        let token = match self {
446            Rule::ExternalFunction(_, name, arg) => {
447                quote! { #name #arg }
448            }
449            Rule::Context(_, msg, rule) => {
450                let rule = rule.to_token_stream();
451                quote! { #rule.context(winnow::error::StrContext::Label(#msg)) }
452            }
453            Rule::Peek(_, rule) => {
454                let rule = rule.to_token_stream();
455                quote! { winnow::combinator::peek(#rule) }
456            }
457            Rule::Not(_, rule) => {
458                let rule = rule.to_token_stream();
459                quote! { winnow::combinator::not(#rule) }
460            }
461            Rule::Opt(_, rule) => {
462                let rule = rule.to_token_stream();
463                quote! { winnow::combinator::opt(#rule) }
464            }
465            Rule::Cut(_, rule) => {
466                let rule = rule.to_token_stream();
467                quote! { winnow::combinator::cut_err(#rule) }
468            }
469            Rule::Many0(_, rule) => {
470                let rule = rule.to_token_stream();
471                quote! { winnow::combinator::repeat(0.., #rule) }
472            }
473            Rule::Many1(_, rule) => {
474                let rule = rule.to_token_stream();
475                quote! { winnow::combinator::repeat(1.., #rule) }
476            }
477            Rule::Sequence(_, rules) => {
478                let list: Punctuated<TokenStream, Token![,]> =
479                    rules.iter().map(|rule| rule.to_token_stream()).collect();
480                quote! { ((#list)) }
481            }
482            Rule::Alt(_, rules) => {
483                let list: Punctuated<TokenStream, Token![,]> =
484                    rules.iter().map(|rule| rule.to_token_stream()).collect();
485                quote! { nom::branch::alt((#list)) }
486            }
487            _ => unimplemented!(),
488        };
489
490        tokens.extend(token);
491    }
492
493    fn to_token_stream(&self) -> TokenStream {
494        let mut tokens = TokenStream::new();
495        self.to_tokens(&mut tokens);
496        tokens
497    }
498}
499
500impl ToTokens for Path {
501    fn to_tokens(&self, tokens: &mut TokenStream) {
502        for (i, segment) in self.segments.iter().enumerate() {
503            if i > 0 {
504                // Double colon `::`
505                tokens.append(Punct::new(':', Spacing::Joint));
506                tokens.append(Punct::new(':', Spacing::Alone));
507            }
508            segment.to_tokens(tokens);
509        }
510    }
511}