1use std::collections::HashMap;
15
16use bytes::Bytes;
17use chrono;
18use futures::FutureExt;
19use sayiir_core::codec::Codec;
20use sayiir_core::codec::sealed;
21use sayiir_core::context::with_context;
22use sayiir_core::error::{BoxError, WorkflowError};
23use sayiir_core::registry::TaskRegistry;
24use sayiir_core::snapshot::{
25 ExecutionPosition, SignalKind, SignalRequest, TaskDeadline, WorkflowSnapshot,
26};
27use sayiir_core::task_claim::AvailableTask;
28use sayiir_core::workflow::{Workflow, WorkflowContinuation, WorkflowStatus};
29use sayiir_persistence::{PersistentBackend, SignalStore, TaskClaimStore};
30use std::num::NonZeroUsize;
31use std::panic::AssertUnwindSafe;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::time::Duration;
35use tokio::sync::mpsc;
36use tokio::time;
37
38pub type WorkflowRegistry<C, Input, M> = Vec<(String, Arc<Workflow<C, Input, M>>)>;
40
41pub struct ExternalWorkflow {
47 pub continuation: Arc<WorkflowContinuation>,
49}
50
51pub type WorkflowIndex = HashMap<String, ExternalWorkflow>;
53
54pub type ExternalTaskExecutor = Arc<
60 dyn Fn(
61 &str,
62 Bytes,
63 ) -> Pin<Box<dyn std::future::Future<Output = Result<Bytes, BoxError>> + Send>>
64 + Send
65 + Sync,
66>;
67
68enum WorkerCommand {
70 Shutdown,
71}
72
73struct WorkerHandleInner<B> {
74 backend: Arc<B>,
75 shutdown_tx: mpsc::Sender<WorkerCommand>,
76 join_handle:
77 tokio::sync::Mutex<Option<tokio::task::JoinHandle<Result<(), crate::error::RuntimeError>>>>,
78}
79
80pub struct WorkerHandle<B> {
86 inner: Arc<WorkerHandleInner<B>>,
87}
88
89impl<B> Clone for WorkerHandle<B> {
90 fn clone(&self) -> Self {
91 Self {
92 inner: Arc::clone(&self.inner),
93 }
94 }
95}
96
97impl<B> WorkerHandle<B> {
98 pub fn shutdown(&self) {
104 let _ = self.inner.shutdown_tx.try_send(WorkerCommand::Shutdown);
105 }
106
107 pub async fn join(&self) -> Result<(), crate::error::RuntimeError> {
115 let jh = self.inner.join_handle.lock().await.take();
116 match jh {
117 Some(jh) => Ok(jh.await??),
118 None => Ok(()),
119 }
120 }
121
122 #[must_use]
124 pub fn backend(&self) -> &Arc<B> {
125 &self.inner.backend
126 }
127}
128
129impl<B: SignalStore> WorkerHandle<B> {
130 pub async fn cancel_workflow(
139 &self,
140 instance_id: &str,
141 reason: Option<String>,
142 cancelled_by: Option<String>,
143 ) -> Result<(), crate::error::RuntimeError> {
144 self.inner
145 .backend
146 .store_signal(
147 instance_id,
148 SignalKind::Cancel,
149 SignalRequest::new(reason, cancelled_by),
150 )
151 .await?;
152 Ok(())
153 }
154
155 pub async fn pause_workflow(
164 &self,
165 instance_id: &str,
166 reason: Option<String>,
167 paused_by: Option<String>,
168 ) -> Result<(), crate::error::RuntimeError> {
169 self.inner
170 .backend
171 .store_signal(
172 instance_id,
173 SignalKind::Pause,
174 SignalRequest::new(reason, paused_by),
175 )
176 .await?;
177 Ok(())
178 }
179}
180
181struct ActiveTaskClaim<'a, B> {
185 backend: &'a B,
186 instance_id: String,
187 task_id: String,
188 worker_id: String,
189}
190
191impl<B: TaskClaimStore> ActiveTaskClaim<'_, B> {
192 async fn release(self) -> Result<(), crate::error::RuntimeError> {
194 self.backend
195 .release_task_claim(&self.instance_id, &self.task_id, &self.worker_id)
196 .await?;
197 Ok(())
198 }
199
200 async fn release_quietly(self) {
202 let _ = self.release().await;
203 }
204}
205
206enum ExecutionOutcome {
208 Success(Bytes),
210 TaskError(crate::error::RuntimeError),
212 Panic(Box<dyn std::any::Any + Send>),
214 Timeout(crate::error::RuntimeError),
216}
217
218fn extract_panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
220 if let Some(s) = payload.downcast_ref::<&str>() {
221 s.to_string()
222 } else if let Some(s) = payload.downcast_ref::<String>() {
223 s.clone()
224 } else {
225 "Task panicked with unknown payload".to_string()
226 }
227}
228
229pub struct PooledWorker<B> {
271 worker_id: String,
272 backend: Arc<B>,
273 #[allow(unused)]
274 registry: Arc<TaskRegistry>,
275 claim_ttl: Option<Duration>,
276 batch_size: NonZeroUsize,
277}
278
279impl<B> PooledWorker<B>
280where
281 B: PersistentBackend + TaskClaimStore + 'static,
282{
283 pub fn new(worker_id: impl Into<String>, backend: B, registry: TaskRegistry) -> Self {
297 Self {
298 worker_id: worker_id.into(),
299 backend: Arc::new(backend),
300 registry: Arc::new(registry),
301 claim_ttl: Some(Duration::from_secs(5 * 60)), batch_size: NonZeroUsize::MIN, }
304 }
305
306 #[must_use]
308 pub fn with_claim_ttl(mut self, ttl: Option<Duration>) -> Self {
309 self.claim_ttl = ttl;
310 self
311 }
312
313 #[must_use]
321 pub fn with_batch_size(mut self, size: NonZeroUsize) -> Self {
322 self.batch_size = size;
323 self
324 }
325
326 pub async fn cancel_workflow(
341 &self,
342 instance_id: &str,
343 reason: Option<String>,
344 cancelled_by: Option<String>,
345 ) -> Result<(), crate::error::RuntimeError> {
346 self.backend
347 .store_signal(
348 instance_id,
349 SignalKind::Cancel,
350 SignalRequest::new(reason, cancelled_by),
351 )
352 .await?;
353
354 Ok(())
355 }
356
357 pub async fn pause_workflow(
372 &self,
373 instance_id: &str,
374 reason: Option<String>,
375 paused_by: Option<String>,
376 ) -> Result<(), crate::error::RuntimeError> {
377 self.backend
378 .store_signal(
379 instance_id,
380 SignalKind::Pause,
381 SignalRequest::new(reason, paused_by),
382 )
383 .await?;
384
385 Ok(())
386 }
387
388 #[must_use]
390 pub fn backend(&self) -> &Arc<B> {
391 &self.backend
392 }
393
394 #[must_use]
411 pub fn spawn<C, Input, M>(
412 self,
413 poll_interval: Duration,
414 workflows: WorkflowRegistry<C, Input, M>,
415 ) -> WorkerHandle<B>
416 where
417 Input: Send + Sync + 'static,
418 M: Send + Sync + 'static,
419 C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
420 {
421 let (tx, rx) = mpsc::channel(1);
422 let backend = Arc::clone(&self.backend);
423 let join_handle =
424 tokio::spawn(async move { self.run_actor_loop(poll_interval, workflows, rx).await });
425 WorkerHandle {
426 inner: Arc::new(WorkerHandleInner {
427 backend,
428 shutdown_tx: tx,
429 join_handle: tokio::sync::Mutex::new(Some(join_handle)),
430 }),
431 }
432 }
433
434 #[must_use]
447 pub fn spawn_with_executor(
448 self,
449 poll_interval: Duration,
450 workflows: WorkflowIndex,
451 executor: ExternalTaskExecutor,
452 ) -> WorkerHandle<B> {
453 let (tx, rx) = mpsc::channel(1);
454 let backend = Arc::clone(&self.backend);
455 let join_handle = tokio::spawn(async move {
456 self.run_external_actor_loop(poll_interval, workflows, executor, rx)
457 .await
458 });
459 WorkerHandle {
460 inner: Arc::new(WorkerHandleInner {
461 backend,
462 shutdown_tx: tx,
463 join_handle: tokio::sync::Mutex::new(Some(join_handle)),
464 }),
465 }
466 }
467
468 async fn run_external_actor_loop(
470 &self,
471 poll_interval: Duration,
472 workflows: WorkflowIndex,
473 executor: ExternalTaskExecutor,
474 mut cmd_rx: mpsc::Receiver<WorkerCommand>,
475 ) -> Result<(), crate::error::RuntimeError> {
476 let mut interval = time::interval(poll_interval);
477
478 loop {
479 tokio::select! {
480 biased;
481
482 cmd = cmd_rx.recv() => {
483 match cmd {
484 Some(WorkerCommand::Shutdown) | None => {
485 tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
486 return Ok(());
487 }
488 }
489 }
490
491 _ = interval.tick() => {
492 tracing::trace!(worker_id = %self.worker_id, "Will poll for available tasks");
493 }
494 }
495
496 let available_tasks = self
497 .backend
498 .find_available_tasks(&self.worker_id, self.batch_size.get())
499 .await?;
500
501 for task in available_tasks {
502 if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
503 cmd_rx.try_recv()
504 {
505 tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
506 return Ok(());
507 }
508
509 if let Some(ext_wf) = workflows.get(&task.workflow_definition_hash) {
510 match self
511 .execute_external_task(
512 &ext_wf.continuation,
513 &task.workflow_definition_hash,
514 &executor,
515 &task,
516 )
517 .await
518 {
519 Err(ref e) if e.is_timeout() => {
520 tracing::error!(
521 worker_id = %self.worker_id,
522 error = %e,
523 "Task timed out — worker shutting down"
524 );
525 return Ok(());
526 }
527 Ok(_) => {
528 tracing::info!("Worker {} completed a task", self.worker_id);
529 }
530 Err(e) => {
531 tracing::error!(
532 "Worker {} task execution failed: {}",
533 self.worker_id,
534 e
535 );
536 }
537 }
538 }
539 }
540 }
541 }
542
543 async fn execute_external_task(
545 &self,
546 continuation: &WorkflowContinuation,
547 definition_hash: &str,
548 executor: &ExternalTaskExecutor,
549 available_task: &AvailableTask,
550 ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
551 let mut snapshot = self
552 .backend
553 .load_snapshot(&available_task.instance_id)
554 .await?;
555 let already_completed = Self::validate_task_preconditions(
556 definition_hash,
557 continuation,
558 available_task,
559 &snapshot,
560 )?;
561 if already_completed {
562 return Ok(WorkflowStatus::InProgress);
563 }
564
565 let Some(claim) = self.claim_task(available_task).await? else {
566 return Ok(WorkflowStatus::InProgress);
567 };
568
569 if let Some(status) = self.check_post_claim_guards(available_task).await? {
570 claim.release_quietly().await;
571 return Ok(status);
572 }
573
574 tracing::debug!(
575 instance_id = %available_task.instance_id,
576 task_id = %available_task.task_id,
577 "Executing task (external)"
578 );
579
580 let execution_result = self
581 .execute_with_deadline_ext(
582 continuation,
583 executor,
584 available_task,
585 &mut snapshot,
586 &claim,
587 )
588 .await;
589
590 self.settle_execution_result_ext(
591 execution_result,
592 continuation,
593 available_task,
594 &mut snapshot,
595 claim,
596 )
597 .await
598 }
599
600 async fn execute_with_deadline_ext(
602 &self,
603 continuation: &WorkflowContinuation,
604 executor: &ExternalTaskExecutor,
605 available_task: &AvailableTask,
606 snapshot: &mut WorkflowSnapshot,
607 claim: &ActiveTaskClaim<'_, B>,
608 ) -> ExecutionOutcome {
609 let task_id = available_task.task_id.clone();
610 let input = available_task.input.clone();
611
612 let deadline = if let Some(timeout) = continuation.get_task_timeout(&task_id) {
613 snapshot.set_task_deadline(task_id.clone(), timeout);
614 let _ = self.backend.save_snapshot(snapshot).await;
615 snapshot.refresh_task_deadline();
616 snapshot.task_deadline.clone()
617 } else {
618 None
619 };
620
621 let execution_future = executor(&task_id, input);
622
623 let heartbeat_result = self
624 .run_with_heartbeat(
625 claim,
626 deadline.as_ref(),
627 AssertUnwindSafe(execution_future).catch_unwind(),
628 )
629 .await;
630
631 snapshot.clear_task_deadline();
632
633 match heartbeat_result {
634 Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
635 Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
636 Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e.into()),
637 Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
638 }
639 }
640
641 async fn settle_execution_result_ext(
643 &self,
644 outcome: ExecutionOutcome,
645 continuation: &WorkflowContinuation,
646 available_task: &AvailableTask,
647 snapshot: &mut WorkflowSnapshot,
648 claim: ActiveTaskClaim<'_, B>,
649 ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
650 match outcome {
651 ExecutionOutcome::Timeout(err) => {
652 if let Ok(Some(status)) = self
653 .try_schedule_retry(continuation, available_task, snapshot, &err.to_string())
654 .await
655 {
656 claim.release_quietly().await;
657 return Ok(status);
658 }
659
660 tracing::warn!(
661 instance_id = %available_task.instance_id,
662 task_id = %available_task.task_id,
663 error = %err,
664 "Task timed out via heartbeat — marking workflow failed, shutting down"
665 );
666 snapshot.mark_failed(err.to_string());
667 let _ = self.backend.save_snapshot(snapshot).await;
668 claim.release_quietly().await;
669 Err(err)
670 }
671 ExecutionOutcome::Panic(panic_payload) => {
672 let panic_msg = extract_panic_message(&panic_payload);
673
674 if let Ok(Some(status)) = self
675 .try_schedule_retry(continuation, available_task, snapshot, &panic_msg)
676 .await
677 {
678 claim.release_quietly().await;
679 return Ok(status);
680 }
681
682 tracing::error!(
683 instance_id = %available_task.instance_id,
684 task_id = %available_task.task_id,
685 panic = %panic_msg,
686 "Task panicked - releasing claim"
687 );
688 claim.release_quietly().await;
689 Err(WorkflowError::TaskPanicked(panic_msg).into())
690 }
691 ExecutionOutcome::TaskError(e) => {
692 if let Ok(Some(status)) = self
693 .try_schedule_retry(continuation, available_task, snapshot, &e.to_string())
694 .await
695 {
696 claim.release_quietly().await;
697 return Ok(status);
698 }
699
700 tracing::error!(
701 instance_id = %available_task.instance_id,
702 task_id = %available_task.task_id,
703 error = %e,
704 "Task execution failed"
705 );
706 claim.release_quietly().await;
707 Err(e)
708 }
709 ExecutionOutcome::Success(output) => {
710 snapshot.clear_retry_state(&available_task.task_id);
711 self.commit_task_result(
712 continuation,
713 available_task,
714 snapshot,
715 output.clone(),
716 claim,
717 )
718 .await?;
719 self.determine_post_task_status(continuation, available_task, snapshot, output)
720 .await
721 }
722 }
723 }
724
725 async fn run_actor_loop<C, Input, M>(
728 &self,
729 poll_interval: Duration,
730 workflows: WorkflowRegistry<C, Input, M>,
731 mut cmd_rx: mpsc::Receiver<WorkerCommand>,
732 ) -> Result<(), crate::error::RuntimeError>
733 where
734 Input: Send + 'static,
735 M: Send + Sync + 'static,
736 C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
737 {
738 let mut interval = time::interval(poll_interval);
739
740 loop {
741 tokio::select! {
742 biased;
743
744 cmd = cmd_rx.recv() => {
745 match cmd {
747 Some(WorkerCommand::Shutdown) | None => {
748 tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
749 return Ok(());
750 }
751 }
752 }
753
754 _ = interval.tick() => {
755 tracing::trace!(worker_id = %self.worker_id, "Will poll for available tasks");
756 }
757 }
758
759 let available_tasks = self
760 .backend
761 .find_available_tasks(&self.worker_id, self.batch_size.get())
762 .await?;
763
764 for task in available_tasks {
765 if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
766 cmd_rx.try_recv()
767 {
768 tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
769 return Ok(());
770 }
771
772 if let Some((_, workflow)) = workflows
773 .iter()
774 .find(|(hash, _)| *hash == task.workflow_definition_hash)
775 {
776 match self.execute_task(workflow.as_ref(), task).await {
777 Err(ref e) if e.is_timeout() => {
778 tracing::error!(
779 worker_id = %self.worker_id,
780 error = %e,
781 "Task timed out — worker shutting down"
782 );
783 return Ok(());
784 }
785 Ok(_) => {
786 tracing::info!("Worker {} completed a task", self.worker_id);
787 }
788 Err(e) => {
789 tracing::error!(
790 "Worker {} task execution failed: {}",
791 self.worker_id,
792 e
793 );
794 }
795 }
796 }
797 }
798 }
799 }
800
801 async fn load_cancelled_status(&self, instance_id: &str) -> WorkflowStatus {
806 if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
807 && let Some((reason, cancelled_by)) = snapshot.state.cancellation_details()
808 {
809 return WorkflowStatus::Cancelled {
810 reason,
811 cancelled_by,
812 };
813 }
814 WorkflowStatus::Cancelled {
815 reason: None,
816 cancelled_by: None,
817 }
818 }
819
820 async fn load_paused_status(&self, instance_id: &str) -> WorkflowStatus {
825 if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
826 && let Some((reason, paused_by)) = snapshot.state.pause_details()
827 {
828 return WorkflowStatus::Paused { reason, paused_by };
829 }
830 WorkflowStatus::Paused {
831 reason: None,
832 paused_by: None,
833 }
834 }
835
836 pub async fn execute_task<C, Input, M>(
848 &self,
849 workflow: &Workflow<C, Input, M>,
850 available_task: AvailableTask,
851 ) -> Result<WorkflowStatus, crate::error::RuntimeError>
852 where
853 Input: Send + 'static,
854 M: Send + Sync + 'static,
855 C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
856 {
857 let mut snapshot = self
859 .backend
860 .load_snapshot(&available_task.instance_id)
861 .await?;
862 let already_completed = Self::validate_task_preconditions(
863 workflow.definition_hash(),
864 workflow.continuation(),
865 &available_task,
866 &snapshot,
867 )?;
868 if already_completed {
869 return Ok(WorkflowStatus::InProgress);
870 }
871
872 let Some(claim) = self.claim_task(&available_task).await? else {
873 return Ok(WorkflowStatus::InProgress);
874 };
875
876 if let Some(status) = self.check_post_claim_guards(&available_task).await? {
878 claim.release_quietly().await;
879 return Ok(status);
880 }
881
882 tracing::debug!(
883 instance_id = %available_task.instance_id,
884 task_id = %available_task.task_id,
885 "Executing task"
886 );
887
888 let execution_result = self
890 .execute_with_deadline(workflow, &available_task, &mut snapshot, &claim)
891 .await;
892
893 self.settle_execution_result(
894 execution_result,
895 workflow,
896 &available_task,
897 &mut snapshot,
898 claim,
899 )
900 .await
901 }
902
903 async fn execute_with_deadline<C, Input, M>(
909 &self,
910 workflow: &Workflow<C, Input, M>,
911 available_task: &AvailableTask,
912 snapshot: &mut WorkflowSnapshot,
913 claim: &ActiveTaskClaim<'_, B>,
914 ) -> ExecutionOutcome
915 where
916 Input: Send + 'static,
917 M: Send + Sync + 'static,
918 C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
919 {
920 let continuation = workflow.continuation();
921 let task_id = available_task.task_id.clone();
922 let input = available_task.input.clone();
923
924 let deadline = if let Some(timeout) = continuation.get_task_timeout(&task_id) {
926 snapshot.set_task_deadline(task_id.clone(), timeout);
927 let _ = self.backend.save_snapshot(snapshot).await;
928 snapshot.refresh_task_deadline();
931 snapshot.task_deadline.clone()
932 } else {
933 None
934 };
935
936 let context = workflow.context().clone();
937 let execution_future = with_context(context, || async move {
938 Self::execute_task_by_id(continuation, &task_id, input).await
939 });
940
941 let heartbeat_result = self
942 .run_with_heartbeat(
943 claim,
944 deadline.as_ref(),
945 AssertUnwindSafe(execution_future).catch_unwind(),
946 )
947 .await;
948
949 snapshot.clear_task_deadline();
950
951 match heartbeat_result {
952 Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
953 Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
954 Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e),
955 Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
956 }
957 }
958
959 async fn try_schedule_retry(
966 &self,
967 continuation: &WorkflowContinuation,
968 available_task: &AvailableTask,
969 snapshot: &mut WorkflowSnapshot,
970 error_msg: &str,
971 ) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
972 let Some(policy) = continuation.get_task_retry_policy(&available_task.task_id) else {
973 return Ok(None);
974 };
975
976 if snapshot.retries_exhausted(&available_task.task_id) {
977 return Ok(None);
978 }
979
980 let next_retry_at = snapshot.record_retry(
981 &available_task.task_id,
982 policy,
983 error_msg,
984 Some(&self.worker_id),
985 );
986 snapshot.clear_task_deadline();
987 let _ = self.backend.save_snapshot(snapshot).await;
988
989 tracing::info!(
990 instance_id = %available_task.instance_id,
991 task_id = %available_task.task_id,
992 attempt = snapshot.get_retry_state(&available_task.task_id).map_or(0, |rs| rs.attempts),
993 max_retries = policy.max_retries,
994 %next_retry_at,
995 "Scheduling retry"
996 );
997
998 Ok(Some(WorkflowStatus::InProgress))
999 }
1000
1001 async fn settle_execution_result<C, Input, M>(
1003 &self,
1004 outcome: ExecutionOutcome,
1005 workflow: &Workflow<C, Input, M>,
1006 available_task: &AvailableTask,
1007 snapshot: &mut WorkflowSnapshot,
1008 claim: ActiveTaskClaim<'_, B>,
1009 ) -> Result<WorkflowStatus, crate::error::RuntimeError>
1010 where
1011 Input: Send + 'static,
1012 M: Send + Sync + 'static,
1013 C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
1014 {
1015 match outcome {
1016 ExecutionOutcome::Timeout(err) => {
1017 if let Ok(Some(status)) = self
1018 .try_schedule_retry(
1019 workflow.continuation(),
1020 available_task,
1021 snapshot,
1022 &err.to_string(),
1023 )
1024 .await
1025 {
1026 claim.release_quietly().await;
1027 return Ok(status);
1028 }
1029
1030 tracing::warn!(
1031 instance_id = %available_task.instance_id,
1032 task_id = %available_task.task_id,
1033 error = %err,
1034 "Task timed out via heartbeat — marking workflow failed, shutting down"
1035 );
1036 snapshot.mark_failed(err.to_string());
1037 let _ = self.backend.save_snapshot(snapshot).await;
1038 claim.release_quietly().await;
1039 Err(err)
1040 }
1041 ExecutionOutcome::Panic(panic_payload) => {
1042 let panic_msg = extract_panic_message(&panic_payload);
1043
1044 if let Ok(Some(status)) = self
1045 .try_schedule_retry(
1046 workflow.continuation(),
1047 available_task,
1048 snapshot,
1049 &panic_msg,
1050 )
1051 .await
1052 {
1053 claim.release_quietly().await;
1054 return Ok(status);
1055 }
1056
1057 tracing::error!(
1058 instance_id = %available_task.instance_id,
1059 task_id = %available_task.task_id,
1060 panic = %panic_msg,
1061 "Task panicked - releasing claim"
1062 );
1063 claim.release_quietly().await;
1064 Err(WorkflowError::TaskPanicked(panic_msg).into())
1065 }
1066 ExecutionOutcome::TaskError(e) => {
1067 if let Ok(Some(status)) = self
1068 .try_schedule_retry(
1069 workflow.continuation(),
1070 available_task,
1071 snapshot,
1072 &e.to_string(),
1073 )
1074 .await
1075 {
1076 claim.release_quietly().await;
1077 return Ok(status);
1078 }
1079
1080 tracing::error!(
1081 instance_id = %available_task.instance_id,
1082 task_id = %available_task.task_id,
1083 error = %e,
1084 "Task execution failed"
1085 );
1086 claim.release_quietly().await;
1087 Err(e)
1088 }
1089 ExecutionOutcome::Success(output) => {
1090 snapshot.clear_retry_state(&available_task.task_id);
1091 self.commit_task_result(
1092 workflow.continuation(),
1093 available_task,
1094 snapshot,
1095 output.clone(),
1096 claim,
1097 )
1098 .await?;
1099 self.determine_post_task_status(
1100 workflow.continuation(),
1101 available_task,
1102 snapshot,
1103 output,
1104 )
1105 .await
1106 }
1107 }
1108 }
1109
1110 fn validate_task_preconditions(
1116 definition_hash: &str,
1117 continuation: &WorkflowContinuation,
1118 available_task: &AvailableTask,
1119 snapshot: &WorkflowSnapshot,
1120 ) -> Result<bool, crate::error::RuntimeError> {
1121 if available_task.workflow_definition_hash != definition_hash {
1122 return Err(WorkflowError::DefinitionMismatch {
1123 expected: definition_hash.to_string(),
1124 found: available_task.workflow_definition_hash.clone(),
1125 }
1126 .into());
1127 }
1128
1129 if !Self::find_task_id_in_continuation(continuation, &available_task.task_id) {
1130 tracing::error!(
1131 instance_id = %available_task.instance_id,
1132 task_id = %available_task.task_id,
1133 "Task does not exist in workflow"
1134 );
1135 return Err(WorkflowError::TaskNotFound(available_task.task_id.clone()).into());
1136 }
1137
1138 if snapshot.get_task_result(&available_task.task_id).is_some() {
1139 tracing::debug!(
1140 instance_id = %available_task.instance_id,
1141 task_id = %available_task.task_id,
1142 "Task already completed, skipping"
1143 );
1144 return Ok(true);
1145 }
1146
1147 Ok(false)
1148 }
1149
1150 async fn claim_task(
1154 &self,
1155 available_task: &AvailableTask,
1156 ) -> Result<Option<ActiveTaskClaim<'_, B>>, crate::error::RuntimeError> {
1157 let claim = self
1158 .backend
1159 .claim_task(
1160 &available_task.instance_id,
1161 &available_task.task_id,
1162 &self.worker_id,
1163 self.claim_ttl
1164 .and_then(|d| chrono::Duration::from_std(d).ok()),
1165 )
1166 .await?;
1167
1168 if claim.is_some() {
1169 tracing::debug!(
1170 instance_id = %available_task.instance_id,
1171 task_id = %available_task.task_id,
1172 "Claim successful"
1173 );
1174 Ok(Some(ActiveTaskClaim {
1175 backend: &self.backend,
1176 instance_id: available_task.instance_id.clone(),
1177 task_id: available_task.task_id.clone(),
1178 worker_id: self.worker_id.clone(),
1179 }))
1180 } else {
1181 tracing::debug!(
1182 instance_id = %available_task.instance_id,
1183 task_id = %available_task.task_id,
1184 "Task was already claimed by another worker"
1185 );
1186 Ok(None)
1187 }
1188 }
1189
1190 async fn check_post_claim_guards(
1196 &self,
1197 available_task: &AvailableTask,
1198 ) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
1199 if self
1200 .backend
1201 .check_and_cancel(&available_task.instance_id, Some(&available_task.task_id))
1202 .await?
1203 {
1204 tracing::info!(
1205 instance_id = %available_task.instance_id,
1206 task_id = %available_task.task_id,
1207 "Workflow was cancelled, releasing claim"
1208 );
1209 return Ok(Some(
1210 self.load_cancelled_status(&available_task.instance_id)
1211 .await,
1212 ));
1213 }
1214
1215 if self
1216 .backend
1217 .check_and_pause(&available_task.instance_id)
1218 .await?
1219 {
1220 tracing::info!(
1221 instance_id = %available_task.instance_id,
1222 task_id = %available_task.task_id,
1223 "Workflow was paused, releasing claim"
1224 );
1225 return Ok(Some(
1226 self.load_paused_status(&available_task.instance_id).await,
1227 ));
1228 }
1229
1230 Ok(None)
1231 }
1232
1233 async fn run_with_heartbeat<F, T>(
1239 &self,
1240 claim: &ActiveTaskClaim<'_, B>,
1241 deadline: Option<&TaskDeadline>,
1242 future: F,
1243 ) -> Result<T, crate::error::RuntimeError>
1244 where
1245 F: std::future::Future<Output = T>,
1246 {
1247 let Some(ttl) = self.claim_ttl else {
1248 return Ok(future.await);
1249 };
1250 let Some(chrono_ttl) = chrono::Duration::from_std(ttl).ok() else {
1251 return Ok(future.await);
1252 };
1253
1254 let interval_duration = ttl / 2;
1255 let mut heartbeat_timer = time::interval(interval_duration);
1256 heartbeat_timer.tick().await; tokio::pin!(future);
1259
1260 loop {
1261 tokio::select! {
1262 result = &mut future => break Ok(result),
1263 _ = heartbeat_timer.tick() => {
1264 if let Some(dl) = deadline
1266 && chrono::Utc::now() >= dl.deadline
1267 {
1268 tracing::warn!(
1269 instance_id = %claim.instance_id,
1270 task_id = %dl.task_id,
1271 "Task deadline expired during heartbeat, cancelling"
1272 );
1273 return Err(WorkflowError::TaskTimedOut {
1274 task_id: dl.task_id.clone(),
1275 timeout: std::time::Duration::from_millis(dl.timeout_ms),
1276 }
1277 .into());
1278 }
1279
1280 tracing::trace!(
1281 instance_id = %claim.instance_id,
1282 task_id = %claim.task_id,
1283 "Extending task claim via heartbeat"
1284 );
1285 if let Err(e) = self.backend
1286 .extend_task_claim(
1287 &claim.instance_id,
1288 &claim.task_id,
1289 &claim.worker_id,
1290 chrono_ttl,
1291 )
1292 .await
1293 {
1294 tracing::warn!(
1295 instance_id = %claim.instance_id,
1296 task_id = %claim.task_id,
1297 error = %e,
1298 "Failed to extend task claim"
1299 );
1300 }
1301 }
1302 }
1303 }
1304 }
1305
1306 async fn commit_task_result(
1308 &self,
1309 continuation: &WorkflowContinuation,
1310 available_task: &AvailableTask,
1311 snapshot: &mut WorkflowSnapshot,
1312 output: Bytes,
1313 claim: ActiveTaskClaim<'_, B>,
1314 ) -> Result<(), crate::error::RuntimeError> {
1315 snapshot.mark_task_completed(available_task.task_id.clone(), output);
1316 tracing::debug!(
1317 instance_id = %available_task.instance_id,
1318 task_id = %available_task.task_id,
1319 "Task completed"
1320 );
1321
1322 Self::update_position_after_task(continuation, &available_task.task_id, snapshot);
1323 self.backend.save_snapshot(snapshot).await?;
1324 claim.release().await?;
1325 Ok(())
1326 }
1327
1328 async fn determine_post_task_status(
1332 &self,
1333 continuation: &WorkflowContinuation,
1334 available_task: &AvailableTask,
1335 snapshot: &mut WorkflowSnapshot,
1336 output: Bytes,
1337 ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
1338 if self
1340 .backend
1341 .check_and_cancel(&available_task.instance_id, None)
1342 .await?
1343 {
1344 tracing::info!(
1345 instance_id = %available_task.instance_id,
1346 task_id = %available_task.task_id,
1347 "Workflow was cancelled after task completion"
1348 );
1349 return Ok(self
1350 .load_cancelled_status(&available_task.instance_id)
1351 .await);
1352 }
1353
1354 if self
1356 .backend
1357 .check_and_pause(&available_task.instance_id)
1358 .await?
1359 {
1360 tracing::info!(
1361 instance_id = %available_task.instance_id,
1362 task_id = %available_task.task_id,
1363 "Workflow was paused after task completion"
1364 );
1365 return Ok(self.load_paused_status(&available_task.instance_id).await);
1366 }
1367
1368 if Self::is_workflow_complete(continuation, snapshot) {
1369 tracing::info!(
1370 instance_id = %available_task.instance_id,
1371 task_id = %available_task.task_id,
1372 "Workflow complete"
1373 );
1374 snapshot.mark_completed(output);
1375 self.backend.save_snapshot(snapshot).await?;
1376 Ok(WorkflowStatus::Completed)
1377 } else {
1378 tracing::debug!(
1379 instance_id = %available_task.instance_id,
1380 task_id = %available_task.task_id,
1381 "Task completed, workflow continues"
1382 );
1383 Ok(WorkflowStatus::InProgress)
1384 }
1385 }
1386
1387 fn find_task_id_in_continuation(continuation: &WorkflowContinuation, task_id: &str) -> bool {
1392 match continuation {
1393 WorkflowContinuation::Task { id, next, .. }
1394 | WorkflowContinuation::Delay { id, next, .. }
1395 | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1396 if id == task_id {
1397 return true;
1398 }
1399 next.as_ref()
1400 .is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
1401 }
1402 WorkflowContinuation::Fork { branches, join, .. } => {
1403 for branch in branches {
1405 if Self::find_task_id_in_continuation(branch, task_id) {
1406 return true;
1407 }
1408 }
1409 if let Some(join_cont) = join {
1411 Self::find_task_id_in_continuation(join_cont, task_id)
1412 } else {
1413 false
1414 }
1415 }
1416 }
1417 }
1418
1419 #[allow(clippy::manual_async_fn)]
1421 fn execute_task_by_id<'a>(
1422 continuation: &'a WorkflowContinuation,
1423 task_id: &'a str,
1424 input: Bytes,
1425 ) -> impl std::future::Future<Output = Result<Bytes, crate::error::RuntimeError>> + Send + 'a
1426 {
1427 async move {
1428 let mut current = continuation;
1429
1430 loop {
1431 match current {
1432 WorkflowContinuation::Task { id, func, next, .. } => {
1433 if id == task_id {
1434 let func = func
1435 .as_ref()
1436 .ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
1437 return Ok(func.run(input).await?);
1438 } else if let Some(next_cont) = next {
1439 current = next_cont;
1440 } else {
1441 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1442 }
1443 }
1444 WorkflowContinuation::Delay { next, .. }
1445 | WorkflowContinuation::AwaitSignal { next, .. } => {
1446 if let Some(next_cont) = next {
1448 current = next_cont;
1449 } else {
1450 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1451 }
1452 }
1453 WorkflowContinuation::Fork { branches, join, .. } => {
1454 let mut found_in_branch = false;
1456 for branch in branches {
1457 if Self::find_task_id_in_continuation(branch, task_id) {
1458 current = branch;
1459 found_in_branch = true;
1460 break;
1461 }
1462 }
1463 if found_in_branch {
1464 continue;
1465 }
1466 if let Some(join_cont) = join {
1468 current = join_cont;
1469 } else {
1470 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1471 }
1472 }
1473 }
1474 }
1475 }
1476 }
1477
1478 fn update_position_after_task(
1480 continuation: &WorkflowContinuation,
1481 completed_task_id: &str,
1482 snapshot: &mut WorkflowSnapshot,
1483 ) {
1484 match continuation {
1485 WorkflowContinuation::Task { id, next, .. }
1486 | WorkflowContinuation::Delay { id, next, .. }
1487 | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1488 if id == completed_task_id {
1489 if let Some(next_cont) = next {
1490 snapshot.update_position(ExecutionPosition::AtTask {
1491 task_id: next_cont.first_task_id().to_string(),
1492 });
1493 }
1494 } else if let Some(next_cont) = next {
1495 Self::update_position_after_task(next_cont, completed_task_id, snapshot);
1496 }
1497 }
1498 WorkflowContinuation::Fork { branches, join, .. } => {
1499 for branch in branches {
1501 Self::update_position_after_task(branch, completed_task_id, snapshot);
1502 }
1503 if let Some(join_cont) = join {
1505 Self::update_position_after_task(join_cont, completed_task_id, snapshot);
1506 }
1507 }
1508 }
1509 }
1510
1511 fn is_workflow_complete(
1513 continuation: &WorkflowContinuation,
1514 snapshot: &WorkflowSnapshot,
1515 ) -> bool {
1516 match continuation {
1518 WorkflowContinuation::Task { id, next, .. } => {
1519 if snapshot.get_task_result(id).is_none() {
1520 return false;
1521 }
1522 if let Some(next_cont) = next {
1523 Self::is_workflow_complete(next_cont, snapshot)
1524 } else {
1525 true }
1527 }
1528 WorkflowContinuation::Delay { id, next, .. }
1529 | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1530 if snapshot.get_task_result(id).is_none() {
1531 return false;
1532 }
1533 next.as_ref()
1534 .is_none_or(|n| Self::is_workflow_complete(n, snapshot))
1535 }
1536 WorkflowContinuation::Fork { branches, join, .. } => {
1537 for branch in branches {
1539 if !Self::is_workflow_complete(branch, snapshot) {
1540 return false;
1541 }
1542 }
1543 if let Some(join_cont) = join {
1545 Self::is_workflow_complete(join_cont, snapshot)
1546 } else {
1547 true
1548 }
1549 }
1550 }
1551 }
1552}
1553
1554#[cfg(test)]
1555#[allow(clippy::unwrap_used)]
1556mod tests {
1557 use super::*;
1558 use crate::serialization::JsonCodec;
1559 use sayiir_core::registry::TaskRegistry;
1560 use sayiir_core::snapshot::WorkflowSnapshot;
1561 use sayiir_persistence::{InMemoryBackend, SignalStore, SnapshotStore};
1562
1563 type EmptyWorkflows = WorkflowRegistry<JsonCodec, (), ()>;
1564
1565 fn make_worker() -> PooledWorker<InMemoryBackend> {
1566 let backend = InMemoryBackend::new();
1567 let registry = TaskRegistry::new();
1568 PooledWorker::new("test-worker", backend, registry)
1569 }
1570
1571 #[tokio::test]
1572 async fn test_spawn_and_shutdown() {
1573 let worker = make_worker();
1574 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1575
1576 handle.shutdown();
1577
1578 let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
1579 assert!(result.is_ok(), "Worker should exit cleanly after shutdown");
1580 assert!(result.unwrap().is_ok());
1581 }
1582
1583 #[tokio::test]
1584 async fn test_handle_is_clone_and_send() {
1585 let worker = make_worker();
1586 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1587
1588 let handle2 = handle.clone();
1589 let remote = tokio::spawn(async move {
1590 handle2.shutdown();
1591 });
1592 remote.await.ok();
1593
1594 let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
1595 assert!(result.is_ok_and(|r| r.is_ok()));
1596 }
1597
1598 #[tokio::test]
1599 async fn test_cancel_via_handle() {
1600 let backend = InMemoryBackend::new();
1601 let registry = TaskRegistry::new();
1602
1603 let snapshot = WorkflowSnapshot::new("wf-1".to_string(), "hash-1".to_string());
1605 backend.save_snapshot(&snapshot).await.ok();
1606
1607 let worker = PooledWorker::new("test-worker", backend, registry);
1608 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1609
1610 handle
1611 .cancel_workflow(
1612 "wf-1",
1613 Some("test reason".to_string()),
1614 Some("tester".to_string()),
1615 )
1616 .await
1617 .ok();
1618
1619 let signal = handle
1621 .backend()
1622 .get_signal("wf-1", SignalKind::Cancel)
1623 .await;
1624 assert!(signal.is_ok_and(|s| s.is_some()));
1625
1626 handle.shutdown();
1627 tokio::time::timeout(Duration::from_secs(5), handle.join())
1628 .await
1629 .ok();
1630 }
1631
1632 #[tokio::test]
1633 async fn test_dropped_handle_shuts_down_worker() {
1634 let worker = make_worker();
1635 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1636
1637 let join_handle = handle.inner.join_handle.lock().await.take().unwrap();
1639 drop(handle);
1640
1641 let result = tokio::time::timeout(Duration::from_secs(5), join_handle)
1642 .await
1643 .ok()
1644 .and_then(Result::ok);
1645 assert!(
1646 result.is_some(),
1647 "Worker should exit when all handles are dropped"
1648 );
1649 assert!(result.is_some_and(|r| r.is_ok()));
1650 }
1651}