1use std::marker::PhantomData;
2
3use heck::AsPascalCase;
4use proc_macro::TokenStream;
5use proc_macro2::{Delimiter, Group, Span};
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{
8 parse::{Parse, ParseStream, Parser},
9 parse_macro_input,
10 punctuated::Punctuated,
11 spanned::Spanned,
12 FnArg, Ident, Pat, PatType, Token,
13};
14
15fn pascalize(ident: &Ident) -> Ident {
16 Ident::new(&AsPascalCase(&ident.to_string()).to_string(), ident.span())
17}
18
19trait GotoSemantics {
20 fn transform_goto(input: ParseStream) -> syn::Result<proc_macro2::TokenTree>;
21}
22
23#[derive(Debug)]
24struct GotoBlockContents<T: GotoSemantics> {
25 stream: proc_macro2::TokenStream,
26 goto_semantics: PhantomData<T>,
27}
28
29impl<T: GotoSemantics> Parse for GotoBlockContents<T> {
30 fn parse(input: ParseStream) -> syn::Result<Self> {
31 let mut tokens = proc_macro2::TokenStream::new();
32 while let Ok(token) = input.parse::<proc_macro2::TokenTree>() {
33 let tt = match token {
34 proc_macro2::TokenTree::Group(grp) => {
35 let delim = grp.delimiter();
36 let span = grp.span();
37 let contents: GotoBlockContents<T> = syn::parse2(grp.stream())?;
38 let mut grp = Group::new(delim, contents.stream);
39 grp.set_span(span);
40 proc_macro2::TokenTree::Group(grp)
41 }
42 proc_macro2::TokenTree::Ident(ref ident) => {
43 if ident == "goto" {
44 T::transform_goto(input)?
45 } else if ident == "safe_goto" {
46 return Err(syn::Error::new(
47 ident.span(),
48 "using safe_goto inside safe_goto is not allowed",
49 ));
50 } else {
51 proc_macro2::TokenTree::Ident(ident.clone())
52 }
53 }
54 tt => tt,
55 };
56 tokens.append(tt);
57 }
58 Ok(GotoBlockContents {
59 stream: tokens,
60 goto_semantics: PhantomData,
61 })
62 }
63}
64
65impl<T: GotoSemantics> ToTokens for GotoBlock<T> {
66 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
67 let GotoBlock {
68 contents,
69 delimiter,
70 } = self;
71 tokens.append(Group::new(*delimiter, contents.stream.clone()));
72 }
73}
74
75#[derive(Debug)]
77struct GotoBlock<T: GotoSemantics> {
78 delimiter: Delimiter,
79 contents: GotoBlockContents<T>,
80}
81
82impl<T: GotoSemantics> From<GotoBlock<T>> for Group {
83 fn from(gtb: GotoBlock<T>) -> Self {
84 Group::new(gtb.delimiter, gtb.contents.stream)
85 }
86}
87
88impl<T: GotoSemantics> Parse for GotoBlock<T> {
89 fn parse(input: ParseStream) -> syn::Result<Self> {
90 let group: Group = input.parse()?;
91 let delimiter = group.delimiter();
92 let contents: GotoBlockContents<T> = syn::parse2(group.stream())?;
93 Ok(GotoBlock {
94 delimiter,
95 contents,
96 })
97 }
98}
99
100struct VariantArgsDelimited {
102 contents: Punctuated<PatType, Token!(,)>,
103}
104
105impl Parse for VariantArgsDelimited {
106 fn parse(input: ParseStream) -> syn::Result<Self> {
107 let group: Group = input.parse()?;
108 let contents = if group.delimiter() == Delimiter::Parenthesis {
109 let parser = Punctuated::<FnArg, Token![,]>::parse_terminated;
110 parser.parse2(group.stream())?
111 } else {
112 return Err(syn::Error::new(group.span_open(), "expected `(`"));
113 };
114 let mut new_contents = Punctuated::<PatType, Token!(,)>::new();
115 for pair in contents.pairs() {
116 if let FnArg::Typed(pat) = pair.value() {
117 new_contents.push_value(pat.clone())
118 } else {
119 return Err(syn::Error::new(contents.span(), "unexpected `self`"));
120 }
121 if let Some(&&punct) = pair.punct() {
122 new_contents.push_punct(punct)
123 }
124 }
125 Ok(VariantArgsDelimited {
126 contents: new_contents,
127 })
128 }
129}
130
131impl ToTokens for VariantArgsDelimited {
132 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
133 if !self.contents.is_empty() {
134 let args = &self.contents;
135 tokens.append_all(quote!(
136 (#args)
137 ))
138 }
139 }
140}
141
142struct GotoBranch<T: GotoSemantics> {
144 id: Ident,
145 block: GotoBlock<T>,
146 variant_args: VariantArgsDelimited,
147}
148
149impl<T: GotoSemantics> Parse for GotoBranch<T> {
150 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
151 let id = input.parse()?;
152 let variant_args = input.parse()?;
153 let block = input.parse()?;
154 Ok(GotoBranch {
155 id,
156 block,
157 variant_args,
158 })
159 }
160}
161
162struct VariantTypesDelimited {
164 contents: Punctuated<Box<syn::Type>, Token!(,)>,
165}
166
167impl ToTokens for VariantTypesDelimited {
168 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
169 if !self.contents.is_empty() {
170 let args = &self.contents;
171 tokens.append_all(quote!(
172 (#args)
173 ))
174 }
175 }
176}
177
178struct VariantPatsDelimited {
180 contents: Punctuated<Box<Pat>, Token!(,)>,
181}
182
183impl ToTokens for VariantPatsDelimited {
184 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
185 if !self.contents.is_empty() {
186 let args = &self.contents;
187 tokens.append_all(quote!(
188 (#args)
189 ))
190 }
191 }
192}
193
194struct Initial;
196impl GotoSemantics for Initial {
197 fn transform_goto(input: ParseStream) -> syn::Result<proc_macro2::TokenTree> {
198 let id: Ident = input
199 .parse()
200 .map_err(|e| syn::Error::new(e.span(), "Invalid syntax for goto statement"))?;
201 let variant = pascalize(&id).clone();
202 let call: Group = input.parse()?;
203 if call.delimiter() != Delimiter::Parenthesis {
204 return Err(syn::Error::new(call.span_open(), "expected `(`"));
205 }
206 let call = if call.stream().is_empty() {
207 proc_macro2::TokenStream::new()
208 } else {
209 quote!(#call)
210 };
211 Ok(syn::parse2(quote!(
212 {
213 break 'goto States::#variant #call
214 }
215 ))
216 .expect("This should parse as a group"))
217 }
218}
219
220struct Other;
222impl GotoSemantics for Other {
223 fn transform_goto(input: ParseStream) -> syn::Result<proc_macro2::TokenTree> {
224 let id: Ident = input
225 .parse()
226 .map_err(|e| syn::Error::new(e.span(), "Invalid syntax for goto statement"))?;
227 let variant = pascalize(&id).clone();
228 let call: Group = input.parse()?;
229 if call.delimiter() != Delimiter::Parenthesis {
230 return Err(syn::Error::new(call.span_open(), "expected `(`"));
231 }
232 let call = if call.stream().is_empty() {
233 proc_macro2::TokenStream::new()
234 } else {
235 quote!(#call)
236 };
237 Ok(syn::parse2(quote!(
238 {
239 goto = States::#variant #call;
240 continue 'goto
241 }
242 ))
243 .expect("This should parse as a group"))
244 }
245}
246
247struct SafeGoto {
249 begin_branch: GotoBranch<Initial>,
250 branches: Punctuated<GotoBranch<Other>, Token!(,)>,
251}
252
253impl SafeGoto {
254 fn idents(&self) -> impl Iterator<Item = &Ident> {
255 self.branches.iter().map(|branch| &branch.id)
256 }
257
258 fn variant_types(&self) -> impl Iterator<Item = VariantTypesDelimited> + '_ {
259 self.branches.iter().map(|branch| {
260 let mut ret = Punctuated::new();
261 for pair in branch.variant_args.contents.pairs() {
262 ret.push_value(pair.value().ty.clone());
263 if let Some(&&punct) = pair.punct() {
264 ret.push_punct(punct)
265 }
266 }
267 VariantTypesDelimited { contents: ret }
268 })
269 }
270
271 fn variant_pats(&self) -> impl Iterator<Item = VariantPatsDelimited> + '_ {
272 self.branches.iter().map(|branch| {
273 let mut ret = Punctuated::new();
274 for pair in branch.variant_args.contents.pairs() {
275 ret.push_value(pair.value().pat.clone());
276 if let Some(&&punct) = pair.punct() {
277 ret.push_punct(punct)
278 }
279 }
280 VariantPatsDelimited { contents: ret }
281 })
282 }
283
284 fn blocks(&self) -> impl Iterator<Item = &GotoBlock<Other>> {
285 self.branches.iter().map(|branch| &branch.block)
286 }
287
288 fn begin_block(&self) -> &GotoBlock<Initial> {
289 &self.begin_branch.block
290 }
291}
292
293impl Parse for SafeGoto {
294 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
295 let begin_branch = input.parse()?;
296 if input.peek(Token!(,)) {
297 let _comma: Token!(,) = input.parse()?;
298 let ret = SafeGoto {
299 begin_branch,
300 branches: input
301 .parse_terminated::<GotoBranch<Other>, Token!(,)>(GotoBranch::parse)?,
302 };
303 let lifetimes: Vec<_> = ret.idents().collect();
304 for i in 0..lifetimes.len() {
305 if lifetimes[i + 1..].contains(&lifetimes[i]) {
306 return Err(syn::Error::new(
307 lifetimes[i].span(),
308 "block label occurs more than once",
309 ));
310 }
311 }
312 Ok(ret)
313 } else {
314 Ok(SafeGoto {
315 begin_branch,
316 branches: Punctuated::new(),
317 })
318 }
319 }
320}
321
322#[proc_macro]
369pub fn safe_goto(t: TokenStream) -> TokenStream {
370 let input = parse_macro_input!(t as SafeGoto);
371 if input.idents().any(|id| id == "begin") {
372 return syn::Error::new(Span::call_site(), "`begin` block should be first")
373 .to_compile_error()
374 .into();
375 }
376 let states_enum = Ident::new("States", Span::call_site());
377 let variants: Vec<_> = input.idents().map(pascalize).collect();
378 let variant_pats = input.variant_pats();
379 let variant_types = input.variant_types();
380 let blocks = input.blocks();
381 let begin_branch = input.begin_block();
382 quote!(
383 {
384 'outer_goto: {enum #states_enum {
385 #(#variants #variant_types),*
386 }
387 let mut goto: #states_enum = 'goto: {
388 let break_val = #begin_branch;
389 #[allow(unreachable_code)]
390 {break 'outer_goto break_val;}
391 };
392
393 'goto: loop {
394 let ret = match goto {
395 #(#states_enum::#variants #variant_pats => #blocks),*
396 };
397 #[allow(unreachable_code)]
398 {break ret}
399 }}
400 }
401 )
402 .into()
403}