1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse::Parser,
5 parse_macro_input,
6 parse_quote,
7 spanned::Spanned,
8 FnArg,
9 ImplItem,
10 ImplItemFn,
11 ItemImpl,
12 LitStr,
13 Pat,
14 ReturnType,
15 Type,
16};
17
18#[proc_macro_attribute]
53pub fn sync_task(attr: TokenStream, item: TokenStream) -> TokenStream {
54 expand_task(attr, item, false)
55}
56
57#[proc_macro_attribute]
61pub fn async_task(attr: TokenStream, item: TokenStream) -> TokenStream {
62 expand_task(attr, item, true)
63}
64
65fn expand_task(attr: TokenStream, item: TokenStream, expect_async: bool) -> TokenStream {
66 let input_impl = parse_macro_input!(item as ItemImpl);
67 let root_path = match parse_root_path(attr) {
68 Ok(path) => path,
69 Err(err) => return err.to_compile_error().into(),
70 };
71
72 match build_task_impl(&input_impl, expect_async, &root_path) {
73 Ok(expanded) => TokenStream::from(quote! {
74 #input_impl
75 #expanded
76 }),
77 Err(err) => err.to_compile_error().into(),
78 }
79}
80
81fn parse_root_path(attr: TokenStream) -> core::result::Result <syn::Path, syn::Error> {
82 if attr.is_empty() {
83 return Ok(parse_quote!(crate));
84 }
85
86 let mut parsed_path = None::<syn::Path>;
87 let parser = syn::meta::parser(|meta| {
88 if meta.path.is_ident("path") {
89 let lit: LitStr = meta.value()?.parse()?;
90 parsed_path = Some(lit.parse()?);
91 Ok(())
92 } else {
93 Err(meta.error("unsupported argument; expected `path = \"::taskflow\"`"))
94 }
95 });
96
97 parser.parse2(proc_macro2::TokenStream::from(attr))?;
98
99 parsed_path.ok_or_else(|| {
100 syn::Error::new(
101 proc_macro2::Span::call_site(),
102 "missing `path` argument; expected `path = \"::taskflow\"`",
103 )
104 })
105}
106
107fn build_task_impl(
108 input_impl: &ItemImpl,
109 expect_async: bool,
110 root_path: &syn::Path,
111) -> core::result::Result <proc_macro2::TokenStream, syn::Error> {
112 let self_ty = &input_impl.self_ty;
113 let run_fn = find_run_fn(input_impl)?;
114
115 if run_fn.sig.asyncness.is_some() != expect_async {
116 let msg = if expect_async {
117 "#[async_task] requires `async fn run(...)`"
118 } else {
119 "#[sync_task] requires non-async `fn run(...)`"
120 };
121 return Err(syn::Error::new(run_fn.sig.span(), msg));
122 }
123
124 let (receiver_kind, has_ctx, arg_infos) = parse_signature(run_fn)?;
125 let input_ty = build_input_type(&arg_infos);
126 let output_ty = match &run_fn.sig.output {
127 ReturnType::Default => {
128 return Err(syn::Error::new(
129 run_fn.sig.span(),
130 "run method must have an explicit return type",
131 ))
132 }
133 ReturnType::Type(_, ty) => ty.clone(),
134 };
135
136 let destructure = build_destructure(&arg_infos);
137 let call_args: Vec<_> = arg_infos.iter().map(|arg| arg.call_expr.clone()).collect();
138 let (receiver_setup, call_expr) =
139 build_inherent_call(self_ty, receiver_kind, has_ctx, &call_args);
140
141 let ctx_discard = if has_ctx {
145 quote! {}
146 } else {
147 quote! { let _ = __tf_ctx; }
148 };
149
150 let trait_name = if expect_async {
151 quote! { #root_path::tf::traits::AsyncTask }
152 } else {
153 quote! { #root_path::tf::traits::SyncTask }
154 };
155
156 let run_method = if expect_async {
157 quote! {
158 fn run(
159 self,
160 __tf_ctx: &#root_path::tf::component_registry::FlowContext,
161 input: #root_path::tf::task::TaskInput<Self::Input>,
162 ) -> impl std::future::Future<Output = #root_path::tf::task::TaskOutput<Self::Output>> + Send {
163 async move {
164 #ctx_discard
165 #destructure
166 #receiver_setup
167 #root_path::tf::task::TaskOutput(#call_expr.await)
168 }
169 }
170 }
171 } else {
172 quote! {
173 fn run(
174 self,
175 __tf_ctx: &#root_path::tf::component_registry::FlowContext,
176 input: #root_path::tf::task::TaskInput<Self::Input>,
177 ) -> #root_path::tf::task::TaskOutput<Self::Output> {
178 #ctx_discard
179 #destructure
180 #receiver_setup
181 #root_path::tf::task::TaskOutput(#call_expr)
182 }
183 }
184 };
185
186 Ok(quote! {
187 impl #trait_name for #self_ty {
188 type Input = #input_ty;
189 type Output = #output_ty;
190
191 #run_method
192 }
193 })
194}
195
196fn find_run_fn(input_impl: &ItemImpl) -> core::result::Result <&ImplItemFn, syn::Error> {
197 let mut run_fn: Option<&ImplItemFn> = None;
198
199 for item in &input_impl.items {
200 if let ImplItem::Fn(f) = item {
201 if f.sig.ident == "run" {
202 if run_fn.is_some() {
203 return Err(syn::Error::new(
204 f.sig.ident.span(),
205 "only one `run` method is allowed in #[sync_task]/#[async_task] impl",
206 ));
207 }
208 run_fn = Some(f);
209 }
210 }
211 }
212
213 run_fn.ok_or_else(|| {
214 syn::Error::new(
215 input_impl.self_ty.span(),
216 "impl block annotated with #[sync_task]/#[async_task] must define `run`",
217 )
218 })
219}
220
221#[derive(Copy, Clone)]
222enum ReceiverKind {
223 None,
224 Value,
225 Ref,
226 RefMut,
227}
228
229struct ArgInfo {
230 binding: syn::Ident,
231 input_ty: Type,
232 call_expr: proc_macro2::TokenStream,
233 needs_mut_binding: bool,
234}
235
236fn parse_signature(
237 run_fn: &ImplItemFn,
238) -> core::result::Result <(ReceiverKind, bool, std::vec::Vec <ArgInfo>), syn::Error> {
239 let mut receiver = ReceiverKind::None;
240 let mut args = Vec::new();
241 let mut has_ctx = false;
242 let mut typed_arg_index: usize = 0;
243
244 for arg in &run_fn.sig.inputs {
245 match arg {
246 FnArg::Receiver(rcv) => {
247 receiver = if rcv.reference.is_none() {
248 ReceiverKind::Value
249 } else if rcv.mutability.is_some() {
250 ReceiverKind::RefMut
251 } else {
252 ReceiverKind::Ref
253 };
254 }
255 FnArg::Typed(typed) => {
256 let Pat::Ident(pat_ident) = typed.pat.as_ref() else {
257 return Err(syn::Error::new(
258 typed.pat.span(),
259 "task `run` args must be simple identifiers",
260 ));
261 };
262
263 let ident = pat_ident.ident.clone();
264
265 if typed_arg_index == 0 {
272 if let Type::Reference(r) = typed.ty.as_ref() {
273 if r.mutability.is_none() && is_flow_context_path(r.elem.as_ref()) {
274 has_ctx = true;
275 typed_arg_index += 1;
276 continue;
277 }
278 }
279 }
280 typed_arg_index += 1;
281
282 match typed.ty.as_ref() {
283 Type::Reference(r) if r.mutability.is_none() => {
284 let inner = (*r.elem).clone();
285 args.push(ArgInfo {
286 binding: ident.clone(),
287 input_ty: inner,
288 call_expr: quote! { &*#ident },
289 needs_mut_binding: false,
290 });
291 }
292 Type::Reference(r) if r.mutability.is_some() => {
293 return Err(syn::Error::new(
294 r.span(),
295 "task `run` args must use shared references `&T`; mutable refs `&mut T` are not supported",
296 ));
297 }
298 other_ty => {
299 return Err(syn::Error::new(
300 other_ty.span(),
301 "task `run` args must use shared references `&T`; by-value args are not supported",
302 ));
303 }
304 }
305 }
306 }
307 }
308
309 Ok((receiver, has_ctx, args))
310}
311
312fn is_flow_context_path(ty: &Type) -> bool {
315 if let Type::Path(p) = ty {
316 if let Some(last) = p.path.segments.last() {
317 return last.ident == "FlowContext";
318 }
319 }
320 false
321}
322
323fn build_input_type(args: &[ArgInfo]) -> proc_macro2::TokenStream {
324 match args {
325 [] => quote! { () },
326 _ => {
327 let tys = args.iter().map(|arg| {
328 let ty = &arg.input_ty;
329 quote! { std::sync::Arc<#ty> }
330 });
331 quote! { ( #(#tys,)* ) }
332 }
333 }
334}
335
336fn build_destructure(args: &[ArgInfo]) -> proc_macro2::TokenStream {
337 match args {
338 [] => quote! { let _ = input; },
339 _ => {
340 let bindings = args.iter().map(|arg| {
341 let ident = &arg.binding;
342 if arg.needs_mut_binding {
343 quote! { mut #ident }
344 } else {
345 quote! { #ident }
346 }
347 });
348 quote! { let ( #(#bindings,)* ) = input.0; }
349 }
350 }
351}
352
353fn build_inherent_call(
354 self_ty: &Type,
355 receiver_kind: ReceiverKind,
356 has_ctx: bool,
357 call_args: &[proc_macro2::TokenStream],
358) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
359 let ctx_arg: Vec<proc_macro2::TokenStream> = if has_ctx {
362 vec![quote! { __tf_ctx }]
363 } else {
364 vec![]
365 };
366 let all_args: Vec<proc_macro2::TokenStream> = ctx_arg
367 .into_iter()
368 .chain(call_args.iter().cloned())
369 .collect();
370
371 match receiver_kind {
372 ReceiverKind::None => {
373 let call = if all_args.is_empty() {
374 quote! { <#self_ty>::run() }
375 } else {
376 quote! { <#self_ty>::run(#(#all_args),*) }
377 };
378 (quote! {}, call)
379 }
380 ReceiverKind::Value => {
381 let call = if all_args.is_empty() {
382 quote! { <#self_ty>::run(self) }
383 } else {
384 quote! { <#self_ty>::run(self, #(#all_args),*) }
385 };
386 (quote! {}, call)
387 }
388 ReceiverKind::Ref => {
389 let call = if all_args.is_empty() {
390 quote! { <#self_ty>::run(&self) }
391 } else {
392 quote! { <#self_ty>::run(&self, #(#all_args),*) }
393 };
394 (quote! {}, call)
395 }
396 ReceiverKind::RefMut => {
397 let call = if all_args.is_empty() {
398 quote! { <#self_ty>::run(&mut __task) }
399 } else {
400 quote! { <#self_ty>::run(&mut __task, #(#all_args),*) }
401 };
402 (quote! { let mut __task = self; }, call)
403 }
404 }
405}