Skip to main content

sayiir_core/
context.rs

1//! Workflow execution context.
2//!
3//! [`WorkflowContext`] carries the workflow ID, codec, and user-supplied
4//! metadata through every task execution.
5//!
6//! [`TaskExecutionContext`] provides read-only access to workflow and task
7//! metadata from within running tasks. It is set automatically by the
8//! runtime and can be retrieved via [`get_task_context()`] or the
9//! [`task_context!`](crate::task_context) macro.
10
11use std::sync::Arc;
12
13use crate::task::TaskMetadata;
14
15/// Execution context available to a running task.
16///
17/// Provides read-only access to workflow and task metadata. Accessible
18/// from within task functions via task-local storage (Rust) or
19/// language-specific context APIs (Python/Node.js).
20#[derive(Clone, Debug)]
21pub struct TaskExecutionContext {
22    /// The workflow definition identifier.
23    pub workflow_id: Arc<str>,
24    /// The workflow instance identifier.
25    pub instance_id: Arc<str>,
26    /// The current task identifier.
27    pub task_id: Arc<str>,
28    /// Task metadata (timeout, retry policy, version, etc.).
29    pub metadata: TaskMetadata,
30    /// Optional JSON-encoded workflow-level metadata.
31    pub workflow_metadata_json: Option<Arc<str>>,
32}
33
34/// Workflow execution context that provides access to metadata and codec.
35///
36/// This context is always available as a plain struct used during workflow
37/// building and by the runner for codec/metadata access.
38pub struct WorkflowContext<C, M> {
39    /// The unique workflow identifier.
40    pub workflow_id: Arc<str>,
41    /// The codec used for serialization/deserialization.
42    pub codec: Arc<C>,
43    /// Immutable metadata attached to the workflow.
44    pub metadata: Arc<M>,
45    /// Optional JSON-encoded workflow-level metadata for task context.
46    pub metadata_json: Option<Arc<str>>,
47}
48
49impl<C, M> Clone for WorkflowContext<C, M> {
50    fn clone(&self) -> Self {
51        Self {
52            workflow_id: Arc::clone(&self.workflow_id),
53            codec: Arc::clone(&self.codec),
54            metadata: Arc::clone(&self.metadata),
55            metadata_json: self.metadata_json.clone(),
56        }
57    }
58}
59
60impl<C, M> WorkflowContext<C, M> {
61    /// Create a new workflow context.
62    pub fn new(workflow_id: impl Into<Arc<str>>, codec: Arc<C>, metadata: Arc<M>) -> Self {
63        Self {
64            workflow_id: workflow_id.into(),
65            codec,
66            metadata,
67            metadata_json: None,
68        }
69    }
70
71    /// Returns the workflow identifier.
72    #[must_use]
73    pub fn workflow_id(&self) -> &str {
74        &self.workflow_id
75    }
76
77    /// Returns a clone of the codec `Arc`.
78    #[must_use]
79    pub fn codec(&self) -> Arc<C> {
80        self.codec.clone()
81    }
82
83    /// Returns a clone of the metadata `Arc`.
84    #[must_use]
85    pub fn metadata(&self) -> Arc<M> {
86        self.metadata.clone()
87    }
88}
89
90use std::cell::RefCell;
91
92std::thread_local! {
93    /// Thread-local fallback for `TaskExecutionContext`.
94    ///
95    /// Used by sync executor paths (Python GIL, Node.js main thread) where
96    /// tokio task-locals are not available.
97    static THREAD_LOCAL_TASK_CTX: RefCell<Option<TaskExecutionContext>> = const { RefCell::new(None) };
98}
99
100/// Set the task execution context in thread-local storage for the duration
101/// of the closure. Clears the context when the closure returns (even on panic).
102pub fn with_thread_local_task_context<R>(ctx: TaskExecutionContext, f: impl FnOnce() -> R) -> R {
103    THREAD_LOCAL_TASK_CTX.with(|cell| {
104        let prev = cell.borrow_mut().replace(ctx);
105        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
106        *cell.borrow_mut() = prev;
107        match result {
108            Ok(r) => r,
109            Err(e) => std::panic::resume_unwind(e),
110        }
111    })
112}
113
114/// Get the task execution context from thread-local storage.
115#[must_use]
116pub fn get_thread_local_task_context() -> Option<TaskExecutionContext> {
117    THREAD_LOCAL_TASK_CTX.with(|cell| cell.borrow().clone())
118}
119
120// ── Task-local context storage (requires tokio) ─────────────────────────
121
122#[cfg(feature = "tokio")]
123mod task_local_ctx {
124    use super::TaskExecutionContext;
125
126    tokio::task_local! {
127        /// Task-local storage for task execution context.
128        static TASK_EXEC_CTX: Option<TaskExecutionContext>;
129    }
130
131    /// Set the task execution context in task-local storage and execute the future.
132    pub async fn with_task_context<F: std::future::Future>(
133        ctx: TaskExecutionContext,
134        fut: F,
135    ) -> F::Output {
136        TASK_EXEC_CTX.scope(Some(ctx), fut).await
137    }
138
139    /// Get the task execution context from task-local storage.
140    ///
141    /// Tries the tokio task-local first, then falls back to the thread-local.
142    #[must_use]
143    pub fn get_task_context() -> Option<TaskExecutionContext> {
144        TASK_EXEC_CTX
145            .try_with(std::clone::Clone::clone)
146            .ok()
147            .flatten()
148            .or_else(super::get_thread_local_task_context)
149    }
150}
151
152#[cfg(feature = "tokio")]
153pub use task_local_ctx::{get_task_context, with_task_context};
154
155/// Get the task execution context (non-tokio fallback).
156///
157/// Delegates to thread-local storage only.
158#[cfg(not(feature = "tokio"))]
159#[must_use]
160pub fn get_task_context() -> Option<TaskExecutionContext> {
161    get_thread_local_task_context()
162}
163
164/// Macro to access the task execution context from within a task.
165///
166/// Returns `Option<TaskExecutionContext>` — `None` if called outside of
167/// task execution context.
168///
169/// Usage:
170/// ```rust,ignore
171/// if let Some(ctx) = task_context!() {
172///     println!("workflow: {}, task: {}", ctx.workflow_id, ctx.task_id);
173/// }
174/// ```
175#[macro_export]
176macro_rules! task_context {
177    () => {
178        $crate::context::get_task_context()
179    };
180}
181
182#[cfg(all(test, feature = "tokio"))]
183#[allow(clippy::unwrap_used, clippy::panic)]
184mod tests {
185    use super::*;
186    use crate::task::TaskMetadata;
187
188    fn make_task_ctx() -> TaskExecutionContext {
189        TaskExecutionContext {
190            workflow_id: Arc::from("wf-1"),
191            instance_id: Arc::from("inst-1"),
192            task_id: Arc::from("task-a"),
193            metadata: TaskMetadata::default(),
194            workflow_metadata_json: None,
195        }
196    }
197
198    #[test]
199    fn thread_local_roundtrip() {
200        assert!(get_thread_local_task_context().is_none());
201
202        let ctx = make_task_ctx();
203        let result = with_thread_local_task_context(ctx.clone(), || {
204            let inner = get_thread_local_task_context().unwrap();
205            assert_eq!(&*inner.workflow_id, "wf-1");
206            assert_eq!(&*inner.instance_id, "inst-1");
207            assert_eq!(&*inner.task_id, "task-a");
208            42
209        });
210        assert_eq!(result, 42);
211
212        // Cleared after scope
213        assert!(get_thread_local_task_context().is_none());
214    }
215
216    #[test]
217    fn thread_local_restores_on_panic() {
218        let ctx = make_task_ctx();
219        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
220            with_thread_local_task_context(ctx, || {
221                panic!("boom");
222            })
223        }));
224        assert!(result.is_err());
225        assert!(get_thread_local_task_context().is_none());
226    }
227
228    #[test]
229    fn task_local_roundtrip() {
230        let rt = tokio::runtime::Builder::new_current_thread()
231            .enable_all()
232            .build()
233            .unwrap();
234        rt.block_on(async {
235            assert!(get_task_context().is_none());
236
237            let ctx = make_task_ctx();
238            let inner = with_task_context(ctx, async {
239                let c = get_task_context().unwrap();
240                assert_eq!(&*c.task_id, "task-a");
241                c
242            })
243            .await;
244
245            assert_eq!(&*inner.workflow_id, "wf-1");
246        });
247    }
248
249    #[test]
250    fn task_local_falls_back_to_thread_local() {
251        let rt = tokio::runtime::Builder::new_current_thread()
252            .enable_all()
253            .build()
254            .unwrap();
255        rt.block_on(async {
256            // Set only thread-local, no task-local — should still find it
257            let ctx = make_task_ctx();
258            let result = with_thread_local_task_context(ctx, || get_task_context());
259            assert!(result.is_some());
260            assert_eq!(&*result.unwrap().instance_id, "inst-1");
261        });
262    }
263
264    #[test]
265    fn macro_works() {
266        let ctx = make_task_ctx();
267        with_thread_local_task_context(ctx, || {
268            let c = task_context!().unwrap();
269            assert_eq!(&*c.task_id, "task-a");
270        });
271    }
272}