runtara_sdk_macros/
lib.rs

1// Copyright (C) 2025 SyncMyOrders Sp. z o.o.
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//! Proc macros for runtara-sdk.
4//!
5//! Provides the `#[durable]` attribute macro for transparent durability.
6
7use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use quote::quote;
10use syn::{FnArg, ItemFn, Pat, ReturnType, Type, parse_macro_input, spanned::Spanned};
11
12/// Makes an async function durable by wrapping it with checkpoint-based caching.
13///
14/// The macro automatically:
15/// - Checks for existing checkpoint before execution
16/// - Returns cached result if checkpoint exists
17/// - Executes function and saves result as checkpoint if no cache
18///
19/// # Requirements
20///
21/// - Function must be async
22/// - **First parameter is the idempotency key** (any type that implements `Display`)
23/// - Function must return `Result<T, E>` where `T: Serialize + DeserializeOwned`
24/// - SDK must be registered via `RuntaraSdk::init()` before calling
25///
26/// # Example
27///
28/// ```ignore
29/// use runtara_sdk::durable;
30///
31/// #[durable]
32/// pub async fn fetch_order(key: &str, order_id: &str) -> Result<Order, OrderError> {
33///     // The key determines caching - same key = same cached result
34///     db.fetch_order(order_id).await
35/// }
36///
37/// // Usage:
38/// fetch_order("order-123", "123").await
39/// ```
40#[proc_macro_attribute]
41pub fn durable(_attr: TokenStream, item: TokenStream) -> TokenStream {
42    let input = parse_macro_input!(item as ItemFn);
43
44    match generate_durable_wrapper(input) {
45        Ok(tokens) => tokens.into(),
46        Err(err) => err.to_compile_error().into(),
47    }
48}
49
50fn generate_durable_wrapper(input: ItemFn) -> syn::Result<TokenStream2> {
51    let fn_name = &input.sig.ident;
52    let fn_name_str = fn_name.to_string();
53    let vis = &input.vis;
54    let attrs = &input.attrs;
55    let sig = &input.sig;
56    let block = &input.block;
57
58    // Must be async
59    if sig.asyncness.is_none() {
60        return Err(syn::Error::new(
61            sig.fn_token.span,
62            "#[durable] only works with async functions",
63        ));
64    }
65
66    // Validate return type is Result<T, E>
67    let ok_type = extract_result_ok_type(&sig.output)?;
68
69    // Extract the first argument as the idempotency key
70    let idempotency_key_ident = extract_first_arg_ident(&sig.inputs)?;
71
72    Ok(quote! {
73        #(#attrs)*
74        #vis #sig {
75            let __cache_key = format!("durable::{}::{}", #fn_name_str, #idempotency_key_ident);
76
77            // Step 1: Check if we have a cached result (read-only lookup)
78            {
79                let __sdk = ::runtara_sdk::sdk();
80                let __sdk_guard = __sdk.lock().await;
81
82                match __sdk_guard.get_checkpoint(&__cache_key).await {
83                    Ok(Some(cached_bytes)) => {
84                        // Found cached result - deserialize and return
85                        drop(__sdk_guard);
86                        match ::serde_json::from_slice::<#ok_type>(&cached_bytes) {
87                            Ok(cached_value) => {
88                                ::tracing::debug!(
89                                    function = #fn_name_str,
90                                    cache_key = %__cache_key,
91                                    "Returning cached result from checkpoint"
92                                );
93                                return Ok(cached_value);
94                            }
95                            Err(e) => {
96                                ::tracing::warn!(
97                                    function = #fn_name_str,
98                                    error = %e,
99                                    "Failed to deserialize cached result, re-executing"
100                                );
101                            }
102                        }
103                    }
104                    Ok(None) => {
105                        // No cached result - will execute function
106                    }
107                    Err(e) => {
108                        // Checkpoint lookup error - log and continue with execution
109                        ::tracing::warn!(
110                            function = #fn_name_str,
111                            error = %e,
112                            "Checkpoint lookup failed, executing function"
113                        );
114                    }
115                }
116            }
117
118            // Step 2: Execute original function body
119            let __result: Result<_, _> = (|| async #block)().await;
120
121            // Step 3: Cache successful result
122            if let Ok(ref value) = __result {
123                match ::serde_json::to_vec(value) {
124                    Ok(result_bytes) => {
125                        let __sdk = ::runtara_sdk::sdk();
126                        let __sdk_guard = __sdk.lock().await;
127
128                        // Use checkpoint to save - it won't overwrite if already exists
129                        match __sdk_guard.checkpoint(&__cache_key, &result_bytes).await {
130                            Ok(checkpoint_result) => {
131                                ::tracing::debug!(
132                                    function = #fn_name_str,
133                                    cache_key = %__cache_key,
134                                    "Result cached via checkpoint"
135                                );
136
137                                // Check for pending pause/cancel signals
138                                if checkpoint_result.should_cancel() {
139                                    ::tracing::info!(
140                                        function = #fn_name_str,
141                                        "Cancel signal pending - instance should exit"
142                                    );
143                                    // Let the caller handle cancellation via signal polling
144                                } else if checkpoint_result.should_pause() {
145                                    ::tracing::info!(
146                                        function = #fn_name_str,
147                                        "Pause signal pending - instance should exit after returning"
148                                    );
149                                    // Let the caller handle pause via signal polling
150                                }
151                            }
152                            Err(e) => {
153                                ::tracing::warn!(
154                                    function = #fn_name_str,
155                                    error = %e,
156                                    "Failed to cache result via checkpoint"
157                                );
158                            }
159                        }
160                    }
161                    Err(e) => {
162                        ::tracing::warn!(
163                            function = #fn_name_str,
164                            error = %e,
165                            "Failed to serialize result for caching"
166                        );
167                    }
168                }
169            }
170
171            __result
172        }
173    })
174}
175
176fn extract_result_ok_type(return_type: &ReturnType) -> syn::Result<Type> {
177    let ReturnType::Type(_, ty) = return_type else {
178        return Err(syn::Error::new(
179            return_type.span(),
180            "#[durable] requires function to return Result<T, E>",
181        ));
182    };
183
184    let Type::Path(type_path) = ty.as_ref() else {
185        return Err(syn::Error::new(
186            ty.span(),
187            "#[durable] requires function to return Result<T, E>",
188        ));
189    };
190
191    let segment = type_path.path.segments.last().ok_or_else(|| {
192        syn::Error::new(
193            ty.span(),
194            "#[durable] requires function to return Result<T, E>",
195        )
196    })?;
197
198    if segment.ident != "Result" {
199        return Err(syn::Error::new(
200            segment.ident.span(),
201            "#[durable] requires function to return Result<T, E>",
202        ));
203    }
204
205    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
206        return Err(syn::Error::new(
207            segment.span(),
208            "#[durable] requires Result<T, E> with explicit type parameters",
209        ));
210    };
211
212    match args.args.first() {
213        Some(syn::GenericArgument::Type(t)) => Ok(t.clone()),
214        _ => Err(syn::Error::new(
215            args.span(),
216            "#[durable] requires Result<T, E> with explicit type parameters",
217        )),
218    }
219}
220
221fn extract_first_arg_ident(
222    inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
223) -> syn::Result<syn::Ident> {
224    // Skip `self` receiver if present, get first real argument
225    for arg in inputs.iter() {
226        match arg {
227            FnArg::Receiver(_) => continue,
228            FnArg::Typed(pat_type) => {
229                let Pat::Ident(pat_ident) = pat_type.pat.as_ref() else {
230                    return Err(syn::Error::new(
231                        pat_type.pat.span(),
232                        "#[durable] requires the first argument to be a simple identifier",
233                    ));
234                };
235                return Ok(pat_ident.ident.clone());
236            }
237        }
238    }
239
240    Err(syn::Error::new(
241        proc_macro2::Span::call_site(),
242        "#[durable] requires at least one argument: the idempotency key (String)",
243    ))
244}