1use culpa::{throw, throws};
2use proc_macro2::Span;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use quote::ToTokens;
6use syn::parse_quote;
7use syn::Block;
8use syn::{
9 parse::Parse, Attribute, Error, FnArg, Ident, Pat, PathArguments, ReturnType, Signature, Type,
10 Visibility,
11};
12
13use std::fmt::Write;
14
15pub(crate) fn ty_is_borrow_str(ty: &Type) -> bool {
16 if let Type::Reference(ty) = ty {
17 if ty.mutability.is_none() && ty.lifetime.is_none() {
18 if let Type::Path(pp) = &*ty.elem {
19 pp.path.is_ident("str")
20 } else {
21 false
23 }
24 } else {
25 false
27 }
28 } else {
29 false
31 }
32}
33
34pub(crate) fn ty_is_borrow_path(ty: &Type) -> bool {
35 if let Type::Reference(ty) = ty {
36 if ty.mutability.is_none() && ty.lifetime.is_none() {
37 if let Type::Path(pp) = &*ty.elem {
38 pp.path.is_ident("Path")
39 } else {
40 false
42 }
43 } else {
44 false
46 }
47 } else {
48 false
50 }
51}
52
53pub(crate) fn ty_is_datafile(ty: &Type) -> bool {
54 if let Type::Path(ty) = ty {
55 ty.path.is_ident("SubplotDataFile")
56 } else {
57 false
58 }
59}
60
61pub(crate) fn ty_is_scenariocontext(ty: &Type) -> bool {
62 if let Type::Path(ty) = ty {
63 ty.path.is_ident("ScenarioContext")
64 } else {
65 false
66 }
67}
68
69#[throws(Error)]
70pub(crate) fn ty_as_path(ty: &Type) -> String {
71 if let Type::Path(p) = ty {
72 let mut ret = String::new();
73 let mut colons = p.path.leading_colon.is_some();
74 for seg in &p.path.segments {
75 if !matches!(seg.arguments, PathArguments::None) {
76 throw!(Error::new_spanned(seg, "unexpected path segment arguments"));
77 }
78 if colons {
79 ret.push_str("::");
80 }
81 colons = true;
82 ret.push_str(&seg.ident.to_string());
83 }
84 ret
85 } else {
86 throw!(Error::new_spanned(ty, "expected a type path"));
87 }
88}
89
90#[throws(Error)]
91pub(crate) fn check_step_declaration(step: &StepFn) {
92 let sig = &step.sig;
106 if let Some(syncness) = sig.asyncness.as_ref() {
107 throw!(Error::new_spanned(
108 syncness,
109 "Step functions may not be async",
110 ));
111 }
112 if let Some(unsafeness) = sig.unsafety.as_ref() {
113 throw!(Error::new_spanned(
114 unsafeness,
115 "Step functions may not be unsafe",
116 ));
117 }
118 if let Some(abi) = sig.abi.as_ref() {
119 throw!(Error::new_spanned(
120 abi,
121 "Step functions may not specify an ABI",
122 ));
123 }
124 if !matches!(sig.output, ReturnType::Default) {
125 throw!(Error::new_spanned(
126 &sig.output,
127 "Step functions may not specify a return value",
128 ));
129 }
130 if let Some(variadic) = sig.variadic.as_ref() {
131 throw!(Error::new_spanned(
132 variadic,
133 "Step functions may not be variadic",
134 ));
135 }
136 if !sig.generics.params.is_empty() || sig.generics.where_clause.is_some() {
137 throw!(Error::new_spanned(
138 &sig.generics,
139 "Step functions may not be generic",
140 ));
141 }
142 if let Some(arg) = sig.inputs.first() {
143 if let FnArg::Typed(pat) = arg {
144 if let Type::Reference(tr) = &*pat.ty {
145 if let Some(lifetime) = tr.lifetime.as_ref() {
146 throw!(Error::new_spanned(
147 lifetime,
148 "Step function context borrow should not be given a lifetime marker",
149 ));
150 }
151 } else {
152 throw!(Error::new_spanned(
153 pat,
154 "Step function context must be taken as a borrow",
155 ));
156 }
157 } else {
158 throw!(Error::new_spanned(
159 arg,
160 "Step functions do not take a method receiver",
161 ));
162 }
163 } else {
164 throw!(Error::new_spanned(
165 &sig.inputs,
166 "Step functions must have at least 1 argument (context)",
167 ));
168 }
169}
170
171#[throws(Error)]
172pub(crate) fn process_step(mut input: StepFn) -> proc_macro2::TokenStream {
173 let vis = input.vis.clone();
183 let stepname = input.sig.ident.clone();
184 let mutablectx = {
185 if let FnArg::Typed(pt) = &input.sig.inputs[0] {
186 if let Type::Reference(pp) = &*pt.ty {
187 pp.mutability.is_some()
188 } else {
189 unreachable!()
190 }
191 } else {
192 unreachable!()
193 }
194 };
195
196 let contexttype = if let Some(ty) = input.sig.inputs.first() {
197 match ty {
198 FnArg::Typed(pt) => {
199 if let Type::Reference(rt) = &*pt.ty {
200 *(rt.elem).clone()
201 } else {
202 unreachable!()
203 }
204 }
205 _ => unreachable!(),
206 }
207 } else {
208 unreachable!()
209 };
210
211 let contexts: Vec<Type> = input
212 .attrs
213 .iter()
214 .filter(|attr| attr.path().is_ident("context"))
215 .map(|attr| {
216 let ty: Type = attr.parse_args()?;
217 Ok(ty)
218 })
219 .collect::<Result<_, Error>>()?;
220
221 input.attrs.retain(|f| !f.path().is_ident("context"));
222
223 let docs: Vec<_> = input
224 .attrs
225 .iter()
226 .filter(|attr| attr.path().is_ident("doc"))
227 .collect();
228
229 let fields = input
230 .sig
231 .inputs
232 .iter()
233 .skip(1)
234 .map(|a| {
235 if let FnArg::Typed(pat) = a {
236 if let Pat::Ident(ident) = &*pat.pat {
237 if let Some(r) = ident.by_ref.as_ref() {
238 Err(Error::new_spanned(r, "ref not valid here"))
239 } else if let Some(subpat) = ident.subpat.as_ref() {
240 Err(Error::new_spanned(&subpat.1, "subpattern not valid here"))
241 } else {
242 let identstr = ident.ident.to_string();
243 Ok((
244 Ident::new(identstr.trim_start_matches('_'), ident.ident.span()),
245 (*pat.ty).clone(),
246 ))
247 }
248 } else {
249 Err(Error::new_spanned(pat, "expected a simple name here"))
250 }
251 } else {
252 Err(Error::new_spanned(
253 a,
254 "receiver argument unexpected in this position",
255 ))
256 }
257 })
258 .collect::<Result<Vec<_>, _>>()?;
259
260 let structdef = {
261 let structfields: Vec<_> = fields
262 .iter()
263 .map(|(id, ty)| {
264 let ty = if ty_is_borrow_str(ty) {
265 parse_quote!(::std::string::String)
266 } else if ty_is_borrow_path(ty) {
267 parse_quote!(::std::path::PathBuf)
268 } else {
269 ty.clone()
270 };
271 quote! {
272 #id : #ty
273 }
274 })
275 .collect();
276 quote! {
277 #[allow(non_camel_case_types)]
278 #[allow(unused)]
279 #[derive(Default)]
280 #[doc(hidden)]
281 pub struct Builder {
282 #(#structfields),*
283 }
284 }
285 };
286
287 let withfn = if mutablectx {
288 Ident::new("with_mut", Span::call_site())
289 } else {
290 Ident::new("with", Span::call_site())
291 };
292
293 let structimpl = {
294 let fieldfns: Vec<_> = fields
295 .iter()
296 .map(|(id, ty)| {
297 if ty_is_borrow_str(ty) {
298 quote! {
299 pub fn #id(mut self, value: &str) -> Self {
300 self.#id = value.to_string();
301 self
302 }
303 }
304 } else if ty_is_borrow_path(ty) {
305 quote! {
306 pub fn #id<P: Into<std::path::PathBuf>>(mut self, value: P) -> Self {
307 self.#id = value.into();
308 self
309 }
310 }
311 } else {
312 quote! {
313 pub fn #id(mut self, value: #ty) -> Self {
314 self.#id = value;
315 self
316 }
317 }
318 }
319 })
320 .collect();
321
322 let buildargs: Vec<_> = fields
323 .iter()
324 .map(|(id, ty)| {
325 if ty_is_borrow_str(ty) || ty_is_borrow_path(ty) {
326 quote! {
327 &self.#id
328 }
329 } else if ty_is_datafile(ty) {
330 quote! {
331 self.#id.clone()
332 }
333 } else {
334 quote! {
335 self.#id
336 }
337 }
338 })
339 .collect();
340
341 let builder_body = if ty_is_scenariocontext(&contexttype) {
342 quote! {
343 #stepname(ctx,#(#buildargs),*)
344 }
345 } else {
346 quote! {
347 ctx.#withfn (|ctx| #stepname(ctx, #(#buildargs),*), _defuse_poison)
348 }
349 };
350
351 quote! {
352 impl Builder {
353 #(#fieldfns)*
354
355 pub fn build(self, step_text: String, location: &'static str) -> ScenarioStep {
356 ScenarioStep::new(step_text, move |ctx, _defuse_poison|
357 #builder_body,
358 |scenario| register_contexts(scenario),
359 location,
360 )
361 }
362 }
363 }
364 };
365
366 let inputargs: Vec<_> = fields.iter().map(|(i, t)| quote!(#i : #t)).collect();
367 let argnames: Vec<_> = fields.iter().map(|(i, _)| i).collect();
368
369 let call_body = if ty_is_scenariocontext(&contexttype) {
370 quote! {
371 #stepname(___context___,#(#argnames),*)
372 }
373 } else {
374 quote! {
375 ___context___.#withfn (move |ctx| #stepname(ctx, #(#argnames),*),false)
376 }
377 };
378
379 let extra_registers: Vec<_> = contexts
380 .iter()
381 .map(|ty| {
382 quote! {
383 scenario.register_context_type::<#ty>();
384 }
385 })
386 .collect();
387
388 let register_fn_body = if ty_is_scenariocontext(&contexttype) {
389 quote! {
390 #(#extra_registers)*
391 }
392 } else {
393 quote! {
394 scenario.register_context_type::<#contexttype>();
395 #(#extra_registers)*
396 }
397 };
398
399 let call_docs = {
400 let mut contextattrs = String::new();
401 let outer_ctx = if ty_is_scenariocontext(&contexttype) {
402 None
403 } else {
404 Some(&contexttype)
405 };
406 for context in outer_ctx.into_iter().chain(contexts.iter()) {
407 write!(contextattrs, "\n #[context({:?})]", ty_as_path(context)?).unwrap();
408 }
409 let func_args: Vec<_> = fields.iter().map(|(ident, _)| format!("{ident}")).collect();
410 let func_args = func_args.join(", ");
411 format!(
412 r#"
413 Call [this step][self] function from another.
414
415 If you want to call this step function from another, you will
416 need to do something like this:
417
418 ```rust,ignore
419 #[step]{contextattrs}
420 fn defer_to_{stepname}(context: &ScenarioContext) {{
421 //...
422 {stepname}::call(context, {func_args})?;
423 // ...
424 }}
425 ```
426 "#,
427 )
428 };
429 let throws = if input.body_good {
430 quote! {
431 #[throws(StepError)]
432 }
433 } else {
434 quote! {}
435 };
436 let ret = quote! {
437 #(#docs)*
438 #vis mod #stepname {
439 use super::*;
440 pub(crate) use super::#contexttype;
441
442 #structdef
443 #structimpl
444
445 #throws
446 #[allow(dead_code)] #[deny(unused_must_use)]
448 #[doc(hidden)]
449 #input
450
451 #[doc = #call_docs]
452 pub fn call(___context___: &ScenarioContext, #(#inputargs),*) -> StepResult {
453 #call_body
454 }
455
456 #[allow(unused_variables)]
457 #[doc(hidden)]
458 pub fn register_contexts(scenario: &Scenario) {
459 #register_fn_body
460 }
461 }
462 };
463
464 ret
465}
466
467pub(crate) struct StepFn {
469 pub(crate) attrs: Vec<Attribute>,
470 pub(crate) vis: Visibility,
471 pub(crate) sig: Signature,
472 pub(crate) block: TokenStream2,
473 pub(crate) body_good: bool,
474}
475
476impl Parse for StepFn {
477 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
478 let attrs = input.call(Attribute::parse_outer)?;
479 let vis: Visibility = input.parse()?;
480 let sig: Signature = input.parse()?;
481 let block = input.fork().parse()?;
482 let body_good = Block::parse(input).is_ok();
483 Ok(Self {
484 attrs,
485 vis,
486 sig,
487 block,
488 body_good,
489 })
490 }
491}
492
493impl ToTokens for StepFn {
494 fn to_tokens(&self, tokens: &mut TokenStream2) {
495 for attr in &self.attrs {
496 attr.to_tokens(tokens);
497 }
498 self.vis.to_tokens(tokens);
499 if self.body_good {
500 self.sig.to_tokens(tokens);
501 } else {
502 syn::Signature {
503 output: parse_quote!(-> Result<(), StepError>),
504 ..self.sig.clone()
505 }
506 .to_tokens(tokens);
507 }
508 self.block.to_tokens(tokens);
509 }
510}