safe_goto/
lib.rs

1use std::marker::PhantomData;
2
3use heck::AsPascalCase;
4use proc_macro::TokenStream;
5use proc_macro2::{Delimiter, Group, Span};
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{
8    parse::{Parse, ParseStream, Parser},
9    parse_macro_input,
10    punctuated::Punctuated,
11    spanned::Spanned,
12    FnArg, Ident, Pat, PatType, Token,
13};
14
15fn pascalize(ident: &Ident) -> Ident {
16    Ident::new(&AsPascalCase(&ident.to_string()).to_string(), ident.span())
17}
18
19trait GotoSemantics {
20    fn transform_goto(input: ParseStream) -> syn::Result<proc_macro2::TokenTree>;
21}
22
23#[derive(Debug)]
24struct GotoBlockContents<T: GotoSemantics> {
25    stream: proc_macro2::TokenStream,
26    goto_semantics: PhantomData<T>,
27}
28
29impl<T: GotoSemantics> Parse for GotoBlockContents<T> {
30    fn parse(input: ParseStream) -> syn::Result<Self> {
31        let mut tokens = proc_macro2::TokenStream::new();
32        while let Ok(token) = input.parse::<proc_macro2::TokenTree>() {
33            let tt = match token {
34                proc_macro2::TokenTree::Group(grp) => {
35                    let delim = grp.delimiter();
36                    let span = grp.span();
37                    let contents: GotoBlockContents<T> = syn::parse2(grp.stream())?;
38                    let mut grp = Group::new(delim, contents.stream);
39                    grp.set_span(span);
40                    proc_macro2::TokenTree::Group(grp)
41                }
42                proc_macro2::TokenTree::Ident(ref ident) => {
43                    if ident == "goto" {
44                        T::transform_goto(input)?
45                    } else if ident == "safe_goto" {
46                        return Err(syn::Error::new(
47                            ident.span(),
48                            "using safe_goto inside safe_goto is not allowed",
49                        ));
50                    } else {
51                        proc_macro2::TokenTree::Ident(ident.clone())
52                    }
53                }
54                tt => tt,
55            };
56            tokens.append(tt);
57        }
58        Ok(GotoBlockContents {
59            stream: tokens,
60            goto_semantics: PhantomData,
61        })
62    }
63}
64
65impl<T: GotoSemantics> ToTokens for GotoBlock<T> {
66    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
67        let GotoBlock {
68            contents,
69            delimiter,
70        } = self;
71        tokens.append(Group::new(*delimiter, contents.stream.clone()));
72    }
73}
74
75/// A possibly invalid Rust block possibly containing goto statements
76#[derive(Debug)]
77struct GotoBlock<T: GotoSemantics> {
78    delimiter: Delimiter,
79    contents: GotoBlockContents<T>,
80}
81
82impl<T: GotoSemantics> From<GotoBlock<T>> for Group {
83    fn from(gtb: GotoBlock<T>) -> Self {
84        Group::new(gtb.delimiter, gtb.contents.stream)
85    }
86}
87
88impl<T: GotoSemantics> Parse for GotoBlock<T> {
89    fn parse(input: ParseStream) -> syn::Result<Self> {
90        let group: Group = input.parse()?;
91        let delimiter = group.delimiter();
92        let contents: GotoBlockContents<T> = syn::parse2(group.stream())?;
93        Ok(GotoBlock {
94            delimiter,
95            contents,
96        })
97    }
98}
99
100/// Comma separated list of typed patterns used as arguments for each goto block
101struct VariantArgsDelimited {
102    contents: Punctuated<PatType, Token!(,)>,
103}
104
105impl Parse for VariantArgsDelimited {
106    fn parse(input: ParseStream) -> syn::Result<Self> {
107        let group: Group = input.parse()?;
108        let contents = if group.delimiter() == Delimiter::Parenthesis {
109            let parser = Punctuated::<FnArg, Token![,]>::parse_terminated;
110            parser.parse2(group.stream())?
111        } else {
112            return Err(syn::Error::new(group.span_open(), "expected `(`"));
113        };
114        let mut new_contents = Punctuated::<PatType, Token!(,)>::new();
115        for pair in contents.pairs() {
116            if let FnArg::Typed(pat) = pair.value() {
117                new_contents.push_value(pat.clone())
118            } else {
119                return Err(syn::Error::new(contents.span(), "unexpected `self`"));
120            }
121            if let Some(&&punct) = pair.punct() {
122                new_contents.push_punct(punct)
123            }
124        }
125        Ok(VariantArgsDelimited {
126            contents: new_contents,
127        })
128    }
129}
130
131impl ToTokens for VariantArgsDelimited {
132    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
133        if !self.contents.is_empty() {
134            let args = &self.contents;
135            tokens.append_all(quote!(
136                (#args)
137            ))
138        }
139    }
140}
141
142/// A branch that can be a target of a goto statement
143struct GotoBranch<T: GotoSemantics> {
144    id: Ident,
145    block: GotoBlock<T>,
146    variant_args: VariantArgsDelimited,
147}
148
149impl<T: GotoSemantics> Parse for GotoBranch<T> {
150    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
151        let id = input.parse()?;
152        let variant_args = input.parse()?;
153        let block = input.parse()?;
154        Ok(GotoBranch {
155            id,
156            block,
157            variant_args,
158        })
159    }
160}
161
162/// Comma separated list of types that are arguments to a goto branch. Used for constructing enum
163struct VariantTypesDelimited {
164    contents: Punctuated<Box<syn::Type>, Token!(,)>,
165}
166
167impl ToTokens for VariantTypesDelimited {
168    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
169        if !self.contents.is_empty() {
170            let args = &self.contents;
171            tokens.append_all(quote!(
172                (#args)
173            ))
174        }
175    }
176}
177
178/// Comma separated list of patterns that are inputs to a goto branch. Used for matching
179struct VariantPatsDelimited {
180    contents: Punctuated<Box<Pat>, Token!(,)>,
181}
182
183impl ToTokens for VariantPatsDelimited {
184    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
185        if !self.contents.is_empty() {
186            let args = &self.contents;
187            tokens.append_all(quote!(
188                (#args)
189            ))
190        }
191    }
192}
193
194/// Strategy for transforming goto statements in the begin branch
195struct Initial;
196impl GotoSemantics for Initial {
197    fn transform_goto(input: ParseStream) -> syn::Result<proc_macro2::TokenTree> {
198        let id: Ident = input
199            .parse()
200            .map_err(|e| syn::Error::new(e.span(), "Invalid syntax for goto statement"))?;
201        let variant = pascalize(&id).clone();
202        let call: Group = input.parse()?;
203        if call.delimiter() != Delimiter::Parenthesis {
204            return Err(syn::Error::new(call.span_open(), "expected `(`"));
205        }
206        let call = if call.stream().is_empty() {
207            proc_macro2::TokenStream::new()
208        } else {
209            quote!(#call)
210        };
211        Ok(syn::parse2(quote!(
212            {
213                break 'goto States::#variant #call
214            }
215        ))
216        .expect("This should parse as a group"))
217    }
218}
219
220/// Strategy for transforming goto "statements" in the proper goto branches
221struct Other;
222impl GotoSemantics for Other {
223    fn transform_goto(input: ParseStream) -> syn::Result<proc_macro2::TokenTree> {
224        let id: Ident = input
225            .parse()
226            .map_err(|e| syn::Error::new(e.span(), "Invalid syntax for goto statement"))?;
227        let variant = pascalize(&id).clone();
228        let call: Group = input.parse()?;
229        if call.delimiter() != Delimiter::Parenthesis {
230            return Err(syn::Error::new(call.span_open(), "expected `(`"));
231        }
232        let call = if call.stream().is_empty() {
233            proc_macro2::TokenStream::new()
234        } else {
235            quote!(#call)
236        };
237        Ok(syn::parse2(quote!(
238            {
239                goto = States::#variant #call;
240                continue 'goto
241            }
242        ))
243        .expect("This should parse as a group"))
244    }
245}
246
247/// half-parsed valid input of the `safe_goto` macro
248struct SafeGoto {
249    begin_branch: GotoBranch<Initial>,
250    branches: Punctuated<GotoBranch<Other>, Token!(,)>,
251}
252
253impl SafeGoto {
254    fn idents(&self) -> impl Iterator<Item = &Ident> {
255        self.branches.iter().map(|branch| &branch.id)
256    }
257
258    fn variant_types(&self) -> impl Iterator<Item = VariantTypesDelimited> + '_ {
259        self.branches.iter().map(|branch| {
260            let mut ret = Punctuated::new();
261            for pair in branch.variant_args.contents.pairs() {
262                ret.push_value(pair.value().ty.clone());
263                if let Some(&&punct) = pair.punct() {
264                    ret.push_punct(punct)
265                }
266            }
267            VariantTypesDelimited { contents: ret }
268        })
269    }
270
271    fn variant_pats(&self) -> impl Iterator<Item = VariantPatsDelimited> + '_ {
272        self.branches.iter().map(|branch| {
273            let mut ret = Punctuated::new();
274            for pair in branch.variant_args.contents.pairs() {
275                ret.push_value(pair.value().pat.clone());
276                if let Some(&&punct) = pair.punct() {
277                    ret.push_punct(punct)
278                }
279            }
280            VariantPatsDelimited { contents: ret }
281        })
282    }
283
284    fn blocks(&self) -> impl Iterator<Item = &GotoBlock<Other>> {
285        self.branches.iter().map(|branch| &branch.block)
286    }
287
288    fn begin_block(&self) -> &GotoBlock<Initial> {
289        &self.begin_branch.block
290    }
291}
292
293impl Parse for SafeGoto {
294    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
295        let begin_branch = input.parse()?;
296        if input.peek(Token!(,)) {
297            let _comma: Token!(,) = input.parse()?;
298            let ret = SafeGoto {
299                begin_branch,
300                branches: input
301                    .parse_terminated::<GotoBranch<Other>, Token!(,)>(GotoBranch::parse)?,
302            };
303            let lifetimes: Vec<_> = ret.idents().collect();
304            for i in 0..lifetimes.len() {
305                if lifetimes[i + 1..].contains(&lifetimes[i]) {
306                    return Err(syn::Error::new(
307                        lifetimes[i].span(),
308                        "block label occurs more than once",
309                    ));
310                }
311            }
312            Ok(ret)
313        } else {
314            Ok(SafeGoto {
315                begin_branch,
316                branches: Punctuated::new(),
317            })
318        }
319    }
320}
321
322/// Executes the contained Rust code with possibly irreducible control flow
323///
324/// # Example
325/// ```
326/// use safe_goto::safe_goto;
327/// safe_goto!{
328///     begin() {
329///         goto s1(3)
330///     },
331///     s1(n: i32) {
332///         n + 1
333///     }
334/// };
335/// ```
336/// The invocation above generates the following code:
337/// ```
338/// 'outer_goto: {
339///     enum States {
340///         S1(i32)
341///     }
342///     let mut goto: States = 'goto: {
343///         let break_val = {break 'goto States::S1(3)};
344///         break 'outer_goto break_val;
345///     };
346///     'goto: loop {
347///         let ret = match goto {
348///             States::S1(n) => {
349///                 n + 1
350///             }
351///         };
352///         break ret;
353///     }
354/// };
355/// ```
356///
357/// There must be a begin block with no arguments. The begin block cannot be
358/// a target of a goto, but can be used for one-time moves.
359/// Nested safe_goto's are not allowed,
360/// though function calls can be used to get around this limitation.
361/// Execution that exits any of the goto blocks will return from the macro
362/// with the value at the end of the final block executed.
363///
364/// # Safety
365///
366/// The macro does not generate unsafe code unless given unsafe code as input.
367/// There are no guarantees for how the macro will interact with unsafe code.
368#[proc_macro]
369pub fn safe_goto(t: TokenStream) -> TokenStream {
370    let input = parse_macro_input!(t as SafeGoto);
371    if input.idents().any(|id| id == "begin") {
372        return syn::Error::new(Span::call_site(), "`begin` block should be first")
373            .to_compile_error()
374            .into();
375    }
376    let states_enum = Ident::new("States", Span::call_site());
377    let variants: Vec<_> = input.idents().map(pascalize).collect();
378    let variant_pats = input.variant_pats();
379    let variant_types = input.variant_types();
380    let blocks = input.blocks();
381    let begin_branch = input.begin_block();
382    quote!(
383        {
384            'outer_goto: {enum #states_enum {
385                #(#variants #variant_types),*
386            }
387            let mut goto: #states_enum = 'goto: {
388                let break_val = #begin_branch;
389                #[allow(unreachable_code)]
390                {break 'outer_goto break_val;}
391            };
392
393            'goto: loop {
394                let ret = match goto {
395                    #(#states_enum::#variants #variant_pats => #blocks),*
396                };
397                #[allow(unreachable_code)]
398                {break ret}
399            }}
400        }
401    )
402    .into()
403}