subplotlib_derive/
step.rs

1use culpa::{throw, throws};
2use proc_macro2::Span;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use quote::ToTokens;
6use syn::parse_quote;
7use syn::Block;
8use syn::{
9    parse::Parse, Attribute, Error, FnArg, Ident, Pat, PathArguments, ReturnType, Signature, Type,
10    Visibility,
11};
12
13use std::fmt::Write;
14
15pub(crate) fn ty_is_borrow_str(ty: &Type) -> bool {
16    if let Type::Reference(ty) = ty {
17        if ty.mutability.is_none() && ty.lifetime.is_none() {
18            if let Type::Path(pp) = &*ty.elem {
19                pp.path.is_ident("str")
20            } else {
21                // not a path, so not &str
22                false
23            }
24        } else {
25            // mutable, or a lifetime stated, so not &str
26            false
27        }
28    } else {
29        // Not & so not &str
30        false
31    }
32}
33
34pub(crate) fn ty_is_borrow_path(ty: &Type) -> bool {
35    if let Type::Reference(ty) = ty {
36        if ty.mutability.is_none() && ty.lifetime.is_none() {
37            if let Type::Path(pp) = &*ty.elem {
38                pp.path.is_ident("Path")
39            } else {
40                // not a path, so not &Path
41                false
42            }
43        } else {
44            // mutable, or a lifetime stated, so not &Path
45            false
46        }
47    } else {
48        // Not & so not &Path
49        false
50    }
51}
52
53pub(crate) fn ty_is_datafile(ty: &Type) -> bool {
54    if let Type::Path(ty) = ty {
55        ty.path.is_ident("SubplotDataFile")
56    } else {
57        false
58    }
59}
60
61pub(crate) fn ty_is_scenariocontext(ty: &Type) -> bool {
62    if let Type::Path(ty) = ty {
63        ty.path.is_ident("ScenarioContext")
64    } else {
65        false
66    }
67}
68
69#[throws(Error)]
70pub(crate) fn ty_as_path(ty: &Type) -> String {
71    if let Type::Path(p) = ty {
72        let mut ret = String::new();
73        let mut colons = p.path.leading_colon.is_some();
74        for seg in &p.path.segments {
75            if !matches!(seg.arguments, PathArguments::None) {
76                throw!(Error::new_spanned(seg, "unexpected path segment arguments"));
77            }
78            if colons {
79                ret.push_str("::");
80            }
81            colons = true;
82            ret.push_str(&seg.ident.to_string());
83        }
84        ret
85    } else {
86        throw!(Error::new_spanned(ty, "expected a type path"));
87    }
88}
89
90#[throws(Error)]
91pub(crate) fn check_step_declaration(step: &StepFn) {
92    // Step functions must be declared very simply as:
93    // fn stepfunctionname(context: &mut Context)
94    // the `mut` is optional, but the type of the context argument must
95    // be a borrow of some kind, its name is not important.
96    // If the step function takes any arguments, then they must come next
97    // and should be named in the usual way.  If the argument starts with
98    // an underscore then that will be stripped during argument conversion
99    // so that if you're just ignoring an argument from your step you can.
100    // Additionally, step functions must **NOT** have a return type declared
101    // and must not be generic, non-rust ABI, unsafe, etc. in any way.
102    // Finally const makes no sense, though we won't deny it for now.
103    // Visibility will be taken into account when constructing the associated
104    // content for the step
105    let sig = &step.sig;
106    if let Some(syncness) = sig.asyncness.as_ref() {
107        throw!(Error::new_spanned(
108            syncness,
109            "Step functions may not be async",
110        ));
111    }
112    if let Some(unsafeness) = sig.unsafety.as_ref() {
113        throw!(Error::new_spanned(
114            unsafeness,
115            "Step functions may not be unsafe",
116        ));
117    }
118    if let Some(abi) = sig.abi.as_ref() {
119        throw!(Error::new_spanned(
120            abi,
121            "Step functions may not specify an ABI",
122        ));
123    }
124    if !matches!(sig.output, ReturnType::Default) {
125        throw!(Error::new_spanned(
126            &sig.output,
127            "Step functions may not specify a return value",
128        ));
129    }
130    if let Some(variadic) = sig.variadic.as_ref() {
131        throw!(Error::new_spanned(
132            variadic,
133            "Step functions may not be variadic",
134        ));
135    }
136    if !sig.generics.params.is_empty() || sig.generics.where_clause.is_some() {
137        throw!(Error::new_spanned(
138            &sig.generics,
139            "Step functions may not be generic",
140        ));
141    }
142    if let Some(arg) = sig.inputs.first() {
143        if let FnArg::Typed(pat) = arg {
144            if let Type::Reference(tr) = &*pat.ty {
145                if let Some(lifetime) = tr.lifetime.as_ref() {
146                    throw!(Error::new_spanned(
147                        lifetime,
148                        "Step function context borrow should not be given a lifetime marker",
149                    ));
150                }
151            } else {
152                throw!(Error::new_spanned(
153                    pat,
154                    "Step function context must be taken as a borrow",
155                ));
156            }
157        } else {
158            throw!(Error::new_spanned(
159                arg,
160                "Step functions do not take a method receiver",
161            ));
162        }
163    } else {
164        throw!(Error::new_spanned(
165            &sig.inputs,
166            "Step functions must have at least 1 argument (context)",
167        ));
168    }
169}
170
171#[throws(Error)]
172pub(crate) fn process_step(mut input: StepFn) -> proc_macro2::TokenStream {
173    // Processing a step involves constructing a step builder for
174    // the function which returns a step object to be passed into the
175    // scenario system
176
177    // A step builder consists of a struct whose fields are of the
178    // appropriate type, a set of pub methods to set those fields
179    // and then a build call which constructs the step instance with
180    // an appropriate closure in it
181
182    let vis = input.vis.clone();
183    let stepname = input.sig.ident.clone();
184    let mutablectx = {
185        if let FnArg::Typed(pt) = &input.sig.inputs[0] {
186            if let Type::Reference(pp) = &*pt.ty {
187                pp.mutability.is_some()
188            } else {
189                unreachable!()
190            }
191        } else {
192            unreachable!()
193        }
194    };
195
196    let contexttype = if let Some(ty) = input.sig.inputs.first() {
197        match ty {
198            FnArg::Typed(pt) => {
199                if let Type::Reference(rt) = &*pt.ty {
200                    *(rt.elem).clone()
201                } else {
202                    unreachable!()
203                }
204            }
205            _ => unreachable!(),
206        }
207    } else {
208        unreachable!()
209    };
210
211    let contexts: Vec<Type> = input
212        .attrs
213        .iter()
214        .filter(|attr| attr.path().is_ident("context"))
215        .map(|attr| {
216            let ty: Type = attr.parse_args()?;
217            Ok(ty)
218        })
219        .collect::<Result<_, Error>>()?;
220
221    input.attrs.retain(|f| !f.path().is_ident("context"));
222
223    let docs: Vec<_> = input
224        .attrs
225        .iter()
226        .filter(|attr| attr.path().is_ident("doc"))
227        .collect();
228
229    let fields = input
230        .sig
231        .inputs
232        .iter()
233        .skip(1)
234        .map(|a| {
235            if let FnArg::Typed(pat) = a {
236                if let Pat::Ident(ident) = &*pat.pat {
237                    if let Some(r) = ident.by_ref.as_ref() {
238                        Err(Error::new_spanned(r, "ref not valid here"))
239                    } else if let Some(subpat) = ident.subpat.as_ref() {
240                        Err(Error::new_spanned(&subpat.1, "subpattern not valid here"))
241                    } else {
242                        let identstr = ident.ident.to_string();
243                        Ok((
244                            Ident::new(identstr.trim_start_matches('_'), ident.ident.span()),
245                            (*pat.ty).clone(),
246                        ))
247                    }
248                } else {
249                    Err(Error::new_spanned(pat, "expected a simple name here"))
250                }
251            } else {
252                Err(Error::new_spanned(
253                    a,
254                    "receiver argument unexpected in this position",
255                ))
256            }
257        })
258        .collect::<Result<Vec<_>, _>>()?;
259
260    let structdef = {
261        let structfields: Vec<_> = fields
262            .iter()
263            .map(|(id, ty)| {
264                let ty = if ty_is_borrow_str(ty) {
265                    parse_quote!(::std::string::String)
266                } else if ty_is_borrow_path(ty) {
267                    parse_quote!(::std::path::PathBuf)
268                } else {
269                    ty.clone()
270                };
271                quote! {
272                    #id : #ty
273                }
274            })
275            .collect();
276        quote! {
277            #[allow(non_camel_case_types)]
278            #[allow(unused)]
279            #[derive(Default)]
280            #[doc(hidden)]
281            pub struct Builder {
282                #(#structfields),*
283            }
284        }
285    };
286
287    let withfn = if mutablectx {
288        Ident::new("with_mut", Span::call_site())
289    } else {
290        Ident::new("with", Span::call_site())
291    };
292
293    let structimpl = {
294        let fieldfns: Vec<_> = fields
295            .iter()
296            .map(|(id, ty)| {
297                if ty_is_borrow_str(ty) {
298                    quote! {
299                        pub fn #id(mut self, value: &str) -> Self {
300                            self.#id = value.to_string();
301                            self
302                        }
303                    }
304                } else if ty_is_borrow_path(ty) {
305                    quote! {
306                        pub fn #id<P: Into<std::path::PathBuf>>(mut self, value: P) -> Self {
307                            self.#id = value.into();
308                            self
309                        }
310                    }
311                } else {
312                    quote! {
313                        pub fn #id(mut self, value: #ty) -> Self {
314                            self.#id = value;
315                            self
316                        }
317                    }
318                }
319            })
320            .collect();
321
322        let buildargs: Vec<_> = fields
323            .iter()
324            .map(|(id, ty)| {
325                if ty_is_borrow_str(ty) || ty_is_borrow_path(ty) {
326                    quote! {
327                       &self.#id
328                    }
329                } else if ty_is_datafile(ty) {
330                    quote! {
331                        self.#id.clone()
332                    }
333                } else {
334                    quote! {
335                        self.#id
336                    }
337                }
338            })
339            .collect();
340
341        let builder_body = if ty_is_scenariocontext(&contexttype) {
342            quote! {
343                #stepname(ctx,#(#buildargs),*)
344            }
345        } else {
346            quote! {
347                ctx.#withfn (|ctx| #stepname(ctx, #(#buildargs),*), _defuse_poison)
348            }
349        };
350
351        quote! {
352            impl Builder {
353                #(#fieldfns)*
354
355                pub fn build(self, step_text: String, location: &'static str) -> ScenarioStep {
356                    ScenarioStep::new(step_text, move |ctx, _defuse_poison|
357                        #builder_body,
358                        |scenario| register_contexts(scenario),
359                        location,
360                    )
361                }
362            }
363        }
364    };
365
366    let inputargs: Vec<_> = fields.iter().map(|(i, t)| quote!(#i : #t)).collect();
367    let argnames: Vec<_> = fields.iter().map(|(i, _)| i).collect();
368
369    let call_body = if ty_is_scenariocontext(&contexttype) {
370        quote! {
371            #stepname(___context___,#(#argnames),*)
372        }
373    } else {
374        quote! {
375            ___context___.#withfn (move |ctx| #stepname(ctx, #(#argnames),*),false)
376        }
377    };
378
379    let extra_registers: Vec<_> = contexts
380        .iter()
381        .map(|ty| {
382            quote! {
383                scenario.register_context_type::<#ty>();
384            }
385        })
386        .collect();
387
388    let register_fn_body = if ty_is_scenariocontext(&contexttype) {
389        quote! {
390            #(#extra_registers)*
391        }
392    } else {
393        quote! {
394            scenario.register_context_type::<#contexttype>();
395            #(#extra_registers)*
396        }
397    };
398
399    let call_docs = {
400        let mut contextattrs = String::new();
401        let outer_ctx = if ty_is_scenariocontext(&contexttype) {
402            None
403        } else {
404            Some(&contexttype)
405        };
406        for context in outer_ctx.into_iter().chain(contexts.iter()) {
407            write!(contextattrs, "\n    #[context({:?})]", ty_as_path(context)?).unwrap();
408        }
409        let func_args: Vec<_> = fields.iter().map(|(ident, _)| format!("{ident}")).collect();
410        let func_args = func_args.join(", ");
411        format!(
412            r#"
413    Call [this step][self] function from another.
414
415    If you want to call this step function from another, you will
416    need to do something like this:
417
418    ```rust,ignore
419    #[step]{contextattrs}
420    fn defer_to_{stepname}(context: &ScenarioContext) {{
421        //...
422        {stepname}::call(context, {func_args})?;
423        // ...
424    }}
425    ```
426    "#,
427        )
428    };
429    let throws = if input.body_good {
430        quote! {
431            #[throws(StepError)]
432        }
433    } else {
434        quote! {}
435    };
436    let ret = quote! {
437        #(#docs)*
438        #vis mod #stepname {
439            use super::*;
440            pub(crate) use super::#contexttype;
441
442            #structdef
443            #structimpl
444
445            #throws
446            #[allow(dead_code)] // It's okay for step functions to not be used
447            #[deny(unused_must_use)]
448            #[doc(hidden)]
449            #input
450
451            #[doc = #call_docs]
452            pub fn call(___context___: &ScenarioContext, #(#inputargs),*) -> StepResult {
453                #call_body
454            }
455
456            #[allow(unused_variables)]
457            #[doc(hidden)]
458            pub fn register_contexts(scenario: &Scenario) {
459                #register_fn_body
460            }
461        }
462    };
463
464    ret
465}
466
467// This is essentially syn::ItemFn only where the body is not parsed
468pub(crate) struct StepFn {
469    pub(crate) attrs: Vec<Attribute>,
470    pub(crate) vis: Visibility,
471    pub(crate) sig: Signature,
472    pub(crate) block: TokenStream2,
473    pub(crate) body_good: bool,
474}
475
476impl Parse for StepFn {
477    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
478        let attrs = input.call(Attribute::parse_outer)?;
479        let vis: Visibility = input.parse()?;
480        let sig: Signature = input.parse()?;
481        let block = input.fork().parse()?;
482        let body_good = Block::parse(input).is_ok();
483        Ok(Self {
484            attrs,
485            vis,
486            sig,
487            block,
488            body_good,
489        })
490    }
491}
492
493impl ToTokens for StepFn {
494    fn to_tokens(&self, tokens: &mut TokenStream2) {
495        for attr in &self.attrs {
496            attr.to_tokens(tokens);
497        }
498        self.vis.to_tokens(tokens);
499        if self.body_good {
500            self.sig.to_tokens(tokens);
501        } else {
502            syn::Signature {
503                output: parse_quote!(-> Result<(), StepError>),
504                ..self.sig.clone()
505            }
506            .to_tokens(tokens);
507        }
508        self.block.to_tokens(tokens);
509    }
510}