1use pratt::{Affix, Associativity, PrattError, PrattParser, Precedence};
2use proc_macro2::{Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
3use proc_macro_error::{abort, abort_call_site, proc_macro_error};
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{punctuated::Punctuated, Token};
6use winnow::{
7 combinator::{alt, opt, repeat, separated, trace},
8 error::ContextError,
9 token::any,
10 PResult, Parser,
11};
12use wrapper::InputWrapper;
13
14mod wrapper;
15
16#[proc_macro]
17#[proc_macro_error]
18pub fn rule(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
19 let tokens: TokenStream = tokens.into();
20 let i: Vec<TokenTree> = tokens.into_iter().collect();
21
22 let rule = parse_rule(i.iter().cloned().collect());
23 rule.check_return_type();
24 rule.to_token_stream().into()
25}
26
27#[derive(Debug, Clone)]
28struct Path {
29 segments: Vec<Ident>,
30}
31
32#[derive(Debug, Clone)]
33enum Rule {
34 MatchText(Span, Literal),
35 MatchToken(Span, Path),
36 ExternalFunction(Span, Path, Option<Group>),
37 Context(Span, Literal, Box<Rule>),
38 Peek(Span, Box<Rule>),
39 Not(Span, Box<Rule>),
40 Opt(Span, Box<Rule>),
41 Cut(Span, Box<Rule>),
42 Many0(Span, Box<Rule>),
43 Many1(Span, Box<Rule>),
44 Sequence(Span, Vec<Rule>),
45 Alt(Span, Vec<Rule>),
46}
47
48#[derive(Debug, Clone)]
49enum RuleElement {
50 MatchText(Literal),
51 MatchToken(Path),
52 ExternalFunction(Path, Option<Group>),
53 Context(Literal),
54 Peek,
55 Not,
56 Opt,
57 Cut,
58 Many0,
59 Many1,
60 Sequence,
61 Alt,
62 SubRule(Rule),
63}
64
65#[derive(Debug, Clone)]
66struct WithSpan {
67 elem: RuleElement,
68 span: Span,
69}
70
71#[derive(Debug, Clone)]
72enum ReturnType {
73 Option(Box<ReturnType>),
74 Vec(Box<ReturnType>),
75 Unit,
76 Unknown,
77}
78
79type Input<'a> = InputWrapper<'a>;
80
81fn match_punct<'a>(punct: char) -> impl Parser<Input<'a>, TokenTree, ContextError> {
82 trace(
83 punct,
84 any.verify_map(move |token| match token {
85 TokenTree::Punct(ref p) if p.as_char() == punct => Some(token.clone()),
86 _ => None,
87 }),
88 )
89}
90
91fn group<'a>(input: &mut Input<'a>) -> PResult<Group> {
92 any.verify_map(move |token| match token {
93 TokenTree::Group(ref group) => Some(group.clone()),
94 _ => None,
95 })
96 .parse_next(input)
97}
98
99fn literal<'a>(input: &mut Input<'a>) -> PResult<Literal> {
100 any.verify_map(move |token| match token {
101 TokenTree::Literal(ref lit) => Some(lit.clone()),
102 _ => None,
103 })
104 .parse_next(input)
105}
106
107fn ident<'a>(input: &mut Input<'a>) -> PResult<Ident> {
108 trace(
109 "ident",
110 any.verify_map(move |token| match token {
111 TokenTree::Ident(ref ident) => Some(ident.clone()),
112 _ => None,
113 }),
114 )
115 .parse_next(input)
116}
117
118fn path<'a>(input: &mut Input<'a>) -> PResult<(Span, Path)> {
119 separated(1.., ident, (match_punct(':'), match_punct(':')))
120 .map(|segments: Vec<_>| {
121 let span = segments[1..]
122 .iter()
123 .fold(segments[0].span(), |acc, segment| {
124 acc.join(segment.span()).unwrap()
125 })
126 .unwrap()
127 .into();
128 let path = Path { segments };
129 (span, path)
130 })
131 .parse_next(input)
132}
133
134fn parse_rule(tokens: TokenStream) -> Rule {
135 let i: Vec<TokenTree> = tokens.into_iter().collect();
136 let i = &mut InputWrapper(&i[..]);
137
138 let elems: Vec<_> = repeat(0.., parse_rule_element).parse_next(i).unwrap();
139 let i = i.0;
140 if !i.is_empty() {
141 let rest: TokenStream = i.iter().cloned().collect();
142 abort!(rest, "unable to parse the following rules: {}", rest);
143 }
144
145 let mut iter = elems.into_iter().peekable();
146 let rule = unwrap_pratt(RuleParser.parse(&mut iter));
147 if iter.peek().is_some() {
148 let rest: Vec<_> = iter.collect();
149 abort!(
150 rest[0].span,
151 "unable to parse the following rules: {:?}",
152 rest
153 );
154 }
155
156 rule
157}
158
159fn parse_rule_element<'a>(i: &mut Input<'a>) -> PResult<WithSpan> {
160 let function_call = |i: &mut Input<'a>| {
161 let hashtag = match_punct('#').parse_next(i)?;
162 let (path_span, fn_path) = path(i)?;
163 let args = opt(group).parse_next(i)?;
164 let span = hashtag.span().join(path_span).unwrap();
165 let span = args
166 .as_ref()
167 .map(|args| args.span().join(span).unwrap())
168 .unwrap_or(span);
169
170 Ok(WithSpan {
171 elem: RuleElement::ExternalFunction(fn_path, args),
172 span,
173 })
174 };
175 let context = (match_punct(':'), literal).map(|(colon, msg)| {
176 let span = colon.span().join(msg.span()).unwrap();
177 WithSpan {
178 elem: RuleElement::Context(msg),
179 span,
180 }
181 });
182 alt((
183 match_punct('|').map(|token| WithSpan {
184 span: token.span(),
185 elem: RuleElement::Alt,
186 }),
187 match_punct('*').map(|token| WithSpan {
188 span: token.span(),
189 elem: RuleElement::Many0,
190 }),
191 match_punct('+').map(|token| WithSpan {
192 span: token.span(),
193 elem: RuleElement::Many1,
194 }),
195 match_punct('?').map(|token| WithSpan {
196 span: token.span(),
197 elem: RuleElement::Opt,
198 }),
199 match_punct('^').map(|token| WithSpan {
200 span: token.span(),
201 elem: RuleElement::Cut,
202 }),
203 match_punct('&').map(|token| WithSpan {
204 span: token.span(),
205 elem: RuleElement::Peek,
206 }),
207 match_punct('!').map(|token| WithSpan {
208 span: token.span(),
209 elem: RuleElement::Not,
210 }),
211 match_punct('~').map(|token| WithSpan {
212 span: token.span(),
213 elem: RuleElement::Sequence,
214 }),
215 literal.map(|lit| WithSpan {
216 span: lit.span(),
217 elem: RuleElement::MatchText(lit),
218 }),
219 path.map(|(span, p)| WithSpan {
220 span,
221 elem: RuleElement::MatchToken(p),
222 }),
223 group.map(|group| WithSpan {
224 span: group.span(),
225 elem: RuleElement::SubRule(parse_rule(group.stream())),
226 }),
227 function_call,
228 context,
229 ))
230 .parse_next(i)
231}
232
233fn unwrap_pratt(res: Result<Rule, PrattError<WithSpan, pratt::NoError>>) -> Rule {
234 match res {
235 Ok(res) => res,
236 Err(PrattError::EmptyInput) => abort_call_site!("expected more tokens for rule"),
237 Err(PrattError::UnexpectedNilfix(input)) => {
238 abort!(input.span, "unable to parse the value")
239 }
240 Err(PrattError::UnexpectedPrefix(input)) => {
241 abort!(input.span, "unable to parse the prefix operator")
242 }
243 Err(PrattError::UnexpectedInfix(input)) => {
244 abort!(input.span, "unable to parse the binary operator")
245 }
246 Err(PrattError::UnexpectedPostfix(input)) => {
247 abort!(input.span, "unable to parse the postfix operator")
248 }
249 Err(PrattError::UserError(_)) => unreachable!(),
250 }
251}
252
253struct RuleParser;
254
255impl<I: Iterator<Item = WithSpan>> PrattParser<I> for RuleParser {
256 type Error = pratt::NoError;
257 type Input = WithSpan;
258 type Output = Rule;
259
260 fn query(&mut self, elem: &WithSpan) -> pratt::Result<Affix> {
261 let affix = match elem.elem {
262 RuleElement::Alt => Affix::Infix(Precedence(1), Associativity::Left),
263 RuleElement::Context(_) => Affix::Postfix(Precedence(2)),
264 RuleElement::Sequence => Affix::Infix(Precedence(3), Associativity::Left),
265 RuleElement::Opt => Affix::Postfix(Precedence(4)),
266 RuleElement::Many1 => Affix::Postfix(Precedence(4)),
267 RuleElement::Many0 => Affix::Postfix(Precedence(4)),
268 RuleElement::Cut => Affix::Prefix(Precedence(5)),
269 RuleElement::Peek => Affix::Prefix(Precedence(5)),
270 RuleElement::Not => Affix::Prefix(Precedence(5)),
271 _ => Affix::Nilfix,
272 };
273 Ok(affix)
274 }
275
276 fn primary(&mut self, elem: WithSpan) -> pratt::Result<Rule> {
277 let rule = match elem.elem {
278 RuleElement::SubRule(rule) => rule,
279 RuleElement::MatchText(text) => Rule::MatchText(elem.span, text),
280 RuleElement::MatchToken(token) => Rule::MatchToken(elem.span, token),
281 RuleElement::ExternalFunction(func, args) => {
282 Rule::ExternalFunction(elem.span, func, args)
283 }
284 _ => unreachable!(),
285 };
286 Ok(rule)
287 }
288
289 fn infix(&mut self, lhs: Rule, elem: WithSpan, rhs: Rule) -> pratt::Result<Rule> {
290 let rule = match elem.elem {
291 RuleElement::Sequence => match lhs {
292 Rule::Sequence(span, mut seq) => {
293 let span = span.join(elem.span).unwrap().join(rhs.span()).unwrap();
294 seq.push(rhs);
295 Rule::Sequence(span, seq)
296 }
297 lhs => {
298 let span = lhs.span().join(rhs.span()).unwrap();
299 Rule::Sequence(span, vec![lhs, rhs])
300 }
301 },
302 RuleElement::Alt => match lhs {
303 Rule::Alt(span, mut choices) => {
304 let span = span.join(elem.span).unwrap().join(rhs.span()).unwrap();
305 choices.push(rhs);
306 Rule::Alt(span, choices)
307 }
308 lhs => {
309 let span = lhs.span().join(rhs.span()).unwrap();
310 Rule::Alt(span, vec![lhs, rhs])
311 }
312 },
313 _ => unreachable!(),
314 };
315 Ok(rule)
316 }
317
318 fn prefix(&mut self, elem: WithSpan, rhs: Rule) -> pratt::Result<Rule> {
319 let rule = match elem.elem {
320 RuleElement::Cut => {
321 let span = elem.span.join(rhs.span()).unwrap();
322 Rule::Cut(span, Box::new(rhs))
323 }
324 RuleElement::Peek => {
325 let span = elem.span.join(rhs.span()).unwrap();
326 Rule::Peek(span, Box::new(rhs))
327 }
328 RuleElement::Not => {
329 let span = elem.span.join(rhs.span()).unwrap();
330 Rule::Not(span, Box::new(rhs))
331 }
332 _ => unreachable!(),
333 };
334 Ok(rule)
335 }
336
337 fn postfix(&mut self, lhs: Rule, elem: WithSpan) -> pratt::Result<Rule> {
338 let rule = match elem.elem {
339 RuleElement::Opt => {
340 let span = lhs.span().join(elem.span).unwrap();
341 Rule::Opt(span, Box::new(lhs))
342 }
343 RuleElement::Many0 => {
344 let span = lhs.span().join(elem.span).unwrap();
345 Rule::Many0(span, Box::new(lhs))
346 }
347 RuleElement::Many1 => {
348 let span = lhs.span().join(elem.span).unwrap();
349 Rule::Many1(span, Box::new(lhs))
350 }
351 RuleElement::Context(msg) => {
352 let span = lhs.span().join(elem.span).unwrap();
353 Rule::Context(span, msg, Box::new(lhs))
354 }
355 _ => unreachable!(),
356 };
357 Ok(rule)
358 }
359}
360
361impl std::fmt::Display for ReturnType {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 match self {
364 ReturnType::Option(ty) => write!(f, "Option<{}>", ty),
365 ReturnType::Vec(ty) => write!(f, "Vec<{}>", ty),
366 ReturnType::Unit => write!(f, "()"),
367 ReturnType::Unknown => write!(f, "_"),
368 }
369 }
370}
371
372impl PartialEq for ReturnType {
373 fn eq(&self, other: &ReturnType) -> bool {
374 match (self, other) {
375 (ReturnType::Option(lhs), ReturnType::Option(rhs)) => lhs == rhs,
376 (ReturnType::Vec(lhs), ReturnType::Vec(rhs)) => lhs == rhs,
377 (ReturnType::Unit, ReturnType::Unit) => true,
378 (ReturnType::Unknown, _) => true,
379 (_, ReturnType::Unknown) => true,
380 _ => false,
381 }
382 }
383}
384
385impl Rule {
386 fn check_return_type(&self) -> ReturnType {
387 match self {
388 Rule::MatchText(_, _) | Rule::MatchToken(_, _) | Rule::ExternalFunction(_, _, _) => {
389 ReturnType::Unknown
390 }
391 Rule::Context(_, _, rule) | Rule::Peek(_, rule) => rule.check_return_type(),
392 Rule::Not(_, _) => ReturnType::Unit,
393 Rule::Opt(_, rule) => ReturnType::Option(Box::new(rule.check_return_type())),
394 Rule::Cut(_, rule) => rule.check_return_type(),
395 Rule::Many0(_, rule) | Rule::Many1(_, rule) => {
396 ReturnType::Vec(Box::new(rule.check_return_type()))
397 }
398 Rule::Sequence(_, rules) => {
399 rules.iter().for_each(|rule| {
400 rule.check_return_type();
401 });
402 ReturnType::Vec(Box::new(ReturnType::Unknown))
403 }
404 Rule::Alt(_, rules) => {
405 for slice in rules.windows(2) {
406 match (slice[0].check_return_type(), slice[1].check_return_type()) {
407 (ReturnType::Option(_), _) => {
408 abort!(
409 slice[0].span(),
410 "optional shouldn't be in a choice because it will shortcut the following branches",
411 )
412 }
413 (a, b) if a != b => abort!(
414 slice[0].span().join(slice[1].span()).unwrap(),
415 "type mismatched between {:} and {:}",
416 a,
417 b,
418 ),
419 _ => (),
420 }
421 }
422 ReturnType::Vec(Box::new(rules[0].check_return_type()))
423 }
424 }
425 }
426
427 fn span(&self) -> Span {
428 match self {
429 Rule::MatchText(span, _)
430 | Rule::MatchToken(span, _)
431 | Rule::ExternalFunction(span, _, _)
432 | Rule::Context(span, _, _)
433 | Rule::Peek(span, _)
434 | Rule::Not(span, _)
435 | Rule::Opt(span, _)
436 | Rule::Cut(span, _)
437 | Rule::Many0(span, _)
438 | Rule::Many1(span, _)
439 | Rule::Sequence(span, _)
440 | Rule::Alt(span, _) => *span,
441 }
442 }
443
444 fn to_tokens(&self, tokens: &mut TokenStream) {
445 let token = match self {
446 Rule::ExternalFunction(_, name, arg) => {
447 quote! { #name #arg }
448 }
449 Rule::Context(_, msg, rule) => {
450 let rule = rule.to_token_stream();
451 quote! { #rule.context(winnow::error::StrContext::Label(#msg)) }
452 }
453 Rule::Peek(_, rule) => {
454 let rule = rule.to_token_stream();
455 quote! { winnow::combinator::peek(#rule) }
456 }
457 Rule::Not(_, rule) => {
458 let rule = rule.to_token_stream();
459 quote! { winnow::combinator::not(#rule) }
460 }
461 Rule::Opt(_, rule) => {
462 let rule = rule.to_token_stream();
463 quote! { winnow::combinator::opt(#rule) }
464 }
465 Rule::Cut(_, rule) => {
466 let rule = rule.to_token_stream();
467 quote! { winnow::combinator::cut_err(#rule) }
468 }
469 Rule::Many0(_, rule) => {
470 let rule = rule.to_token_stream();
471 quote! { winnow::combinator::repeat(0.., #rule) }
472 }
473 Rule::Many1(_, rule) => {
474 let rule = rule.to_token_stream();
475 quote! { winnow::combinator::repeat(1.., #rule) }
476 }
477 Rule::Sequence(_, rules) => {
478 let list: Punctuated<TokenStream, Token![,]> =
479 rules.iter().map(|rule| rule.to_token_stream()).collect();
480 quote! { ((#list)) }
481 }
482 Rule::Alt(_, rules) => {
483 let list: Punctuated<TokenStream, Token![,]> =
484 rules.iter().map(|rule| rule.to_token_stream()).collect();
485 quote! { nom::branch::alt((#list)) }
486 }
487 _ => unimplemented!(),
488 };
489
490 tokens.extend(token);
491 }
492
493 fn to_token_stream(&self) -> TokenStream {
494 let mut tokens = TokenStream::new();
495 self.to_tokens(&mut tokens);
496 tokens
497 }
498}
499
500impl ToTokens for Path {
501 fn to_tokens(&self, tokens: &mut TokenStream) {
502 for (i, segment) in self.segments.iter().enumerate() {
503 if i > 0 {
504 tokens.append(Punct::new(':', Spacing::Joint));
506 tokens.append(Punct::new(':', Spacing::Alone));
507 }
508 segment.to_tokens(tokens);
509 }
510 }
511}