Skip to main content

taskflow_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::Parser,
5    parse_macro_input,
6    parse_quote,
7    spanned::Spanned,
8    FnArg,
9    ImplItem,
10    ImplItemFn,
11    ItemImpl,
12    LitStr,
13    Pat,
14    ReturnType,
15    Type,
16};
17
18/// Turn an inherent `impl Block { fn run(...) -> ... }` into a
19/// `SyncTask`/`AsyncTask` trait implementation for the taskflow scheduler.
20///
21/// ## Accepted `run` signatures
22///
23/// ```ignore
24/// // 1. No inputs, no context (source task).
25/// fn run(self) -> Out;
26///
27/// // 2. Only DAG inputs.
28/// fn run(self, a: &A, b: &B) -> Out;
29///
30/// // 3. With the runtime FlowContext. MUST be the first non-self parameter,
31/// //    typed as `&FlowContext` (match by trailing path segment, so any
32/// //    import alias that still names the type `FlowContext` works).
33/// fn run(self, ctx: &FlowContext, a: &A) -> Out;
34/// ```
35///
36/// DAG inputs must be shared references `&T` (the scheduler stores outputs as
37/// `Arc<T>` and hands out a borrow). Owned and `&mut` parameters are rejected.
38///
39/// ## Context injection details
40///
41/// The generated trait impl always takes `ctx: &FlowContext`. If the user did
42/// not declare one, the generated body discards it with `let _ = ctx;`. If the
43/// user did declare one, it is forwarded as the first argument to the inherent
44/// `run` call. Nothing else in the user's signature changes.
45///
46/// ## The `path = "..."` attribute
47///
48/// When the macro is used outside the taskflow crate itself, pass
49/// `path = "::taskflow"` (or the relevant re-export root) so the generated
50/// code can refer to the runtime traits. Inside the taskflow crate the
51/// default `crate` path is used.
52#[proc_macro_attribute]
53pub fn sync_task(attr: TokenStream, item: TokenStream) -> TokenStream {
54    expand_task(attr, item, false)
55}
56
57/// Async counterpart of [`macro@sync_task`]. The `run` method must be
58/// `async fn` and all the rules about parameters (shared references, optional
59/// leading `ctx: &FlowContext`) are identical.
60#[proc_macro_attribute]
61pub fn async_task(attr: TokenStream, item: TokenStream) -> TokenStream {
62    expand_task(attr, item, true)
63}
64
65fn expand_task(attr: TokenStream, item: TokenStream, expect_async: bool) -> TokenStream {
66    let input_impl = parse_macro_input!(item as ItemImpl);
67    let root_path = match parse_root_path(attr) {
68        Ok(path) => path,
69        Err(err) => return err.to_compile_error().into(),
70    };
71
72    match build_task_impl(&input_impl, expect_async, &root_path) {
73        Ok(expanded) => TokenStream::from(quote! {
74            #input_impl
75            #expanded
76        }),
77        Err(err) => err.to_compile_error().into(),
78    }
79}
80
81fn parse_root_path(attr: TokenStream) -> core::result::Result <syn::Path, syn::Error> {
82    if attr.is_empty() {
83        return Ok(parse_quote!(crate));
84    }
85
86    let mut parsed_path = None::<syn::Path>;
87    let parser = syn::meta::parser(|meta| {
88        if meta.path.is_ident("path") {
89            let lit: LitStr = meta.value()?.parse()?;
90            parsed_path = Some(lit.parse()?);
91            Ok(())
92        } else {
93            Err(meta.error("unsupported argument; expected `path = \"::taskflow\"`"))
94        }
95    });
96
97    parser.parse2(proc_macro2::TokenStream::from(attr))?;
98
99    parsed_path.ok_or_else(|| {
100        syn::Error::new(
101            proc_macro2::Span::call_site(),
102            "missing `path` argument; expected `path = \"::taskflow\"`",
103        )
104    })
105}
106
107fn build_task_impl(
108    input_impl: &ItemImpl,
109    expect_async: bool,
110    root_path: &syn::Path,
111) -> core::result::Result <proc_macro2::TokenStream, syn::Error> {
112    let self_ty = &input_impl.self_ty;
113    let run_fn = find_run_fn(input_impl)?;
114
115    if run_fn.sig.asyncness.is_some() != expect_async {
116        let msg = if expect_async {
117            "#[async_task] requires `async fn run(...)`"
118        } else {
119            "#[sync_task] requires non-async `fn run(...)`"
120        };
121        return Err(syn::Error::new(run_fn.sig.span(), msg));
122    }
123
124    let (receiver_kind, has_ctx, arg_infos) = parse_signature(run_fn)?;
125    let input_ty = build_input_type(&arg_infos);
126    let output_ty = match &run_fn.sig.output {
127        ReturnType::Default => {
128            return Err(syn::Error::new(
129                run_fn.sig.span(),
130                "run method must have an explicit return type",
131            ))
132        }
133        ReturnType::Type(_, ty) => ty.clone(),
134    };
135
136    let destructure = build_destructure(&arg_infos);
137    let call_args: Vec<_> = arg_infos.iter().map(|arg| arg.call_expr.clone()).collect();
138    let (receiver_setup, call_expr) =
139        build_inherent_call(self_ty, receiver_kind, has_ctx, &call_args);
140
141    // If the user's `run` does not declare `ctx: &FlowContext`, we still must
142    // accept it in the generated trait impl — silence the unused warning with
143    // a discard binding.
144    let ctx_discard = if has_ctx {
145        quote! {}
146    } else {
147        quote! { let _ = __tf_ctx; }
148    };
149
150    let trait_name = if expect_async {
151        quote! { #root_path::tf::traits::AsyncTask }
152    } else {
153        quote! { #root_path::tf::traits::SyncTask }
154    };
155
156    let run_method = if expect_async {
157        quote! {
158            fn run(
159                self,
160                __tf_ctx: &#root_path::tf::component_registry::FlowContext,
161                input: #root_path::tf::task::TaskInput<Self::Input>,
162            ) -> impl std::future::Future<Output = #root_path::tf::task::TaskOutput<Self::Output>> + Send {
163                async move {
164                    #ctx_discard
165                    #destructure
166                    #receiver_setup
167                    #root_path::tf::task::TaskOutput(#call_expr.await)
168                }
169            }
170        }
171    } else {
172        quote! {
173            fn run(
174                self,
175                __tf_ctx: &#root_path::tf::component_registry::FlowContext,
176                input: #root_path::tf::task::TaskInput<Self::Input>,
177            ) -> #root_path::tf::task::TaskOutput<Self::Output> {
178                #ctx_discard
179                #destructure
180                #receiver_setup
181                #root_path::tf::task::TaskOutput(#call_expr)
182            }
183        }
184    };
185
186    Ok(quote! {
187        impl #trait_name for #self_ty {
188            type Input = #input_ty;
189            type Output = #output_ty;
190
191            #run_method
192        }
193    })
194}
195
196fn find_run_fn(input_impl: &ItemImpl) -> core::result::Result <&ImplItemFn, syn::Error> {
197    let mut run_fn: Option<&ImplItemFn> = None;
198
199    for item in &input_impl.items {
200        if let ImplItem::Fn(f) = item {
201            if f.sig.ident == "run" {
202                if run_fn.is_some() {
203                    return Err(syn::Error::new(
204                        f.sig.ident.span(),
205                        "only one `run` method is allowed in #[sync_task]/#[async_task] impl",
206                    ));
207                }
208                run_fn = Some(f);
209            }
210        }
211    }
212
213    run_fn.ok_or_else(|| {
214        syn::Error::new(
215            input_impl.self_ty.span(),
216            "impl block annotated with #[sync_task]/#[async_task] must define `run`",
217        )
218    })
219}
220
221#[derive(Copy, Clone)]
222enum ReceiverKind {
223    None,
224    Value,
225    Ref,
226    RefMut,
227}
228
229struct ArgInfo {
230    binding: syn::Ident,
231    input_ty: Type,
232    call_expr: proc_macro2::TokenStream,
233    needs_mut_binding: bool,
234}
235
236fn parse_signature(
237    run_fn: &ImplItemFn,
238) -> core::result::Result <(ReceiverKind, bool, std::vec::Vec <ArgInfo>), syn::Error> {
239    let mut receiver = ReceiverKind::None;
240    let mut args = Vec::new();
241    let mut has_ctx = false;
242    let mut typed_arg_index: usize = 0;
243
244    for arg in &run_fn.sig.inputs {
245        match arg {
246            FnArg::Receiver(rcv) => {
247                receiver = if rcv.reference.is_none() {
248                    ReceiverKind::Value
249                } else if rcv.mutability.is_some() {
250                    ReceiverKind::RefMut
251                } else {
252                    ReceiverKind::Ref
253                };
254            }
255            FnArg::Typed(typed) => {
256                let Pat::Ident(pat_ident) = typed.pat.as_ref() else {
257                    return Err(syn::Error::new(
258                        typed.pat.span(),
259                        "task `run` args must be simple identifiers",
260                    ));
261                };
262
263                let ident = pat_ident.ident.clone();
264
265                // Detect a leading `ctx: &FlowContext` argument. It must be
266                // the first non-`self` typed parameter and is routed to the
267                // runtime-provided FlowContext rather than being treated as a
268                // DAG input. Match by the trailing `FlowContext` identifier so
269                // users are free to `use ... as Foo` if they wish — but see
270                // the macro docs for the recommended convention.
271                if typed_arg_index == 0 {
272                    if let Type::Reference(r) = typed.ty.as_ref() {
273                        if r.mutability.is_none() && is_flow_context_path(r.elem.as_ref()) {
274                            has_ctx = true;
275                            typed_arg_index += 1;
276                            continue;
277                        }
278                    }
279                }
280                typed_arg_index += 1;
281
282                match typed.ty.as_ref() {
283                    Type::Reference(r) if r.mutability.is_none() => {
284                        let inner = (*r.elem).clone();
285                        args.push(ArgInfo {
286                            binding: ident.clone(),
287                            input_ty: inner,
288                            call_expr: quote! { &*#ident },
289                            needs_mut_binding: false,
290                        });
291                    }
292                    Type::Reference(r) if r.mutability.is_some() => {
293                        return Err(syn::Error::new(
294                            r.span(),
295                            "task `run` args must use shared references `&T`; mutable refs `&mut T` are not supported",
296                        ));
297                    }
298                    other_ty => {
299                        return Err(syn::Error::new(
300                            other_ty.span(),
301                            "task `run` args must use shared references `&T`; by-value args are not supported",
302                        ));
303                    }
304                }
305            }
306        }
307    }
308
309    Ok((receiver, has_ctx, args))
310}
311
312/// Matches `FlowContext` as the final path segment. Accepts `FlowContext`,
313/// `taskflow::FlowContext`, `crate::tf::component_registry::FlowContext`, etc.
314fn is_flow_context_path(ty: &Type) -> bool {
315    if let Type::Path(p) = ty {
316        if let Some(last) = p.path.segments.last() {
317            return last.ident == "FlowContext";
318        }
319    }
320    false
321}
322
323fn build_input_type(args: &[ArgInfo]) -> proc_macro2::TokenStream {
324    match args {
325        [] => quote! { () },
326        _ => {
327            let tys = args.iter().map(|arg| {
328                let ty = &arg.input_ty;
329                quote! { std::sync::Arc<#ty> }
330            });
331            quote! { ( #(#tys,)* ) }
332        }
333    }
334}
335
336fn build_destructure(args: &[ArgInfo]) -> proc_macro2::TokenStream {
337    match args {
338        [] => quote! { let _ = input; },
339        _ => {
340            let bindings = args.iter().map(|arg| {
341                let ident = &arg.binding;
342                if arg.needs_mut_binding {
343                    quote! { mut #ident }
344                } else {
345                    quote! { #ident }
346                }
347            });
348            quote! { let ( #(#bindings,)* ) = input.0; }
349        }
350    }
351}
352
353fn build_inherent_call(
354    self_ty: &Type,
355    receiver_kind: ReceiverKind,
356    has_ctx: bool,
357    call_args: &[proc_macro2::TokenStream],
358) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
359    // If the user's run declared `ctx: &FlowContext`, prepend it to the
360    // argument list so it flows from the runtime into their function.
361    let ctx_arg: Vec<proc_macro2::TokenStream> = if has_ctx {
362        vec![quote! { __tf_ctx }]
363    } else {
364        vec![]
365    };
366    let all_args: Vec<proc_macro2::TokenStream> = ctx_arg
367        .into_iter()
368        .chain(call_args.iter().cloned())
369        .collect();
370
371    match receiver_kind {
372        ReceiverKind::None => {
373            let call = if all_args.is_empty() {
374                quote! { <#self_ty>::run() }
375            } else {
376                quote! { <#self_ty>::run(#(#all_args),*) }
377            };
378            (quote! {}, call)
379        }
380        ReceiverKind::Value => {
381            let call = if all_args.is_empty() {
382                quote! { <#self_ty>::run(self) }
383            } else {
384                quote! { <#self_ty>::run(self, #(#all_args),*) }
385            };
386            (quote! {}, call)
387        }
388        ReceiverKind::Ref => {
389            let call = if all_args.is_empty() {
390                quote! { <#self_ty>::run(&self) }
391            } else {
392                quote! { <#self_ty>::run(&self, #(#all_args),*) }
393            };
394            (quote! {}, call)
395        }
396        ReceiverKind::RefMut => {
397            let call = if all_args.is_empty() {
398                quote! { <#self_ty>::run(&mut __task) }
399            } else {
400                quote! { <#self_ty>::run(&mut __task, #(#all_args),*) }
401            };
402            (quote! { let mut __task = self; }, call)
403        }
404    }
405}