Skip to main content

swf_runtime/
context.rs

1use crate::events::SharedEventBus;
2use crate::expression::ExpressionEngineRegistry;
3use crate::handler::HandlerRegistry;
4use crate::listener::{WorkflowEvent, WorkflowExecutionListener};
5use crate::secret::SecretManager;
6use crate::status::{StatusPhase, StatusPhaseLog};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Mutex};
11use swf_core::models::task::TaskDefinition;
12use swf_core::models::workflow::WorkflowDefinition;
13use tokio::sync::Notify;
14
15/// Generates a setter, deref-getter, and clone-getter for an `Option<Arc<T>>` field.
16macro_rules! arc_accessors {
17    ($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
18        pub fn $setter(&mut self, value: Arc<$ty>) {
19            self.$field = Some(value);
20        }
21        pub fn $getter(&self) -> Option<&$ty> {
22            self.$field.as_deref()
23        }
24        pub fn $clone(&self) -> Option<Arc<$ty>> {
25            self.$field.clone()
26        }
27    };
28}
29
30/// Generates a setter, ref-getter, and clone-getter for an `Option<T>` field where T: Clone.
31macro_rules! option_accessors {
32    ($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
33        pub fn $setter(&mut self, value: $ty) {
34            self.$field = Some(value);
35        }
36        pub fn $getter(&self) -> Option<&$ty> {
37            self.$field.as_ref()
38        }
39        pub fn $clone(&self) -> Option<$ty> {
40            self.$field.clone()
41        }
42    };
43}
44
45/// Shared suspend/resume state for workflow execution.
46/// Cloned between WorkflowHandle and WorkflowContext to avoid duplicating logic.
47#[derive(Clone)]
48pub(crate) struct SuspendState {
49    suspended: Arc<AtomicBool>,
50    resume_notify: Arc<Notify>,
51}
52
53impl SuspendState {
54    pub(crate) fn new() -> Self {
55        Self {
56            suspended: Arc::new(AtomicBool::new(false)),
57            resume_notify: Arc::new(Notify::new()),
58        }
59    }
60
61    /// Suspends the workflow. Returns true if suspended, false if already suspended.
62    pub fn suspend(&self) -> bool {
63        self.suspended
64            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
65            .is_ok()
66    }
67
68    /// Resumes a suspended workflow. Returns true if resumed, false if not suspended.
69    pub fn resume(&self) -> bool {
70        if self
71            .suspended
72            .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
73            .is_ok()
74        {
75            self.resume_notify.notify_waiters();
76            true
77        } else {
78            false
79        }
80    }
81
82    /// Checks if the workflow is currently suspended.
83    pub fn is_suspended(&self) -> bool {
84        self.suspended.load(Ordering::SeqCst)
85    }
86
87    /// Returns the cancellation-aware resume notifier.
88    pub(crate) fn resume_notify(&self) -> &Arc<Notify> {
89        &self.resume_notify
90    }
91}
92use tokio_util::sync::CancellationToken;
93
94/// Variable name constants used in JQ expressions
95pub mod vars {
96    pub const CONTEXT: &str = "$context";
97    pub const INPUT: &str = "$input";
98    pub const OUTPUT: &str = "$output";
99    pub const WORKFLOW: &str = "$workflow";
100    pub const RUNTIME: &str = "$runtime";
101    pub const TASK: &str = "$task";
102    pub const SECRET: &str = "$secret";
103    pub const AUTHORIZATION: &str = "$authorization";
104}
105
106/// Runtime name and version constants
107pub mod runtime_info {
108    pub const NAME: &str = "CNCF Serverless Workflow Specification Rust SDK";
109    pub const VERSION: &str = env!("CARGO_PKG_VERSION");
110
111    /// Cached runtime info JSON value (constructed once)
112    static RUNTIME_INFO: std::sync::LazyLock<serde_json::Value> = std::sync::LazyLock::new(|| {
113        serde_json::json!({
114            "name": NAME,
115            "version": VERSION,
116        })
117    });
118
119    pub fn runtime_info_value() -> &'static serde_json::Value {
120        &RUNTIME_INFO
121    }
122}
123
124/// Holds the runtime context for a workflow execution
125pub struct WorkflowContext {
126    /// The workflow input ($input)
127    input: Option<Value>,
128    /// The workflow output ($output)
129    output: Option<Value>,
130    /// The instance context ($context) - set by export.as
131    instance_ctx: Option<Value>,
132    /// The workflow descriptor ($workflow)
133    workflow_descriptor: Arc<Value>,
134    /// The current task descriptor ($task)
135    task_descriptor: Value,
136    /// Local expression variables (e.g., $item, $index in for loops)
137    local_expr_vars: HashMap<String, Value>,
138    /// The authorization descriptor ($authorization) — set after HTTP auth
139    authorization: Option<Value>,
140    /// The secret manager ($secret)
141    secret_manager: Option<Arc<dyn SecretManager>>,
142    /// The execution listener
143    listener: Option<Arc<dyn WorkflowExecutionListener>>,
144    /// The event bus for publish/subscribe (used by emit and listen tasks)
145    event_bus: Option<SharedEventBus>,
146    /// Sub-workflow registry keyed by "namespace/name/version"
147    sub_workflows: HashMap<String, WorkflowDefinition>,
148    /// Cancellation token for graceful shutdown (e.g., workflow timeout)
149    cancellation_token: CancellationToken,
150    /// Suspend flag: true when the workflow is suspended
151    suspend_state: SuspendState,
152    /// Handler registry for custom call/run handlers
153    handler_registry: HandlerRegistry,
154    /// Expression engine registry for pluggable expression evaluation
155    expression_engines: ExpressionEngineRegistry,
156    /// Registered function definitions for call.function resolution (catalog mechanism)
157    functions: HashMap<String, TaskDefinition>,
158    /// Overall workflow status log
159    status_log: Vec<StatusPhaseLog>,
160    /// Per-task status log
161    task_status: HashMap<String, Vec<StatusPhaseLog>>,
162    /// Per-task iteration counter (incremented each time a task executes)
163    iterations: HashMap<String, u32>,
164    /// Cached vars map for JQ expression evaluation (rebuilt when dirty)
165    vars_cache: Mutex<Option<HashMap<String, Value>>>,
166    /// Whether vars_cache is stale and needs rebuilding
167    vars_dirty: AtomicBool,
168}
169
170impl Clone for WorkflowContext {
171    fn clone(&self) -> Self {
172        Self {
173            input: self.input.clone(),
174            output: self.output.clone(),
175            instance_ctx: self.instance_ctx.clone(),
176            workflow_descriptor: Arc::clone(&self.workflow_descriptor),
177            task_descriptor: self.task_descriptor.clone(),
178            local_expr_vars: self.local_expr_vars.clone(),
179            authorization: self.authorization.clone(),
180            secret_manager: self.secret_manager.clone(),
181            listener: self.listener.clone(),
182            event_bus: self.event_bus.clone(),
183            sub_workflows: self.sub_workflows.clone(),
184            cancellation_token: self.cancellation_token.clone(),
185            suspend_state: self.suspend_state.clone(),
186            handler_registry: self.handler_registry.clone(),
187            expression_engines: self.expression_engines.clone(),
188            functions: self.functions.clone(),
189            status_log: self.status_log.clone(),
190            task_status: self.task_status.clone(),
191            iterations: self.iterations.clone(),
192            vars_cache: Mutex::new(self.vars_cache.lock().unwrap().clone()),
193            vars_dirty: AtomicBool::new(self.vars_dirty.load(Ordering::Acquire)),
194        }
195    }
196}
197
198impl std::fmt::Debug for WorkflowContext {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("WorkflowContext")
201            .field("input", &self.input)
202            .field("output", &self.output)
203            .field("instance_ctx", &self.instance_ctx)
204            .field("workflow_descriptor", &self.workflow_descriptor)
205            .field("task_descriptor", &self.task_descriptor)
206            .field("local_expr_vars", &self.local_expr_vars)
207            .field(
208                "secret_manager",
209                &self.secret_manager.as_ref().map(|_| "..."),
210            )
211            .field("listener", &self.listener.as_ref().map(|_| "..."))
212            .field("event_bus", &self.event_bus.as_ref().map(|_| "..."))
213            .field("status_log", &self.status_log)
214            .field("task_status", &self.task_status)
215            .field("iterations", &self.iterations)
216            .finish()
217    }
218}
219
220impl WorkflowContext {
221    /// Creates a new workflow context from a workflow definition
222    pub fn new(
223        workflow: &swf_core::models::workflow::WorkflowDefinition,
224    ) -> crate::error::WorkflowResult<Self> {
225        let workflow_json = serde_json::to_value(workflow).map_err(|e| {
226            crate::error::WorkflowError::runtime(
227                format!("failed to serialize workflow definition: {}", e),
228                "/",
229                "/",
230            )
231        })?;
232
233        let workflow_descriptor = Arc::new(serde_json::json!({
234            "id": uuid::Uuid::new_v4().to_string(),
235            "definition": workflow_json,
236        }));
237
238        let mut ctx = Self {
239            input: None,
240            output: None,
241            instance_ctx: None,
242            workflow_descriptor,
243            task_descriptor: Value::Object(Default::default()),
244            local_expr_vars: HashMap::new(),
245            authorization: None,
246            secret_manager: None,
247            listener: None,
248            event_bus: None,
249            sub_workflows: HashMap::new(),
250            cancellation_token: CancellationToken::new(),
251            suspend_state: SuspendState::new(),
252            handler_registry: HandlerRegistry::new(),
253            expression_engines: ExpressionEngineRegistry::new(),
254            functions: HashMap::new(),
255            status_log: Vec::new(),
256            task_status: HashMap::new(),
257            iterations: HashMap::new(),
258            vars_cache: Mutex::new(None),
259            vars_dirty: AtomicBool::new(true),
260        };
261        ctx.set_status(StatusPhase::Pending);
262        Ok(ctx)
263    }
264
265    // ---- Status ----
266
267    /// Sets the overall workflow status
268    pub fn set_status(&mut self, status: StatusPhase) {
269        self.status_log.push(StatusPhaseLog::new(status));
270    }
271
272    /// Gets the workflow instance ID
273    pub fn instance_id(&self) -> &str {
274        self.workflow_descriptor
275            .as_object()
276            .and_then(|obj| obj.get("id"))
277            .and_then(|id| id.as_str())
278            .unwrap_or("unknown")
279    }
280
281    /// Gets the current overall workflow status
282    pub fn get_status(&self) -> StatusPhase {
283        self.status_log
284            .last()
285            .map(|log| log.status)
286            .unwrap_or(StatusPhase::Pending)
287    }
288
289    /// Sets the status for a specific task
290    pub fn set_task_status(&mut self, task: &str, status: StatusPhase) {
291        self.task_status
292            .entry(task.to_string())
293            .or_default()
294            .push(StatusPhaseLog::new(status));
295    }
296
297    /// Gets the current status for a specific task
298    pub fn get_task_status(&self, task: &str) -> Option<StatusPhase> {
299        self.task_status
300            .get(task)
301            .and_then(|logs| logs.last())
302            .map(|log| log.status)
303    }
304
305    // ---- Input / Output / Instance Context ----
306
307    pub fn set_input(&mut self, value: Value) {
308        self.input = Some(value);
309        self.invalidate_vars_cache();
310    }
311    pub fn get_input(&self) -> Option<&Value> {
312        self.input.as_ref()
313    }
314    pub fn set_output(&mut self, value: Value) {
315        self.output = Some(value);
316        self.invalidate_vars_cache();
317    }
318    pub fn get_output(&self) -> Option<&Value> {
319        self.output.as_ref()
320    }
321    pub fn set_instance_ctx(&mut self, value: Value) {
322        self.instance_ctx = Some(value);
323        self.invalidate_vars_cache();
324    }
325    pub fn get_instance_ctx(&self) -> Option<&Value> {
326        self.instance_ctx.as_ref()
327    }
328
329    // ---- Raw Input (in workflow descriptor) ----
330
331    /// Sets the raw input in the workflow descriptor
332    pub fn set_raw_input(&mut self, input: &Value) {
333        let mut desc = (*self.workflow_descriptor).clone();
334        if let Some(obj) = desc.as_object_mut() {
335            obj.insert("input".to_string(), input.clone());
336        }
337        self.workflow_descriptor = Arc::new(desc);
338        self.invalidate_vars_cache();
339    }
340
341    // ---- Task Descriptor ----
342
343    /// Inserts a key-value pair into the task descriptor object.
344    fn task_descriptor_insert(&mut self, key: &str, value: Value) {
345        if let Some(obj) = self.task_descriptor.as_object_mut() {
346            obj.insert(key.to_string(), value);
347        }
348        self.invalidate_vars_cache();
349    }
350
351    /// Sets the task name in the current task descriptor
352    pub fn set_task_name(&mut self, name: &str) {
353        self.task_descriptor_insert("name", Value::String(name.to_string()));
354    }
355
356    /// Sets the task raw input
357    pub fn set_task_raw_input(&mut self, input: &Value) {
358        self.task_descriptor_insert("input", input.clone());
359    }
360
361    /// Sets the task raw output
362    pub fn set_task_raw_output(&mut self, output: &Value) {
363        self.task_descriptor_insert("output", output.clone());
364    }
365
366    /// Sets the task startedAt timestamp with nested structure:
367    /// { iso8601: "...", epoch: { seconds: 123, milliseconds: 123456 } }
368    pub fn set_task_started_at(&mut self) {
369        let now = chrono::Utc::now();
370        let iso8601 = now.to_rfc3339();
371        let epoch_seconds = now.timestamp();
372        let epoch_millis = now.timestamp_millis();
373        self.task_descriptor_insert(
374            "startedAt",
375            serde_json::json!({
376                "iso8601": iso8601,
377                "epoch": {
378                    "seconds": epoch_seconds,
379                    "milliseconds": epoch_millis,
380                }
381            }),
382        );
383    }
384
385    /// Sets the task reference (JSON Pointer)
386    pub fn set_task_reference(&mut self, reference: &str) {
387        self.task_descriptor_insert("reference", Value::String(reference.to_string()));
388    }
389
390    /// Gets the task reference
391    pub fn get_task_reference(&self) -> Option<&str> {
392        self.task_descriptor
393            .as_object()
394            .and_then(|obj| obj.get("reference"))
395            .and_then(|v| v.as_str())
396    }
397
398    /// Gets the serialized workflow JSON value (for json_pointer resolution)
399    pub fn get_workflow_json(&self) -> Option<&Value> {
400        self.workflow_descriptor
401            .as_object()
402            .and_then(|obj| obj.get("definition"))
403    }
404
405    /// Gets the workflow instance ID
406    /// Sets the task definition in the task descriptor
407    pub fn set_task_def(&mut self, task: &Value) {
408        self.task_descriptor_insert("definition", task.clone());
409    }
410
411    /// Increments and returns the iteration counter for the given task position.
412    /// Each time a task executes, this counter is incremented, starting at 1.
413    pub fn inc_iteration(&mut self, position: &str) -> u32 {
414        let count = self.iterations.entry(position.to_string()).or_insert(0);
415        *count += 1;
416        let value = *count;
417        self.task_descriptor_insert("iteration", serde_json::json!(value));
418        value
419    }
420
421    /// Sets the retry attempt count in the task descriptor
422    pub fn set_retry_attempt(&mut self, attempt: u32) {
423        self.task_descriptor_insert("retryAttempt", serde_json::json!(attempt));
424    }
425
426    /// Clears the current task context
427    pub fn clear_task_context(&mut self) {
428        self.task_descriptor = Value::Object(Default::default());
429    }
430
431    // ---- Secret Manager ----
432
433    arc_accessors!(
434        secret_manager,
435        set_secret_manager,
436        get_secret_manager,
437        clone_secret_manager,
438        dyn SecretManager
439    );
440
441    // ---- Execution Listener ----
442
443    arc_accessors!(
444        listener,
445        set_listener,
446        get_listener,
447        clone_listener,
448        dyn WorkflowExecutionListener
449    );
450
451    // ---- Event Emission ----
452
453    /// Emits an event to the listener if configured, and publishes as CloudEvent to EventBus
454    pub fn emit_event(&self, event: WorkflowEvent) {
455        // Notify the synchronous listener
456        if let Some(ref listener) = self.listener {
457            listener.on_event(&event);
458        }
459
460        // Publish lifecycle CloudEvent to EventBus if configured
461        if let Some(ref event_bus) = self.event_bus {
462            let cloud_event = event.to_cloud_event();
463            let bus = event_bus.clone();
464            tokio::spawn(async move {
465                bus.publish(cloud_event).await;
466            });
467        }
468    }
469
470    // ---- Event Bus ----
471
472    option_accessors!(
473        event_bus,
474        set_event_bus,
475        get_event_bus,
476        clone_event_bus,
477        SharedEventBus
478    );
479
480    // ---- Sub-Workflow Registry ----
481
482    /// Sets the sub-workflow registry
483    pub fn set_sub_workflows(&mut self, sub_workflows: HashMap<String, WorkflowDefinition>) {
484        self.sub_workflows = sub_workflows;
485    }
486
487    /// Looks up a sub-workflow by namespace/name/version key
488    pub fn get_sub_workflow(
489        &self,
490        namespace: &str,
491        name: &str,
492        version: &str,
493    ) -> Option<&WorkflowDefinition> {
494        let key = format!("{}/{}/{}", namespace, name, version);
495        self.sub_workflows.get(&key)
496    }
497
498    /// Clones the entire sub-workflow registry (for propagating to child runners)
499    pub fn clone_sub_workflows(&self) -> HashMap<String, WorkflowDefinition> {
500        self.sub_workflows.clone()
501    }
502
503    // ---- Handler Registry ----
504
505    /// Sets the handler registry (replaces all handlers)
506    pub fn set_handler_registry(&mut self, registry: HandlerRegistry) {
507        self.handler_registry = registry;
508    }
509
510    /// Gets a reference to the handler registry
511    pub fn get_handler_registry(&self) -> &HandlerRegistry {
512        &self.handler_registry
513    }
514
515    /// Clones the handler registry (for propagating to child runners)
516    pub fn clone_handler_registry(&self) -> HandlerRegistry {
517        self.handler_registry.clone()
518    }
519
520    // ---- Expression Engines ----
521
522    /// Sets the expression engine registry
523    pub(crate) fn set_expression_engines(&mut self, engines: ExpressionEngineRegistry) {
524        self.expression_engines = engines;
525    }
526
527    /// Gets a reference to the expression engine registry
528    pub(crate) fn get_expression_engines(&self) -> &ExpressionEngineRegistry {
529        &self.expression_engines
530    }
531
532    /// Clones the expression engine registry (for propagating to child runners)
533    pub(crate) fn clone_expression_engines(&self) -> ExpressionEngineRegistry {
534        self.expression_engines.clone()
535    }
536
537    // ---- Functions (Catalog) ----
538
539    /// Sets the registered function definitions (for call.function resolution)
540    pub fn set_functions(&mut self, functions: HashMap<String, TaskDefinition>) {
541        self.functions = functions;
542    }
543
544    /// Looks up a registered function definition by name
545    pub fn get_function(&self, name: &str) -> Option<&TaskDefinition> {
546        self.functions.get(name)
547    }
548
549    // ---- Cancellation ----
550
551    /// Gets a clone of the cancellation token (for use in tokio::select!)
552    pub fn cancellation_token(&self) -> CancellationToken {
553        self.cancellation_token.clone()
554    }
555
556    /// Cancels the workflow (triggers cancellation for all wait points)
557    pub fn cancel(&self) {
558        self.cancellation_token.cancel();
559    }
560
561    /// Checks if cancellation has been requested
562    pub fn is_cancelled(&self) -> bool {
563        self.cancellation_token.is_cancelled()
564    }
565
566    // ---- Suspend / Resume ----
567
568    /// Suspends the workflow execution
569    ///
570    /// Returns `true` if the workflow was successfully suspended,
571    /// `false` if it was already suspended.
572    pub fn suspend(&self) -> bool {
573        self.suspend_state.suspend()
574    }
575
576    /// Resumes a suspended workflow execution
577    ///
578    /// Returns `true` if the workflow was resumed from a suspended state,
579    /// `false` if it was not suspended.
580    pub fn resume(&self) -> bool {
581        self.suspend_state.resume()
582    }
583
584    /// Checks if the workflow is currently suspended
585    pub fn is_suspended(&self) -> bool {
586        self.suspend_state.is_suspended()
587    }
588
589    /// Waits until the workflow is resumed (or cancelled)
590    ///
591    /// Should be called from task runners at cooperative yield points
592    /// when the workflow is detected as suspended.
593    pub async fn wait_for_resume(&self) {
594        if self.is_suspended() {
595            tokio::select! {
596                _ = self.suspend_state.resume_notify().notified() => {}
597                _ = self.cancellation_token.cancelled() => {}
598            }
599        }
600    }
601
602    // ---- Suspend State Sharing ----
603
604    /// Sets the shared suspend/resume state from the WorkflowRunner
605    ///
606    /// This allows the WorkflowHandle to share the same AtomicBool and Notify
607    /// as the context, enabling external suspend/resume control.
608    pub(crate) fn set_suspend_state(&mut self, state: SuspendState) {
609        self.suspend_state = state;
610    }
611
612    // ---- Authorization ----
613
614    /// Sets the authorization descriptor for the current task
615    /// Called after HTTP authentication succeeds (Basic, Bearer, Digest, OAuth2, OIDC)
616    pub fn set_authorization(&mut self, scheme: &str, parameter: &str) {
617        self.authorization = Some(serde_json::json!({
618            "scheme": scheme,
619            "parameter": parameter,
620        }));
621        self.invalidate_vars_cache();
622    }
623
624    /// Clears the authorization descriptor (called after task completes)
625    pub fn clear_authorization(&mut self) {
626        self.authorization = None;
627        self.invalidate_vars_cache();
628    }
629
630    // ---- Local Expression Variables ----
631
632    /// Sets local expression variables (replaces all)
633    pub fn set_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
634        self.local_expr_vars = vars;
635        self.invalidate_vars_cache();
636    }
637
638    /// Adds local expression variables (merges, does not overwrite existing keys)
639    pub fn add_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
640        for (k, v) in vars {
641            self.local_expr_vars.entry(k).or_insert(v);
642        }
643        self.invalidate_vars_cache();
644    }
645
646    /// Removes specified local expression variables
647    pub fn remove_local_expr_vars(&mut self, keys: &[&str]) {
648        for key in keys {
649            self.local_expr_vars.remove(*key);
650        }
651        self.invalidate_vars_cache();
652    }
653
654    // ---- Variable Aggregation ----
655
656    /// Marks the vars cache as dirty (needs rebuild on next access)
657    fn invalidate_vars_cache(&self) {
658        self.vars_dirty.store(true, Ordering::Release);
659    }
660
661    /// Returns all variables for JQ expression evaluation, using a cache
662    /// to avoid rebuilding the map on every call.
663    pub fn get_vars(&self) -> HashMap<String, Value> {
664        if self.vars_dirty.load(Ordering::Acquire) {
665            let mut vars = HashMap::new();
666
667            vars.insert(
668                vars::INPUT.to_string(),
669                self.input.clone().unwrap_or(Value::Null),
670            );
671            vars.insert(
672                vars::OUTPUT.to_string(),
673                self.output.clone().unwrap_or(Value::Null),
674            );
675            vars.insert(
676                vars::CONTEXT.to_string(),
677                self.instance_ctx.clone().unwrap_or(Value::Null),
678            );
679            vars.insert(vars::TASK.to_string(), self.task_descriptor.clone());
680            vars.insert(
681                vars::WORKFLOW.to_string(),
682                (*self.workflow_descriptor).clone(),
683            );
684            vars.insert(
685                vars::RUNTIME.to_string(),
686                runtime_info::runtime_info_value().clone(),
687            );
688
689            if let Some(ref mgr) = self.secret_manager {
690                vars.insert(vars::SECRET.to_string(), mgr.get_all_secrets());
691            }
692
693            if let Some(ref auth) = self.authorization {
694                vars.insert(vars::AUTHORIZATION.to_string(), auth.clone());
695            }
696
697            for (k, v) in &self.local_expr_vars {
698                vars.insert(k.clone(), v.clone());
699            }
700
701            *self.vars_cache.lock().unwrap() = Some(vars);
702            self.vars_dirty.store(false, Ordering::Release);
703        }
704        self.vars_cache.lock().unwrap().as_ref().unwrap().clone()
705    }
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711    use serde_json::json;
712    use swf_core::models::workflow::WorkflowDefinition;
713
714    fn new_context() -> WorkflowContext {
715        let workflow = WorkflowDefinition::default();
716        WorkflowContext::new(&workflow).unwrap()
717    }
718
719    #[test]
720    fn test_context_new() {
721        let ctx = new_context();
722        assert!(ctx.get_input().is_none());
723        assert!(ctx.get_output().is_none());
724        assert_eq!(ctx.get_status(), StatusPhase::Pending);
725    }
726
727    #[test]
728    fn test_context_set_input_output() {
729        let mut ctx = new_context();
730        ctx.set_input(json!({"key": "value"}));
731        assert_eq!(ctx.get_input(), Some(&json!({"key": "value"})));
732
733        ctx.set_output(json!(42));
734        assert_eq!(ctx.get_output(), Some(&json!(42)));
735    }
736
737    #[test]
738    fn test_context_status_transitions() {
739        let mut ctx = new_context();
740        assert_eq!(ctx.get_status(), StatusPhase::Pending);
741
742        ctx.set_status(StatusPhase::Running);
743        assert_eq!(ctx.get_status(), StatusPhase::Running);
744
745        ctx.set_status(StatusPhase::Completed);
746        assert_eq!(ctx.get_status(), StatusPhase::Completed);
747    }
748
749    #[test]
750    fn test_context_instance_ctx() {
751        let mut ctx = new_context();
752        assert!(ctx.get_instance_ctx().is_none());
753
754        ctx.set_instance_ctx(json!({"exported": "data"}));
755        assert_eq!(ctx.get_instance_ctx(), Some(&json!({"exported": "data"})));
756    }
757
758    #[test]
759    fn test_context_local_expr_vars() {
760        let mut ctx = new_context();
761        let mut vars = HashMap::new();
762        vars.insert("$item".to_string(), json!("hello"));
763        vars.insert("$index".to_string(), json!(0));
764        ctx.add_local_expr_vars(vars);
765
766        let all_vars = ctx.get_vars();
767        assert_eq!(all_vars.get("$item"), Some(&json!("hello")));
768        assert_eq!(all_vars.get("$index"), Some(&json!(0)));
769
770        ctx.remove_local_expr_vars(&["$item", "$index"]);
771        let all_vars = ctx.get_vars();
772        assert!(!all_vars.contains_key("$item"));
773        assert!(!all_vars.contains_key("$index"));
774    }
775
776    #[test]
777    fn test_context_get_vars_includes_runtime() {
778        let ctx = new_context();
779        let vars = ctx.get_vars();
780        assert!(vars.contains_key(vars::RUNTIME));
781        assert!(vars.contains_key(vars::WORKFLOW));
782        assert!(vars.contains_key(vars::TASK));
783    }
784
785    #[test]
786    fn test_context_task_status() {
787        let mut ctx = new_context();
788        ctx.set_task_status("task1", StatusPhase::Running);
789        ctx.set_task_status("task1", StatusPhase::Completed);
790        ctx.set_task_status("task2", StatusPhase::Pending);
791
792        let task1_status = ctx.get_task_status("task1");
793        assert_eq!(task1_status, Some(StatusPhase::Completed));
794    }
795
796    #[test]
797    fn test_context_authorization() {
798        let mut ctx = new_context();
799
800        // No authorization by default
801        let vars = ctx.get_vars();
802        assert!(!vars.contains_key("$authorization"));
803
804        // Set authorization
805        ctx.set_authorization("Bearer", "my-token-123");
806        let vars = ctx.get_vars();
807        let auth = vars
808            .get("$authorization")
809            .expect("$authorization should be set");
810        assert_eq!(auth["scheme"], "Bearer");
811        assert_eq!(auth["parameter"], "my-token-123");
812
813        // Clear authorization
814        ctx.clear_authorization();
815        let vars = ctx.get_vars();
816        assert!(!vars.contains_key("$authorization"));
817    }
818}