1use crate::events::SharedEventBus;
2use crate::expression::ExpressionEngineRegistry;
3use crate::handler::HandlerRegistry;
4use crate::listener::{WorkflowEvent, WorkflowExecutionListener};
5use crate::secret::SecretManager;
6use crate::status::{StatusPhase, StatusPhaseLog};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Mutex};
11use swf_core::models::task::TaskDefinition;
12use swf_core::models::workflow::WorkflowDefinition;
13use tokio::sync::Notify;
14
15macro_rules! arc_accessors {
17 ($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
18 pub fn $setter(&mut self, value: Arc<$ty>) {
19 self.$field = Some(value);
20 }
21 pub fn $getter(&self) -> Option<&$ty> {
22 self.$field.as_deref()
23 }
24 pub fn $clone(&self) -> Option<Arc<$ty>> {
25 self.$field.clone()
26 }
27 };
28}
29
30macro_rules! option_accessors {
32 ($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
33 pub fn $setter(&mut self, value: $ty) {
34 self.$field = Some(value);
35 }
36 pub fn $getter(&self) -> Option<&$ty> {
37 self.$field.as_ref()
38 }
39 pub fn $clone(&self) -> Option<$ty> {
40 self.$field.clone()
41 }
42 };
43}
44
45#[derive(Clone)]
48pub(crate) struct SuspendState {
49 suspended: Arc<AtomicBool>,
50 resume_notify: Arc<Notify>,
51}
52
53impl SuspendState {
54 pub(crate) fn new() -> Self {
55 Self {
56 suspended: Arc::new(AtomicBool::new(false)),
57 resume_notify: Arc::new(Notify::new()),
58 }
59 }
60
61 pub fn suspend(&self) -> bool {
63 self.suspended
64 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
65 .is_ok()
66 }
67
68 pub fn resume(&self) -> bool {
70 if self
71 .suspended
72 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
73 .is_ok()
74 {
75 self.resume_notify.notify_waiters();
76 true
77 } else {
78 false
79 }
80 }
81
82 pub fn is_suspended(&self) -> bool {
84 self.suspended.load(Ordering::SeqCst)
85 }
86
87 pub(crate) fn resume_notify(&self) -> &Arc<Notify> {
89 &self.resume_notify
90 }
91}
92use tokio_util::sync::CancellationToken;
93
94pub mod vars {
96 pub const CONTEXT: &str = "$context";
97 pub const INPUT: &str = "$input";
98 pub const OUTPUT: &str = "$output";
99 pub const WORKFLOW: &str = "$workflow";
100 pub const RUNTIME: &str = "$runtime";
101 pub const TASK: &str = "$task";
102 pub const SECRET: &str = "$secret";
103 pub const AUTHORIZATION: &str = "$authorization";
104}
105
106pub mod runtime_info {
108 pub const NAME: &str = "CNCF Serverless Workflow Specification Rust SDK";
109 pub const VERSION: &str = env!("CARGO_PKG_VERSION");
110
111 static RUNTIME_INFO: std::sync::LazyLock<serde_json::Value> = std::sync::LazyLock::new(|| {
113 serde_json::json!({
114 "name": NAME,
115 "version": VERSION,
116 })
117 });
118
119 pub fn runtime_info_value() -> &'static serde_json::Value {
120 &RUNTIME_INFO
121 }
122}
123
124pub struct WorkflowContext {
126 input: Option<Value>,
128 output: Option<Value>,
130 instance_ctx: Option<Value>,
132 workflow_descriptor: Arc<Value>,
134 task_descriptor: Value,
136 local_expr_vars: HashMap<String, Value>,
138 authorization: Option<Value>,
140 secret_manager: Option<Arc<dyn SecretManager>>,
142 listener: Option<Arc<dyn WorkflowExecutionListener>>,
144 event_bus: Option<SharedEventBus>,
146 sub_workflows: HashMap<String, WorkflowDefinition>,
148 cancellation_token: CancellationToken,
150 suspend_state: SuspendState,
152 handler_registry: HandlerRegistry,
154 expression_engines: ExpressionEngineRegistry,
156 functions: HashMap<String, TaskDefinition>,
158 status_log: Vec<StatusPhaseLog>,
160 task_status: HashMap<String, Vec<StatusPhaseLog>>,
162 iterations: HashMap<String, u32>,
164 vars_cache: Mutex<Option<HashMap<String, Value>>>,
166 vars_dirty: AtomicBool,
168}
169
170impl Clone for WorkflowContext {
171 fn clone(&self) -> Self {
172 Self {
173 input: self.input.clone(),
174 output: self.output.clone(),
175 instance_ctx: self.instance_ctx.clone(),
176 workflow_descriptor: Arc::clone(&self.workflow_descriptor),
177 task_descriptor: self.task_descriptor.clone(),
178 local_expr_vars: self.local_expr_vars.clone(),
179 authorization: self.authorization.clone(),
180 secret_manager: self.secret_manager.clone(),
181 listener: self.listener.clone(),
182 event_bus: self.event_bus.clone(),
183 sub_workflows: self.sub_workflows.clone(),
184 cancellation_token: self.cancellation_token.clone(),
185 suspend_state: self.suspend_state.clone(),
186 handler_registry: self.handler_registry.clone(),
187 expression_engines: self.expression_engines.clone(),
188 functions: self.functions.clone(),
189 status_log: self.status_log.clone(),
190 task_status: self.task_status.clone(),
191 iterations: self.iterations.clone(),
192 vars_cache: Mutex::new(self.vars_cache.lock().unwrap().clone()),
193 vars_dirty: AtomicBool::new(self.vars_dirty.load(Ordering::Acquire)),
194 }
195 }
196}
197
198impl std::fmt::Debug for WorkflowContext {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("WorkflowContext")
201 .field("input", &self.input)
202 .field("output", &self.output)
203 .field("instance_ctx", &self.instance_ctx)
204 .field("workflow_descriptor", &self.workflow_descriptor)
205 .field("task_descriptor", &self.task_descriptor)
206 .field("local_expr_vars", &self.local_expr_vars)
207 .field(
208 "secret_manager",
209 &self.secret_manager.as_ref().map(|_| "..."),
210 )
211 .field("listener", &self.listener.as_ref().map(|_| "..."))
212 .field("event_bus", &self.event_bus.as_ref().map(|_| "..."))
213 .field("status_log", &self.status_log)
214 .field("task_status", &self.task_status)
215 .field("iterations", &self.iterations)
216 .finish()
217 }
218}
219
220impl WorkflowContext {
221 pub fn new(
223 workflow: &swf_core::models::workflow::WorkflowDefinition,
224 ) -> crate::error::WorkflowResult<Self> {
225 let workflow_json = serde_json::to_value(workflow).map_err(|e| {
226 crate::error::WorkflowError::runtime(
227 format!("failed to serialize workflow definition: {}", e),
228 "/",
229 "/",
230 )
231 })?;
232
233 let workflow_descriptor = Arc::new(serde_json::json!({
234 "id": uuid::Uuid::new_v4().to_string(),
235 "definition": workflow_json,
236 }));
237
238 let mut ctx = Self {
239 input: None,
240 output: None,
241 instance_ctx: None,
242 workflow_descriptor,
243 task_descriptor: Value::Object(Default::default()),
244 local_expr_vars: HashMap::new(),
245 authorization: None,
246 secret_manager: None,
247 listener: None,
248 event_bus: None,
249 sub_workflows: HashMap::new(),
250 cancellation_token: CancellationToken::new(),
251 suspend_state: SuspendState::new(),
252 handler_registry: HandlerRegistry::new(),
253 expression_engines: ExpressionEngineRegistry::new(),
254 functions: HashMap::new(),
255 status_log: Vec::new(),
256 task_status: HashMap::new(),
257 iterations: HashMap::new(),
258 vars_cache: Mutex::new(None),
259 vars_dirty: AtomicBool::new(true),
260 };
261 ctx.set_status(StatusPhase::Pending);
262 Ok(ctx)
263 }
264
265 pub fn set_status(&mut self, status: StatusPhase) {
269 self.status_log.push(StatusPhaseLog::new(status));
270 }
271
272 pub fn instance_id(&self) -> &str {
274 self.workflow_descriptor
275 .as_object()
276 .and_then(|obj| obj.get("id"))
277 .and_then(|id| id.as_str())
278 .unwrap_or("unknown")
279 }
280
281 pub fn get_status(&self) -> StatusPhase {
283 self.status_log
284 .last()
285 .map(|log| log.status)
286 .unwrap_or(StatusPhase::Pending)
287 }
288
289 pub fn set_task_status(&mut self, task: &str, status: StatusPhase) {
291 self.task_status
292 .entry(task.to_string())
293 .or_default()
294 .push(StatusPhaseLog::new(status));
295 }
296
297 pub fn get_task_status(&self, task: &str) -> Option<StatusPhase> {
299 self.task_status
300 .get(task)
301 .and_then(|logs| logs.last())
302 .map(|log| log.status)
303 }
304
305 pub fn set_input(&mut self, value: Value) {
308 self.input = Some(value);
309 self.invalidate_vars_cache();
310 }
311 pub fn get_input(&self) -> Option<&Value> {
312 self.input.as_ref()
313 }
314 pub fn set_output(&mut self, value: Value) {
315 self.output = Some(value);
316 self.invalidate_vars_cache();
317 }
318 pub fn get_output(&self) -> Option<&Value> {
319 self.output.as_ref()
320 }
321 pub fn set_instance_ctx(&mut self, value: Value) {
322 self.instance_ctx = Some(value);
323 self.invalidate_vars_cache();
324 }
325 pub fn get_instance_ctx(&self) -> Option<&Value> {
326 self.instance_ctx.as_ref()
327 }
328
329 pub fn set_raw_input(&mut self, input: &Value) {
333 let mut desc = (*self.workflow_descriptor).clone();
334 if let Some(obj) = desc.as_object_mut() {
335 obj.insert("input".to_string(), input.clone());
336 }
337 self.workflow_descriptor = Arc::new(desc);
338 self.invalidate_vars_cache();
339 }
340
341 fn task_descriptor_insert(&mut self, key: &str, value: Value) {
345 if let Some(obj) = self.task_descriptor.as_object_mut() {
346 obj.insert(key.to_string(), value);
347 }
348 self.invalidate_vars_cache();
349 }
350
351 pub fn set_task_name(&mut self, name: &str) {
353 self.task_descriptor_insert("name", Value::String(name.to_string()));
354 }
355
356 pub fn set_task_raw_input(&mut self, input: &Value) {
358 self.task_descriptor_insert("input", input.clone());
359 }
360
361 pub fn set_task_raw_output(&mut self, output: &Value) {
363 self.task_descriptor_insert("output", output.clone());
364 }
365
366 pub fn set_task_started_at(&mut self) {
369 let now = chrono::Utc::now();
370 let iso8601 = now.to_rfc3339();
371 let epoch_seconds = now.timestamp();
372 let epoch_millis = now.timestamp_millis();
373 self.task_descriptor_insert(
374 "startedAt",
375 serde_json::json!({
376 "iso8601": iso8601,
377 "epoch": {
378 "seconds": epoch_seconds,
379 "milliseconds": epoch_millis,
380 }
381 }),
382 );
383 }
384
385 pub fn set_task_reference(&mut self, reference: &str) {
387 self.task_descriptor_insert("reference", Value::String(reference.to_string()));
388 }
389
390 pub fn get_task_reference(&self) -> Option<&str> {
392 self.task_descriptor
393 .as_object()
394 .and_then(|obj| obj.get("reference"))
395 .and_then(|v| v.as_str())
396 }
397
398 pub fn get_workflow_json(&self) -> Option<&Value> {
400 self.workflow_descriptor
401 .as_object()
402 .and_then(|obj| obj.get("definition"))
403 }
404
405 pub fn set_task_def(&mut self, task: &Value) {
408 self.task_descriptor_insert("definition", task.clone());
409 }
410
411 pub fn inc_iteration(&mut self, position: &str) -> u32 {
414 let count = self.iterations.entry(position.to_string()).or_insert(0);
415 *count += 1;
416 let value = *count;
417 self.task_descriptor_insert("iteration", serde_json::json!(value));
418 value
419 }
420
421 pub fn set_retry_attempt(&mut self, attempt: u32) {
423 self.task_descriptor_insert("retryAttempt", serde_json::json!(attempt));
424 }
425
426 pub fn clear_task_context(&mut self) {
428 self.task_descriptor = Value::Object(Default::default());
429 }
430
431 arc_accessors!(
434 secret_manager,
435 set_secret_manager,
436 get_secret_manager,
437 clone_secret_manager,
438 dyn SecretManager
439 );
440
441 arc_accessors!(
444 listener,
445 set_listener,
446 get_listener,
447 clone_listener,
448 dyn WorkflowExecutionListener
449 );
450
451 pub fn emit_event(&self, event: WorkflowEvent) {
455 if let Some(ref listener) = self.listener {
457 listener.on_event(&event);
458 }
459
460 if let Some(ref event_bus) = self.event_bus {
462 let cloud_event = event.to_cloud_event();
463 let bus = event_bus.clone();
464 tokio::spawn(async move {
465 bus.publish(cloud_event).await;
466 });
467 }
468 }
469
470 option_accessors!(
473 event_bus,
474 set_event_bus,
475 get_event_bus,
476 clone_event_bus,
477 SharedEventBus
478 );
479
480 pub fn set_sub_workflows(&mut self, sub_workflows: HashMap<String, WorkflowDefinition>) {
484 self.sub_workflows = sub_workflows;
485 }
486
487 pub fn get_sub_workflow(
489 &self,
490 namespace: &str,
491 name: &str,
492 version: &str,
493 ) -> Option<&WorkflowDefinition> {
494 let key = format!("{}/{}/{}", namespace, name, version);
495 self.sub_workflows.get(&key)
496 }
497
498 pub fn clone_sub_workflows(&self) -> HashMap<String, WorkflowDefinition> {
500 self.sub_workflows.clone()
501 }
502
503 pub fn set_handler_registry(&mut self, registry: HandlerRegistry) {
507 self.handler_registry = registry;
508 }
509
510 pub fn get_handler_registry(&self) -> &HandlerRegistry {
512 &self.handler_registry
513 }
514
515 pub fn clone_handler_registry(&self) -> HandlerRegistry {
517 self.handler_registry.clone()
518 }
519
520 pub(crate) fn set_expression_engines(&mut self, engines: ExpressionEngineRegistry) {
524 self.expression_engines = engines;
525 }
526
527 pub(crate) fn get_expression_engines(&self) -> &ExpressionEngineRegistry {
529 &self.expression_engines
530 }
531
532 pub(crate) fn clone_expression_engines(&self) -> ExpressionEngineRegistry {
534 self.expression_engines.clone()
535 }
536
537 pub fn set_functions(&mut self, functions: HashMap<String, TaskDefinition>) {
541 self.functions = functions;
542 }
543
544 pub fn get_function(&self, name: &str) -> Option<&TaskDefinition> {
546 self.functions.get(name)
547 }
548
549 pub fn cancellation_token(&self) -> CancellationToken {
553 self.cancellation_token.clone()
554 }
555
556 pub fn cancel(&self) {
558 self.cancellation_token.cancel();
559 }
560
561 pub fn is_cancelled(&self) -> bool {
563 self.cancellation_token.is_cancelled()
564 }
565
566 pub fn suspend(&self) -> bool {
573 self.suspend_state.suspend()
574 }
575
576 pub fn resume(&self) -> bool {
581 self.suspend_state.resume()
582 }
583
584 pub fn is_suspended(&self) -> bool {
586 self.suspend_state.is_suspended()
587 }
588
589 pub async fn wait_for_resume(&self) {
594 if self.is_suspended() {
595 tokio::select! {
596 _ = self.suspend_state.resume_notify().notified() => {}
597 _ = self.cancellation_token.cancelled() => {}
598 }
599 }
600 }
601
602 pub(crate) fn set_suspend_state(&mut self, state: SuspendState) {
609 self.suspend_state = state;
610 }
611
612 pub fn set_authorization(&mut self, scheme: &str, parameter: &str) {
617 self.authorization = Some(serde_json::json!({
618 "scheme": scheme,
619 "parameter": parameter,
620 }));
621 self.invalidate_vars_cache();
622 }
623
624 pub fn clear_authorization(&mut self) {
626 self.authorization = None;
627 self.invalidate_vars_cache();
628 }
629
630 pub fn set_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
634 self.local_expr_vars = vars;
635 self.invalidate_vars_cache();
636 }
637
638 pub fn add_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
640 for (k, v) in vars {
641 self.local_expr_vars.entry(k).or_insert(v);
642 }
643 self.invalidate_vars_cache();
644 }
645
646 pub fn remove_local_expr_vars(&mut self, keys: &[&str]) {
648 for key in keys {
649 self.local_expr_vars.remove(*key);
650 }
651 self.invalidate_vars_cache();
652 }
653
654 fn invalidate_vars_cache(&self) {
658 self.vars_dirty.store(true, Ordering::Release);
659 }
660
661 pub fn get_vars(&self) -> HashMap<String, Value> {
664 if self.vars_dirty.load(Ordering::Acquire) {
665 let mut vars = HashMap::new();
666
667 vars.insert(
668 vars::INPUT.to_string(),
669 self.input.clone().unwrap_or(Value::Null),
670 );
671 vars.insert(
672 vars::OUTPUT.to_string(),
673 self.output.clone().unwrap_or(Value::Null),
674 );
675 vars.insert(
676 vars::CONTEXT.to_string(),
677 self.instance_ctx.clone().unwrap_or(Value::Null),
678 );
679 vars.insert(vars::TASK.to_string(), self.task_descriptor.clone());
680 vars.insert(
681 vars::WORKFLOW.to_string(),
682 (*self.workflow_descriptor).clone(),
683 );
684 vars.insert(
685 vars::RUNTIME.to_string(),
686 runtime_info::runtime_info_value().clone(),
687 );
688
689 if let Some(ref mgr) = self.secret_manager {
690 vars.insert(vars::SECRET.to_string(), mgr.get_all_secrets());
691 }
692
693 if let Some(ref auth) = self.authorization {
694 vars.insert(vars::AUTHORIZATION.to_string(), auth.clone());
695 }
696
697 for (k, v) in &self.local_expr_vars {
698 vars.insert(k.clone(), v.clone());
699 }
700
701 *self.vars_cache.lock().unwrap() = Some(vars);
702 self.vars_dirty.store(false, Ordering::Release);
703 }
704 self.vars_cache.lock().unwrap().as_ref().unwrap().clone()
705 }
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711 use serde_json::json;
712 use swf_core::models::workflow::WorkflowDefinition;
713
714 fn new_context() -> WorkflowContext {
715 let workflow = WorkflowDefinition::default();
716 WorkflowContext::new(&workflow).unwrap()
717 }
718
719 #[test]
720 fn test_context_new() {
721 let ctx = new_context();
722 assert!(ctx.get_input().is_none());
723 assert!(ctx.get_output().is_none());
724 assert_eq!(ctx.get_status(), StatusPhase::Pending);
725 }
726
727 #[test]
728 fn test_context_set_input_output() {
729 let mut ctx = new_context();
730 ctx.set_input(json!({"key": "value"}));
731 assert_eq!(ctx.get_input(), Some(&json!({"key": "value"})));
732
733 ctx.set_output(json!(42));
734 assert_eq!(ctx.get_output(), Some(&json!(42)));
735 }
736
737 #[test]
738 fn test_context_status_transitions() {
739 let mut ctx = new_context();
740 assert_eq!(ctx.get_status(), StatusPhase::Pending);
741
742 ctx.set_status(StatusPhase::Running);
743 assert_eq!(ctx.get_status(), StatusPhase::Running);
744
745 ctx.set_status(StatusPhase::Completed);
746 assert_eq!(ctx.get_status(), StatusPhase::Completed);
747 }
748
749 #[test]
750 fn test_context_instance_ctx() {
751 let mut ctx = new_context();
752 assert!(ctx.get_instance_ctx().is_none());
753
754 ctx.set_instance_ctx(json!({"exported": "data"}));
755 assert_eq!(ctx.get_instance_ctx(), Some(&json!({"exported": "data"})));
756 }
757
758 #[test]
759 fn test_context_local_expr_vars() {
760 let mut ctx = new_context();
761 let mut vars = HashMap::new();
762 vars.insert("$item".to_string(), json!("hello"));
763 vars.insert("$index".to_string(), json!(0));
764 ctx.add_local_expr_vars(vars);
765
766 let all_vars = ctx.get_vars();
767 assert_eq!(all_vars.get("$item"), Some(&json!("hello")));
768 assert_eq!(all_vars.get("$index"), Some(&json!(0)));
769
770 ctx.remove_local_expr_vars(&["$item", "$index"]);
771 let all_vars = ctx.get_vars();
772 assert!(!all_vars.contains_key("$item"));
773 assert!(!all_vars.contains_key("$index"));
774 }
775
776 #[test]
777 fn test_context_get_vars_includes_runtime() {
778 let ctx = new_context();
779 let vars = ctx.get_vars();
780 assert!(vars.contains_key(vars::RUNTIME));
781 assert!(vars.contains_key(vars::WORKFLOW));
782 assert!(vars.contains_key(vars::TASK));
783 }
784
785 #[test]
786 fn test_context_task_status() {
787 let mut ctx = new_context();
788 ctx.set_task_status("task1", StatusPhase::Running);
789 ctx.set_task_status("task1", StatusPhase::Completed);
790 ctx.set_task_status("task2", StatusPhase::Pending);
791
792 let task1_status = ctx.get_task_status("task1");
793 assert_eq!(task1_status, Some(StatusPhase::Completed));
794 }
795
796 #[test]
797 fn test_context_authorization() {
798 let mut ctx = new_context();
799
800 let vars = ctx.get_vars();
802 assert!(!vars.contains_key("$authorization"));
803
804 ctx.set_authorization("Bearer", "my-token-123");
806 let vars = ctx.get_vars();
807 let auth = vars
808 .get("$authorization")
809 .expect("$authorization should be set");
810 assert_eq!(auth["scheme"], "Bearer");
811 assert_eq!(auth["parameter"], "my-token-123");
812
813 ctx.clear_authorization();
815 let vars = ctx.get_vars();
816 assert!(!vars.contains_key("$authorization"));
817 }
818}