1use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use quote::quote;
10use syn::parse::{Parse, ParseStream};
11use syn::{
12 FnArg, Ident, ItemFn, LitInt, Pat, PatType, ReturnType, Token, Type, parse_macro_input,
13 spanned::Spanned,
14};
15
16#[derive(Debug, Default)]
18struct DurableAttr {
19 max_retries: Option<u32>,
21 strategy: Option<String>,
23 delay: Option<u64>,
25}
26
27impl Parse for DurableAttr {
28 fn parse(input: ParseStream) -> syn::Result<Self> {
29 let mut attr = DurableAttr::default();
30
31 while !input.is_empty() {
32 let ident: Ident = input.parse()?;
33 input.parse::<Token![=]>()?;
34
35 match ident.to_string().as_str() {
36 "max_retries" => {
37 let lit: LitInt = input.parse()?;
38 attr.max_retries = Some(lit.base10_parse()?);
39 }
40 "strategy" => {
41 let strategy_ident: Ident = input.parse()?;
42 let strategy_str = strategy_ident.to_string();
43 if strategy_str != "ExponentialBackoff" {
44 return Err(syn::Error::new(
45 strategy_ident.span(),
46 "Only ExponentialBackoff strategy is currently supported",
47 ));
48 }
49 attr.strategy = Some(strategy_str);
50 }
51 "delay" => {
52 let lit: LitInt = input.parse()?;
53 attr.delay = Some(lit.base10_parse()?);
54 }
55 _ => {
56 return Err(syn::Error::new(
57 ident.span(),
58 format!(
59 "Unknown attribute '{}'. Valid attributes: max_retries, strategy, delay",
60 ident
61 ),
62 ));
63 }
64 }
65
66 if input.peek(Token![,]) {
67 input.parse::<Token![,]>()?;
68 }
69 }
70
71 Ok(attr)
72 }
73}
74
75#[proc_macro_attribute]
118pub fn durable(attr: TokenStream, item: TokenStream) -> TokenStream {
119 let config = parse_macro_input!(attr as DurableAttr);
120 let input = parse_macro_input!(item as ItemFn);
121
122 match generate_durable_wrapper(input, config) {
123 Ok(tokens) => tokens.into(),
124 Err(err) => err.to_compile_error().into(),
125 }
126}
127
128fn generate_durable_wrapper(input: ItemFn, config: DurableAttr) -> syn::Result<TokenStream2> {
129 let fn_name = &input.sig.ident;
130 let fn_name_str = fn_name.to_string();
131 let vis = &input.vis;
132 let attrs = &input.attrs;
133 let sig = &input.sig;
134 let block = &input.block;
135
136 if sig.asyncness.is_none() {
138 return Err(syn::Error::new(
139 sig.fn_token.span,
140 "#[durable] only works with async functions",
141 ));
142 }
143
144 let ok_type = extract_result_ok_type(&sig.output)?;
146
147 let idempotency_key_ident = extract_first_arg_ident(&sig.inputs)?;
149
150 let max_retries = config.max_retries.unwrap_or(3);
152 let base_delay_ms = config.delay.unwrap_or(1000);
153
154 if max_retries == 0 {
156 generate_no_retry_wrapper(
158 fn_name_str,
159 vis,
160 attrs,
161 sig,
162 block,
163 ok_type,
164 idempotency_key_ident,
165 )
166 } else {
167 generate_retry_wrapper(
169 fn_name_str,
170 vis,
171 attrs,
172 sig,
173 block,
174 ok_type,
175 idempotency_key_ident,
176 max_retries,
177 base_delay_ms,
178 )
179 }
180}
181
182fn generate_no_retry_wrapper(
184 fn_name_str: String,
185 vis: &syn::Visibility,
186 attrs: &[syn::Attribute],
187 sig: &syn::Signature,
188 block: &syn::Block,
189 ok_type: Type,
190 idempotency_key_ident: Ident,
191) -> syn::Result<TokenStream2> {
192 Ok(quote! {
193 #(#attrs)*
194 #vis #sig {
195 let __cache_key = format!("durable::{}::{}", #fn_name_str, #idempotency_key_ident);
196
197 {
199 let __sdk = ::runtara_sdk::sdk();
200 let __sdk_guard = __sdk.lock().await;
201
202 match __sdk_guard.get_checkpoint(&__cache_key).await {
203 Ok(Some(cached_bytes)) => {
204 drop(__sdk_guard);
206 match ::serde_json::from_slice::<#ok_type>(&cached_bytes) {
207 Ok(cached_value) => {
208 ::tracing::debug!(
209 function = #fn_name_str,
210 cache_key = %__cache_key,
211 "Returning cached result from checkpoint"
212 );
213 return Ok(cached_value);
214 }
215 Err(e) => {
216 ::tracing::warn!(
217 function = #fn_name_str,
218 error = %e,
219 "Failed to deserialize cached result, re-executing"
220 );
221 }
222 }
223 }
224 Ok(None) => {
225 }
227 Err(e) => {
228 ::tracing::warn!(
230 function = #fn_name_str,
231 error = %e,
232 "Checkpoint lookup failed, executing function"
233 );
234 }
235 }
236 }
237
238 let __result: std::result::Result<_, _> = async #block.await;
240
241 if let Ok(ref value) = __result {
243 match ::serde_json::to_vec(value) {
244 Ok(result_bytes) => {
245 let __sdk = ::runtara_sdk::sdk();
246 let __sdk_guard = __sdk.lock().await;
247
248 match __sdk_guard.checkpoint(&__cache_key, &result_bytes).await {
250 Ok(checkpoint_result) => {
251 ::tracing::debug!(
252 function = #fn_name_str,
253 cache_key = %__cache_key,
254 "Result cached via checkpoint"
255 );
256
257 if checkpoint_result.should_cancel() {
259 ::tracing::info!(
260 function = #fn_name_str,
261 "Cancel signal pending - instance should exit"
262 );
263 } else if checkpoint_result.should_pause() {
264 ::tracing::info!(
265 function = #fn_name_str,
266 "Pause signal pending - instance should exit after returning"
267 );
268 }
269 }
270 Err(e) => {
271 ::tracing::warn!(
272 function = #fn_name_str,
273 error = %e,
274 "Failed to cache result via checkpoint"
275 );
276 }
277 }
278 }
279 Err(e) => {
280 ::tracing::warn!(
281 function = #fn_name_str,
282 error = %e,
283 "Failed to serialize result for caching"
284 );
285 }
286 }
287 }
288
289 __result
290 }
291 })
292}
293
294#[allow(clippy::too_many_arguments)]
296fn generate_retry_wrapper(
297 fn_name_str: String,
298 vis: &syn::Visibility,
299 attrs: &[syn::Attribute],
300 sig: &syn::Signature,
301 block: &syn::Block,
302 ok_type: Type,
303 idempotency_key_ident: Ident,
304 max_retries: u32,
305 base_delay_ms: u64,
306) -> syn::Result<TokenStream2> {
307 let total_attempts = max_retries + 1;
308
309 let clonable_params = extract_clonable_params(&sig.inputs);
311
312 let clone_statements: Vec<TokenStream2> = clonable_params
314 .iter()
315 .map(|(ident, _ty)| {
316 quote! {
317 let #ident = #ident.clone();
318 }
319 })
320 .collect();
321
322 Ok(quote! {
323 #(#attrs)*
324 #vis #sig {
325 let __cache_key = format!("durable::{}::{}", #fn_name_str, #idempotency_key_ident);
326 let __max_retries: u32 = #max_retries;
327 let __base_delay_ms: u64 = #base_delay_ms;
328
329 {
331 let __sdk = ::runtara_sdk::sdk();
332 let __sdk_guard = __sdk.lock().await;
333
334 match __sdk_guard.get_checkpoint(&__cache_key).await {
335 Ok(Some(cached_bytes)) => {
336 drop(__sdk_guard);
338 match ::serde_json::from_slice::<#ok_type>(&cached_bytes) {
339 Ok(cached_value) => {
340 ::tracing::debug!(
341 function = #fn_name_str,
342 cache_key = %__cache_key,
343 "Returning cached result from checkpoint"
344 );
345 return Ok(cached_value);
346 }
347 Err(e) => {
348 ::tracing::warn!(
349 function = #fn_name_str,
350 error = %e,
351 "Failed to deserialize cached result, re-executing"
352 );
353 }
354 }
355 }
356 Ok(None) => {
357 }
359 Err(e) => {
360 ::tracing::warn!(
362 function = #fn_name_str,
363 error = %e,
364 "Checkpoint lookup failed, executing function"
365 );
366 }
367 }
368 }
369
370 let mut __last_error: Option<String> = None;
372 let __total_attempts: u32 = #total_attempts;
373
374 for __attempt in 1..=__total_attempts {
375 #(#clone_statements)*
377
378 if __attempt > 1 {
380 let __delay_multiplier = 2u64.pow(__attempt - 2);
382 let __delay = ::std::time::Duration::from_millis(
383 __base_delay_ms.saturating_mul(__delay_multiplier)
384 );
385
386 ::tracing::info!(
387 function = #fn_name_str,
388 attempt = __attempt,
389 max_retries = __max_retries,
390 delay_ms = __delay.as_millis() as u64,
391 last_error = ?__last_error,
392 "Retrying after backoff"
393 );
394
395 ::tokio::time::sleep(__delay).await;
396
397 {
399 let __sdk = ::runtara_sdk::sdk();
400 let __sdk_guard = __sdk.lock().await;
401 if let Err(e) = __sdk_guard.record_retry_attempt(
402 &__cache_key,
403 __attempt,
404 __last_error.as_deref(),
405 ).await {
406 ::tracing::warn!(
407 function = #fn_name_str,
408 error = %e,
409 "Failed to record retry attempt"
410 );
411 }
412 }
413 }
414
415 let __result: std::result::Result<_, _> = async #block.await;
417
418 match __result {
419 Ok(ref value) => {
420 match ::serde_json::to_vec(value) {
422 Ok(result_bytes) => {
423 let __sdk = ::runtara_sdk::sdk();
424 let __sdk_guard = __sdk.lock().await;
425
426 match __sdk_guard.checkpoint(&__cache_key, &result_bytes).await {
427 Ok(checkpoint_result) => {
428 ::tracing::debug!(
429 function = #fn_name_str,
430 cache_key = %__cache_key,
431 attempts = __attempt,
432 "Result cached via checkpoint"
433 );
434
435 if checkpoint_result.should_cancel() {
437 ::tracing::info!(
438 function = #fn_name_str,
439 "Cancel signal pending - instance should exit"
440 );
441 } else if checkpoint_result.should_pause() {
442 ::tracing::info!(
443 function = #fn_name_str,
444 "Pause signal pending - instance should exit after returning"
445 );
446 }
447 }
448 Err(e) => {
449 ::tracing::warn!(
450 function = #fn_name_str,
451 error = %e,
452 "Failed to cache result via checkpoint"
453 );
454 }
455 }
456 }
457 Err(e) => {
458 ::tracing::warn!(
459 function = #fn_name_str,
460 error = %e,
461 "Failed to serialize result for caching"
462 );
463 }
464 }
465 return __result;
466 }
467 Err(ref e) => {
468 __last_error = Some(format!("{}", e));
469
470 if __attempt < __total_attempts {
471 ::tracing::warn!(
472 function = #fn_name_str,
473 attempt = __attempt,
474 max_retries = __max_retries,
475 error = %e,
476 "Attempt failed, will retry"
477 );
478 continue;
479 } else {
480 ::tracing::error!(
481 function = #fn_name_str,
482 attempts = __attempt,
483 error = %e,
484 "All retry attempts exhausted"
485 );
486 return __result;
487 }
488 }
489 }
490 }
491
492 unreachable!("Retry loop should always return")
494 }
495 })
496}
497
498fn extract_result_ok_type(return_type: &ReturnType) -> syn::Result<Type> {
499 let ReturnType::Type(_, ty) = return_type else {
500 return Err(syn::Error::new(
501 return_type.span(),
502 "#[durable] requires function to return Result<T, E>",
503 ));
504 };
505
506 let Type::Path(type_path) = ty.as_ref() else {
507 return Err(syn::Error::new(
508 ty.span(),
509 "#[durable] requires function to return Result<T, E>",
510 ));
511 };
512
513 let segment = type_path.path.segments.last().ok_or_else(|| {
514 syn::Error::new(
515 ty.span(),
516 "#[durable] requires function to return Result<T, E>",
517 )
518 })?;
519
520 if segment.ident != "Result" {
521 return Err(syn::Error::new(
522 segment.ident.span(),
523 "#[durable] requires function to return Result<T, E>",
524 ));
525 }
526
527 let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
528 return Err(syn::Error::new(
529 segment.span(),
530 "#[durable] requires Result<T, E> with explicit type parameters",
531 ));
532 };
533
534 match args.args.first() {
535 Some(syn::GenericArgument::Type(t)) => Ok(t.clone()),
536 _ => Err(syn::Error::new(
537 args.span(),
538 "#[durable] requires Result<T, E> with explicit type parameters",
539 )),
540 }
541}
542
543fn extract_first_arg_ident(
544 inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
545) -> syn::Result<syn::Ident> {
546 for arg in inputs.iter() {
548 match arg {
549 FnArg::Receiver(_) => continue,
550 FnArg::Typed(pat_type) => {
551 let Pat::Ident(pat_ident) = pat_type.pat.as_ref() else {
552 return Err(syn::Error::new(
553 pat_type.pat.span(),
554 "#[durable] requires the first argument to be a simple identifier",
555 ));
556 };
557 return Ok(pat_ident.ident.clone());
558 }
559 }
560 }
561
562 Err(syn::Error::new(
563 proc_macro2::Span::call_site(),
564 "#[durable] requires at least one argument: the idempotency key (String)",
565 ))
566}
567
568fn is_reference_type(ty: &Type) -> bool {
570 matches!(ty, Type::Reference(_))
571}
572
573fn extract_clonable_params(
576 inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
577) -> Vec<(Ident, Type)> {
578 let mut params = Vec::new();
579 let mut is_first = true;
580
581 for arg in inputs.iter() {
582 match arg {
583 FnArg::Receiver(_) => continue,
584 FnArg::Typed(PatType { pat, ty, .. }) => {
585 if is_first {
587 is_first = false;
588 continue;
589 }
590
591 if is_reference_type(ty) {
593 continue;
594 }
595
596 if let Pat::Ident(pat_ident) = pat.as_ref() {
598 params.push((pat_ident.ident.clone(), (**ty).clone()));
599 }
600 }
601 }
602 }
603
604 params
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610 use syn::parse_quote;
611
612 #[test]
613 fn test_durable_attr_parsing_empty() {
614 let attr: DurableAttr = syn::parse2(quote! {}).unwrap();
615 assert!(attr.max_retries.is_none());
616 assert!(attr.strategy.is_none());
617 assert!(attr.delay.is_none());
618 }
619
620 #[test]
621 fn test_durable_attr_parsing_max_retries() {
622 let attr: DurableAttr = syn::parse2(quote! { max_retries = 5 }).unwrap();
623 assert_eq!(attr.max_retries, Some(5));
624 assert!(attr.strategy.is_none());
625 assert!(attr.delay.is_none());
626 }
627
628 #[test]
629 fn test_durable_attr_parsing_delay() {
630 let attr: DurableAttr = syn::parse2(quote! { delay = 2000 }).unwrap();
631 assert!(attr.max_retries.is_none());
632 assert!(attr.strategy.is_none());
633 assert_eq!(attr.delay, Some(2000));
634 }
635
636 #[test]
637 fn test_durable_attr_parsing_strategy() {
638 let attr: DurableAttr = syn::parse2(quote! { strategy = ExponentialBackoff }).unwrap();
639 assert!(attr.max_retries.is_none());
640 assert_eq!(attr.strategy, Some("ExponentialBackoff".to_string()));
641 assert!(attr.delay.is_none());
642 }
643
644 #[test]
645 fn test_durable_attr_parsing_all_options() {
646 let attr: DurableAttr =
647 syn::parse2(quote! { max_retries = 3, strategy = ExponentialBackoff, delay = 1000 })
648 .unwrap();
649 assert_eq!(attr.max_retries, Some(3));
650 assert_eq!(attr.strategy, Some("ExponentialBackoff".to_string()));
651 assert_eq!(attr.delay, Some(1000));
652 }
653
654 #[test]
655 fn test_durable_attr_parsing_unknown_attribute_fails() {
656 let result: Result<DurableAttr, _> = syn::parse2(quote! { unknown = 5 });
657 assert!(result.is_err());
658 let err = result.unwrap_err().to_string();
659 assert!(err.contains("Unknown attribute"));
660 }
661
662 #[test]
663 fn test_durable_attr_parsing_invalid_strategy_fails() {
664 let result: Result<DurableAttr, _> = syn::parse2(quote! { strategy = LinearBackoff });
665 assert!(result.is_err());
666 let err = result.unwrap_err().to_string();
667 assert!(err.contains("Only ExponentialBackoff"));
668 }
669
670 #[test]
671 fn test_extract_result_ok_type_valid() {
672 let fn_item: ItemFn = parse_quote! {
673 async fn foo(key: &str) -> Result<String, Error> {
674 Ok("hello".to_string())
675 }
676 };
677 let result = extract_result_ok_type(&fn_item.sig.output);
678 assert!(result.is_ok());
679 }
680
681 #[test]
682 fn test_extract_result_ok_type_no_return() {
683 let fn_item: ItemFn = parse_quote! {
684 async fn foo(key: &str) {
685 }
686 };
687 let result = extract_result_ok_type(&fn_item.sig.output);
688 assert!(result.is_err());
689 }
690
691 #[test]
692 fn test_extract_result_ok_type_not_result() {
693 let fn_item: ItemFn = parse_quote! {
694 async fn foo(key: &str) -> Option<String> {
695 Some("hello".to_string())
696 }
697 };
698 let result = extract_result_ok_type(&fn_item.sig.output);
699 assert!(result.is_err());
700 }
701
702 #[test]
703 fn test_extract_first_arg_ident_valid() {
704 let fn_item: ItemFn = parse_quote! {
705 async fn foo(key: &str, value: i32) -> Result<(), ()> {
706 Ok(())
707 }
708 };
709 let result = extract_first_arg_ident(&fn_item.sig.inputs);
710 assert!(result.is_ok());
711 assert_eq!(result.unwrap().to_string(), "key");
712 }
713
714 #[test]
715 fn test_extract_first_arg_ident_no_args() {
716 let fn_item: ItemFn = parse_quote! {
717 async fn foo() -> Result<(), ()> {
718 Ok(())
719 }
720 };
721 let result = extract_first_arg_ident(&fn_item.sig.inputs);
722 assert!(result.is_err());
723 }
724
725 #[test]
726 fn test_extract_first_arg_ident_with_self() {
727 let fn_item: ItemFn = parse_quote! {
728 async fn foo(&self, key: &str) -> Result<(), ()> {
729 Ok(())
730 }
731 };
732 let result = extract_first_arg_ident(&fn_item.sig.inputs);
733 assert!(result.is_ok());
734 assert_eq!(result.unwrap().to_string(), "key");
735 }
736
737 #[test]
738 fn test_is_reference_type_reference() {
739 let ty: Type = parse_quote! { &str };
740 assert!(is_reference_type(&ty));
741 }
742
743 #[test]
744 fn test_is_reference_type_not_reference() {
745 let ty: Type = parse_quote! { String };
746 assert!(!is_reference_type(&ty));
747 }
748
749 #[test]
750 fn test_extract_clonable_params_reference_skipped() {
751 let fn_item: ItemFn = parse_quote! {
752 async fn foo(key: &str, value: &[u8]) -> Result<(), ()> {
753 Ok(())
754 }
755 };
756 let params = extract_clonable_params(&fn_item.sig.inputs);
757 assert!(params.is_empty());
760 }
761
762 #[test]
763 fn test_extract_clonable_params_owned_included() {
764 let fn_item: ItemFn = parse_quote! {
765 async fn foo(key: &str, value: String, count: i32) -> Result<(), ()> {
766 Ok(())
767 }
768 };
769 let params = extract_clonable_params(&fn_item.sig.inputs);
770 assert_eq!(params.len(), 2);
772 assert_eq!(params[0].0.to_string(), "value");
773 assert_eq!(params[1].0.to_string(), "count");
774 }
775
776 #[test]
777 fn test_generate_durable_wrapper_not_async_fails() {
778 let fn_item: ItemFn = parse_quote! {
779 fn foo(key: &str) -> Result<(), ()> {
780 Ok(())
781 }
782 };
783 let config = DurableAttr::default();
784 let result = generate_durable_wrapper(fn_item, config);
785 assert!(result.is_err());
786 let err = result.unwrap_err().to_string();
787 assert!(err.contains("async"));
788 }
789
790 #[test]
791 fn test_generate_durable_wrapper_valid() {
792 let fn_item: ItemFn = parse_quote! {
793 async fn foo(key: &str) -> Result<String, String> {
794 Ok("hello".to_string())
795 }
796 };
797 let config = DurableAttr::default();
798 let result = generate_durable_wrapper(fn_item, config);
799 assert!(result.is_ok());
800 }
801
802 #[test]
803 fn test_generate_durable_wrapper_zero_retries() {
804 let fn_item: ItemFn = parse_quote! {
805 async fn foo(key: &str) -> Result<String, String> {
806 Ok("hello".to_string())
807 }
808 };
809 let config = DurableAttr {
810 max_retries: Some(0),
811 strategy: None,
812 delay: None,
813 };
814 let result = generate_durable_wrapper(fn_item, config);
815 assert!(result.is_ok());
816 let tokens = result.unwrap().to_string();
818 assert!(!tokens.contains("__max_retries"));
820 }
821
822 #[test]
823 fn test_generate_durable_wrapper_with_retries() {
824 let fn_item: ItemFn = parse_quote! {
825 async fn foo(key: &str) -> Result<String, String> {
826 Ok("hello".to_string())
827 }
828 };
829 let config = DurableAttr {
830 max_retries: Some(3),
831 strategy: None,
832 delay: Some(1000),
833 };
834 let result = generate_durable_wrapper(fn_item, config);
835 assert!(result.is_ok());
836 let tokens = result.unwrap().to_string();
838 assert!(tokens.contains("__max_retries"));
840 assert!(tokens.contains("__base_delay_ms"));
841 }
842}