1use crate::events::SharedEventBus;
2use crate::handler::HandlerRegistry;
3use crate::listener::{WorkflowEvent, WorkflowExecutionListener};
4use crate::secret::SecretManager;
5use crate::status::{StatusPhase, StatusPhaseLog};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use swf_core::models::task::TaskDefinition;
11use swf_core::models::workflow::WorkflowDefinition;
12use tokio::sync::Notify;
13
14macro_rules! arc_accessors {
16 ($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
17 pub fn $setter(&mut self, value: Arc<$ty>) {
18 self.$field = Some(value);
19 }
20 pub fn $getter(&self) -> Option<&$ty> {
21 self.$field.as_deref()
22 }
23 pub fn $clone(&self) -> Option<Arc<$ty>> {
24 self.$field.clone()
25 }
26 };
27}
28
29macro_rules! option_accessors {
31 ($field:ident, $setter:ident, $getter:ident, $clone:ident, $ty:ty) => {
32 pub fn $setter(&mut self, value: $ty) {
33 self.$field = Some(value);
34 }
35 pub fn $getter(&self) -> Option<&$ty> {
36 self.$field.as_ref()
37 }
38 pub fn $clone(&self) -> Option<$ty> {
39 self.$field.clone()
40 }
41 };
42}
43
44#[derive(Clone)]
47pub(crate) struct SuspendState {
48 suspended: Arc<AtomicBool>,
49 resume_notify: Arc<Notify>,
50}
51
52impl SuspendState {
53 pub(crate) fn new() -> Self {
54 Self {
55 suspended: Arc::new(AtomicBool::new(false)),
56 resume_notify: Arc::new(Notify::new()),
57 }
58 }
59
60 pub fn suspend(&self) -> bool {
62 self.suspended
63 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
64 .is_ok()
65 }
66
67 pub fn resume(&self) -> bool {
69 if self
70 .suspended
71 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
72 .is_ok()
73 {
74 self.resume_notify.notify_waiters();
75 true
76 } else {
77 false
78 }
79 }
80
81 pub fn is_suspended(&self) -> bool {
83 self.suspended.load(Ordering::SeqCst)
84 }
85
86 pub(crate) fn resume_notify(&self) -> &Arc<Notify> {
88 &self.resume_notify
89 }
90}
91use tokio_util::sync::CancellationToken;
92
93pub mod vars {
95 pub const CONTEXT: &str = "$context";
96 pub const INPUT: &str = "$input";
97 pub const OUTPUT: &str = "$output";
98 pub const WORKFLOW: &str = "$workflow";
99 pub const RUNTIME: &str = "$runtime";
100 pub const TASK: &str = "$task";
101 pub const SECRET: &str = "$secret";
102 pub const AUTHORIZATION: &str = "$authorization";
103}
104
105pub mod runtime_info {
107 pub const NAME: &str = "CNCF Serverless Workflow Specification Rust SDK";
108 pub const VERSION: &str = env!("CARGO_PKG_VERSION");
109
110 static RUNTIME_INFO: std::sync::LazyLock<serde_json::Value> = std::sync::LazyLock::new(|| {
112 serde_json::json!({
113 "name": NAME,
114 "version": VERSION,
115 })
116 });
117
118 pub fn runtime_info_value() -> &'static serde_json::Value {
119 &RUNTIME_INFO
120 }
121}
122
123pub struct WorkflowContext {
125 input: Option<Value>,
127 output: Option<Value>,
129 instance_ctx: Option<Value>,
131 workflow_descriptor: Arc<Value>,
133 task_descriptor: Value,
135 local_expr_vars: HashMap<String, Value>,
137 authorization: Option<Value>,
139 secret_manager: Option<Arc<dyn SecretManager>>,
141 listener: Option<Arc<dyn WorkflowExecutionListener>>,
143 event_bus: Option<SharedEventBus>,
145 sub_workflows: HashMap<String, WorkflowDefinition>,
147 cancellation_token: CancellationToken,
149 suspend_state: SuspendState,
151 handler_registry: HandlerRegistry,
153 functions: HashMap<String, TaskDefinition>,
155 status_log: Vec<StatusPhaseLog>,
157 task_status: HashMap<String, Vec<StatusPhaseLog>>,
159 iterations: HashMap<String, u32>,
161 vars_cache: Mutex<Option<HashMap<String, Value>>>,
163 vars_dirty: AtomicBool,
165}
166
167impl Clone for WorkflowContext {
168 fn clone(&self) -> Self {
169 Self {
170 input: self.input.clone(),
171 output: self.output.clone(),
172 instance_ctx: self.instance_ctx.clone(),
173 workflow_descriptor: Arc::clone(&self.workflow_descriptor),
174 task_descriptor: self.task_descriptor.clone(),
175 local_expr_vars: self.local_expr_vars.clone(),
176 authorization: self.authorization.clone(),
177 secret_manager: self.secret_manager.clone(),
178 listener: self.listener.clone(),
179 event_bus: self.event_bus.clone(),
180 sub_workflows: self.sub_workflows.clone(),
181 cancellation_token: self.cancellation_token.clone(),
182 suspend_state: self.suspend_state.clone(),
183 handler_registry: self.handler_registry.clone(),
184 functions: self.functions.clone(),
185 status_log: self.status_log.clone(),
186 task_status: self.task_status.clone(),
187 iterations: self.iterations.clone(),
188 vars_cache: Mutex::new(self.vars_cache.lock().unwrap().clone()),
189 vars_dirty: AtomicBool::new(self.vars_dirty.load(Ordering::Acquire)),
190 }
191 }
192}
193
194impl std::fmt::Debug for WorkflowContext {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct("WorkflowContext")
197 .field("input", &self.input)
198 .field("output", &self.output)
199 .field("instance_ctx", &self.instance_ctx)
200 .field("workflow_descriptor", &self.workflow_descriptor)
201 .field("task_descriptor", &self.task_descriptor)
202 .field("local_expr_vars", &self.local_expr_vars)
203 .field(
204 "secret_manager",
205 &self.secret_manager.as_ref().map(|_| "..."),
206 )
207 .field("listener", &self.listener.as_ref().map(|_| "..."))
208 .field("event_bus", &self.event_bus.as_ref().map(|_| "..."))
209 .field("status_log", &self.status_log)
210 .field("task_status", &self.task_status)
211 .field("iterations", &self.iterations)
212 .finish()
213 }
214}
215
216impl WorkflowContext {
217 pub fn new(
219 workflow: &swf_core::models::workflow::WorkflowDefinition,
220 ) -> crate::error::WorkflowResult<Self> {
221 let workflow_json = serde_json::to_value(workflow).map_err(|e| {
222 crate::error::WorkflowError::runtime(
223 format!("failed to serialize workflow definition: {}", e),
224 "/",
225 "/",
226 )
227 })?;
228
229 let workflow_descriptor = Arc::new(serde_json::json!({
230 "id": uuid::Uuid::new_v4().to_string(),
231 "definition": workflow_json,
232 }));
233
234 let mut ctx = Self {
235 input: None,
236 output: None,
237 instance_ctx: None,
238 workflow_descriptor,
239 task_descriptor: Value::Object(Default::default()),
240 local_expr_vars: HashMap::new(),
241 authorization: None,
242 secret_manager: None,
243 listener: None,
244 event_bus: None,
245 sub_workflows: HashMap::new(),
246 cancellation_token: CancellationToken::new(),
247 suspend_state: SuspendState::new(),
248 handler_registry: HandlerRegistry::new(),
249 functions: HashMap::new(),
250 status_log: Vec::new(),
251 task_status: HashMap::new(),
252 iterations: HashMap::new(),
253 vars_cache: Mutex::new(None),
254 vars_dirty: AtomicBool::new(true),
255 };
256 ctx.set_status(StatusPhase::Pending);
257 Ok(ctx)
258 }
259
260 pub fn set_status(&mut self, status: StatusPhase) {
264 self.status_log.push(StatusPhaseLog::new(status));
265 }
266
267 pub fn instance_id(&self) -> &str {
269 self.workflow_descriptor
270 .as_object()
271 .and_then(|obj| obj.get("id"))
272 .and_then(|id| id.as_str())
273 .unwrap_or("unknown")
274 }
275
276 pub fn get_status(&self) -> StatusPhase {
278 self.status_log
279 .last()
280 .map(|log| log.status)
281 .unwrap_or(StatusPhase::Pending)
282 }
283
284 pub fn set_task_status(&mut self, task: &str, status: StatusPhase) {
286 self.task_status
287 .entry(task.to_string())
288 .or_default()
289 .push(StatusPhaseLog::new(status));
290 }
291
292 pub fn get_task_status(&self, task: &str) -> Option<StatusPhase> {
294 self.task_status
295 .get(task)
296 .and_then(|logs| logs.last())
297 .map(|log| log.status)
298 }
299
300 pub fn set_input(&mut self, value: Value) {
303 self.input = Some(value);
304 self.invalidate_vars_cache();
305 }
306 pub fn get_input(&self) -> Option<&Value> {
307 self.input.as_ref()
308 }
309 pub fn set_output(&mut self, value: Value) {
310 self.output = Some(value);
311 self.invalidate_vars_cache();
312 }
313 pub fn get_output(&self) -> Option<&Value> {
314 self.output.as_ref()
315 }
316 pub fn set_instance_ctx(&mut self, value: Value) {
317 self.instance_ctx = Some(value);
318 self.invalidate_vars_cache();
319 }
320 pub fn get_instance_ctx(&self) -> Option<&Value> {
321 self.instance_ctx.as_ref()
322 }
323
324 pub fn set_raw_input(&mut self, input: &Value) {
328 let mut desc = (*self.workflow_descriptor).clone();
329 if let Some(obj) = desc.as_object_mut() {
330 obj.insert("input".to_string(), input.clone());
331 }
332 self.workflow_descriptor = Arc::new(desc);
333 self.invalidate_vars_cache();
334 }
335
336 fn task_descriptor_insert(&mut self, key: &str, value: Value) {
340 if let Some(obj) = self.task_descriptor.as_object_mut() {
341 obj.insert(key.to_string(), value);
342 }
343 self.invalidate_vars_cache();
344 }
345
346 pub fn set_task_name(&mut self, name: &str) {
348 self.task_descriptor_insert("name", Value::String(name.to_string()));
349 }
350
351 pub fn set_task_raw_input(&mut self, input: &Value) {
353 self.task_descriptor_insert("input", input.clone());
354 }
355
356 pub fn set_task_raw_output(&mut self, output: &Value) {
358 self.task_descriptor_insert("output", output.clone());
359 }
360
361 pub fn set_task_started_at(&mut self) {
364 let now = chrono::Utc::now();
365 let iso8601 = now.to_rfc3339();
366 let epoch_seconds = now.timestamp();
367 let epoch_millis = now.timestamp_millis();
368 self.task_descriptor_insert(
369 "startedAt",
370 serde_json::json!({
371 "iso8601": iso8601,
372 "epoch": {
373 "seconds": epoch_seconds,
374 "milliseconds": epoch_millis,
375 }
376 }),
377 );
378 }
379
380 pub fn set_task_reference(&mut self, reference: &str) {
382 self.task_descriptor_insert("reference", Value::String(reference.to_string()));
383 }
384
385 pub fn get_task_reference(&self) -> Option<&str> {
387 self.task_descriptor
388 .as_object()
389 .and_then(|obj| obj.get("reference"))
390 .and_then(|v| v.as_str())
391 }
392
393 pub fn get_workflow_json(&self) -> Option<&Value> {
395 self.workflow_descriptor
396 .as_object()
397 .and_then(|obj| obj.get("definition"))
398 }
399
400 pub fn set_task_def(&mut self, task: &Value) {
403 self.task_descriptor_insert("definition", task.clone());
404 }
405
406 pub fn inc_iteration(&mut self, position: &str) -> u32 {
409 let count = self.iterations.entry(position.to_string()).or_insert(0);
410 *count += 1;
411 let value = *count;
412 self.task_descriptor_insert("iteration", serde_json::json!(value));
413 value
414 }
415
416 pub fn set_retry_attempt(&mut self, attempt: u32) {
418 self.task_descriptor_insert("retryAttempt", serde_json::json!(attempt));
419 }
420
421 pub fn clear_task_context(&mut self) {
423 self.task_descriptor = Value::Object(Default::default());
424 }
425
426 arc_accessors!(
429 secret_manager,
430 set_secret_manager,
431 get_secret_manager,
432 clone_secret_manager,
433 dyn SecretManager
434 );
435
436 arc_accessors!(
439 listener,
440 set_listener,
441 get_listener,
442 clone_listener,
443 dyn WorkflowExecutionListener
444 );
445
446 pub fn emit_event(&self, event: WorkflowEvent) {
450 if let Some(ref listener) = self.listener {
452 listener.on_event(&event);
453 }
454
455 if let Some(ref event_bus) = self.event_bus {
457 let cloud_event = event.to_cloud_event();
458 let bus = event_bus.clone();
459 tokio::spawn(async move {
460 bus.publish(cloud_event).await;
461 });
462 }
463 }
464
465 option_accessors!(
468 event_bus,
469 set_event_bus,
470 get_event_bus,
471 clone_event_bus,
472 SharedEventBus
473 );
474
475 pub fn set_sub_workflows(&mut self, sub_workflows: HashMap<String, WorkflowDefinition>) {
479 self.sub_workflows = sub_workflows;
480 }
481
482 pub fn get_sub_workflow(
484 &self,
485 namespace: &str,
486 name: &str,
487 version: &str,
488 ) -> Option<&WorkflowDefinition> {
489 let key = format!("{}/{}/{}", namespace, name, version);
490 self.sub_workflows.get(&key)
491 }
492
493 pub fn clone_sub_workflows(&self) -> HashMap<String, WorkflowDefinition> {
495 self.sub_workflows.clone()
496 }
497
498 pub fn set_handler_registry(&mut self, registry: HandlerRegistry) {
502 self.handler_registry = registry;
503 }
504
505 pub fn get_handler_registry(&self) -> &HandlerRegistry {
507 &self.handler_registry
508 }
509
510 pub fn clone_handler_registry(&self) -> HandlerRegistry {
512 self.handler_registry.clone()
513 }
514
515 pub fn set_functions(&mut self, functions: HashMap<String, TaskDefinition>) {
519 self.functions = functions;
520 }
521
522 pub fn get_function(&self, name: &str) -> Option<&TaskDefinition> {
524 self.functions.get(name)
525 }
526
527 pub fn cancellation_token(&self) -> CancellationToken {
531 self.cancellation_token.clone()
532 }
533
534 pub fn cancel(&self) {
536 self.cancellation_token.cancel();
537 }
538
539 pub fn is_cancelled(&self) -> bool {
541 self.cancellation_token.is_cancelled()
542 }
543
544 pub fn suspend(&self) -> bool {
551 self.suspend_state.suspend()
552 }
553
554 pub fn resume(&self) -> bool {
559 self.suspend_state.resume()
560 }
561
562 pub fn is_suspended(&self) -> bool {
564 self.suspend_state.is_suspended()
565 }
566
567 pub async fn wait_for_resume(&self) {
572 if self.is_suspended() {
573 tokio::select! {
574 _ = self.suspend_state.resume_notify().notified() => {}
575 _ = self.cancellation_token.cancelled() => {}
576 }
577 }
578 }
579
580 pub(crate) fn set_suspend_state(&mut self, state: SuspendState) {
587 self.suspend_state = state;
588 }
589
590 pub fn set_authorization(&mut self, scheme: &str, parameter: &str) {
595 self.authorization = Some(serde_json::json!({
596 "scheme": scheme,
597 "parameter": parameter,
598 }));
599 self.invalidate_vars_cache();
600 }
601
602 pub fn clear_authorization(&mut self) {
604 self.authorization = None;
605 self.invalidate_vars_cache();
606 }
607
608 pub fn set_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
612 self.local_expr_vars = vars;
613 self.invalidate_vars_cache();
614 }
615
616 pub fn add_local_expr_vars(&mut self, vars: HashMap<String, Value>) {
618 for (k, v) in vars {
619 self.local_expr_vars.entry(k).or_insert(v);
620 }
621 self.invalidate_vars_cache();
622 }
623
624 pub fn remove_local_expr_vars(&mut self, keys: &[&str]) {
626 for key in keys {
627 self.local_expr_vars.remove(*key);
628 }
629 self.invalidate_vars_cache();
630 }
631
632 fn invalidate_vars_cache(&self) {
636 self.vars_dirty.store(true, Ordering::Release);
637 }
638
639 pub fn get_vars(&self) -> HashMap<String, Value> {
642 if self.vars_dirty.load(Ordering::Acquire) {
643 let mut vars = HashMap::new();
644
645 vars.insert(
646 vars::INPUT.to_string(),
647 self.input.clone().unwrap_or(Value::Null),
648 );
649 vars.insert(
650 vars::OUTPUT.to_string(),
651 self.output.clone().unwrap_or(Value::Null),
652 );
653 vars.insert(
654 vars::CONTEXT.to_string(),
655 self.instance_ctx.clone().unwrap_or(Value::Null),
656 );
657 vars.insert(vars::TASK.to_string(), self.task_descriptor.clone());
658 vars.insert(
659 vars::WORKFLOW.to_string(),
660 (*self.workflow_descriptor).clone(),
661 );
662 vars.insert(
663 vars::RUNTIME.to_string(),
664 runtime_info::runtime_info_value().clone(),
665 );
666
667 if let Some(ref mgr) = self.secret_manager {
668 vars.insert(vars::SECRET.to_string(), mgr.get_all_secrets());
669 }
670
671 if let Some(ref auth) = self.authorization {
672 vars.insert(vars::AUTHORIZATION.to_string(), auth.clone());
673 }
674
675 for (k, v) in &self.local_expr_vars {
676 vars.insert(k.clone(), v.clone());
677 }
678
679 *self.vars_cache.lock().unwrap() = Some(vars);
680 self.vars_dirty.store(false, Ordering::Release);
681 }
682 self.vars_cache.lock().unwrap().as_ref().unwrap().clone()
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use serde_json::json;
690 use swf_core::models::workflow::WorkflowDefinition;
691
692 fn new_context() -> WorkflowContext {
693 let workflow = WorkflowDefinition::default();
694 WorkflowContext::new(&workflow).unwrap()
695 }
696
697 #[test]
698 fn test_context_new() {
699 let ctx = new_context();
700 assert!(ctx.get_input().is_none());
701 assert!(ctx.get_output().is_none());
702 assert_eq!(ctx.get_status(), StatusPhase::Pending);
703 }
704
705 #[test]
706 fn test_context_set_input_output() {
707 let mut ctx = new_context();
708 ctx.set_input(json!({"key": "value"}));
709 assert_eq!(ctx.get_input(), Some(&json!({"key": "value"})));
710
711 ctx.set_output(json!(42));
712 assert_eq!(ctx.get_output(), Some(&json!(42)));
713 }
714
715 #[test]
716 fn test_context_status_transitions() {
717 let mut ctx = new_context();
718 assert_eq!(ctx.get_status(), StatusPhase::Pending);
719
720 ctx.set_status(StatusPhase::Running);
721 assert_eq!(ctx.get_status(), StatusPhase::Running);
722
723 ctx.set_status(StatusPhase::Completed);
724 assert_eq!(ctx.get_status(), StatusPhase::Completed);
725 }
726
727 #[test]
728 fn test_context_instance_ctx() {
729 let mut ctx = new_context();
730 assert!(ctx.get_instance_ctx().is_none());
731
732 ctx.set_instance_ctx(json!({"exported": "data"}));
733 assert_eq!(ctx.get_instance_ctx(), Some(&json!({"exported": "data"})));
734 }
735
736 #[test]
737 fn test_context_local_expr_vars() {
738 let mut ctx = new_context();
739 let mut vars = HashMap::new();
740 vars.insert("$item".to_string(), json!("hello"));
741 vars.insert("$index".to_string(), json!(0));
742 ctx.add_local_expr_vars(vars);
743
744 let all_vars = ctx.get_vars();
745 assert_eq!(all_vars.get("$item"), Some(&json!("hello")));
746 assert_eq!(all_vars.get("$index"), Some(&json!(0)));
747
748 ctx.remove_local_expr_vars(&["$item", "$index"]);
749 let all_vars = ctx.get_vars();
750 assert!(!all_vars.contains_key("$item"));
751 assert!(!all_vars.contains_key("$index"));
752 }
753
754 #[test]
755 fn test_context_get_vars_includes_runtime() {
756 let ctx = new_context();
757 let vars = ctx.get_vars();
758 assert!(vars.contains_key(vars::RUNTIME));
759 assert!(vars.contains_key(vars::WORKFLOW));
760 assert!(vars.contains_key(vars::TASK));
761 }
762
763 #[test]
764 fn test_context_task_status() {
765 let mut ctx = new_context();
766 ctx.set_task_status("task1", StatusPhase::Running);
767 ctx.set_task_status("task1", StatusPhase::Completed);
768 ctx.set_task_status("task2", StatusPhase::Pending);
769
770 let task1_status = ctx.get_task_status("task1");
771 assert_eq!(task1_status, Some(StatusPhase::Completed));
772 }
773
774 #[test]
775 fn test_context_authorization() {
776 let mut ctx = new_context();
777
778 let vars = ctx.get_vars();
780 assert!(!vars.contains_key("$authorization"));
781
782 ctx.set_authorization("Bearer", "my-token-123");
784 let vars = ctx.get_vars();
785 let auth = vars
786 .get("$authorization")
787 .expect("$authorization should be set");
788 assert_eq!(auth["scheme"], "Bearer");
789 assert_eq!(auth["parameter"], "my-token-123");
790
791 ctx.clear_authorization();
793 let vars = ctx.get_vars();
794 assert!(!vars.contains_key("$authorization"));
795 }
796}