1#![warn(missing_docs)] #[macro_use]
49extern crate tracing;
50
51mod activity_context;
52mod app_data;
53pub mod interceptors;
54pub mod prelude;
55pub mod workflow;
56mod workflow_context;
57mod workflow_future;
58
59pub use activity_context::ActContext;
60pub use squads_temporal_client::Namespace;
61use tracing::{Instrument, Span, field};
62pub use workflow_context::{
63 ActivityOptions, CancellableFuture, ChildWorkflow, ChildWorkflowOptions, LocalActivityOptions,
64 NexusOperationOptions, PendingChildWorkflow, Signal, SignalData, SignalWorkflowOptions,
65 StartedChildWorkflow, TimerOptions, WfContext,
66};
67
68use crate::{
69 interceptors::WorkerInterceptor,
70 workflow_context::{ChildWfCommon, NexusUnblockData, StartedNexusOperation},
71};
72use anyhow::{Context, anyhow, bail};
73use app_data::AppData;
74use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::BoxFuture};
75use serde::Serialize;
76use std::{
77 any::{Any, TypeId},
78 cell::RefCell,
79 collections::HashMap,
80 fmt::{Debug, Display, Formatter},
81 future::Future,
82 panic::AssertUnwindSafe,
83 sync::Arc,
84 time::Duration,
85};
86use squads_temporal_client::ClientOptionsBuilder;
87use squads_temporal_sdk_core::Url;
88use squads_temporal_sdk_core_api::{Worker as CoreWorker, errors::PollError};
89use squads_temporal_sdk_core_protos::{
90 TaskToken,
91 coresdk::{
92 ActivityTaskCompletion, AsJsonPayloadExt, FromJsonPayloadExt,
93 activity_result::{ActivityExecutionResult, ActivityResolution},
94 activity_task::{ActivityTask, activity_task},
95 child_workflow::ChildWorkflowResult,
96 common::NamespacedWorkflowExecution,
97 nexus::NexusOperationResult,
98 workflow_activation::{
99 WorkflowActivation,
100 resolve_child_workflow_execution_start::Status as ChildWorkflowStartStatus,
101 resolve_nexus_operation_start, workflow_activation_job::Variant,
102 },
103 workflow_commands::{ContinueAsNewWorkflowExecution, WorkflowCommand, workflow_command},
104 workflow_completion::WorkflowActivationCompletion,
105 },
106 temporal::api::{
107 common::v1::Payload,
108 enums::v1::WorkflowTaskFailedCause,
109 failure::v1::{Failure, failure},
110 },
111};
112use tokio::{
113 sync::{
114 Notify,
115 mpsc::{UnboundedSender, unbounded_channel},
116 oneshot,
117 },
118 task::JoinError,
119};
120use tokio_stream::wrappers::UnboundedReceiverStream;
121use tokio_util::sync::CancellationToken;
122
123const VERSION: &str = env!("CARGO_PKG_VERSION");
124
125pub fn sdk_client_options(url: impl Into<Url>) -> ClientOptionsBuilder {
128 let mut builder = ClientOptionsBuilder::default();
129 builder
130 .target_url(url)
131 .client_name("temporal-rust".to_string())
132 .client_version(VERSION.to_string());
133
134 builder
135}
136
137pub struct Worker {
140 common: CommonWorker,
141 workflow_half: WorkflowHalf,
142 activity_half: ActivityHalf,
143 app_data: Option<AppData>,
144}
145
146struct CommonWorker {
147 worker: Arc<dyn CoreWorker>,
148 task_queue: String,
149 worker_interceptor: Option<Box<dyn WorkerInterceptor>>,
150}
151
152struct WorkflowHalf {
153 workflows: RefCell<HashMap<String, WorkflowData>>,
155 workflow_fns: RefCell<HashMap<String, WorkflowFunction>>,
157 workflow_removed_from_map: Notify,
158}
159struct WorkflowData {
160 activation_chan: UnboundedSender<WorkflowActivation>,
162}
163
164struct WorkflowFutureHandle<F: Future<Output = Result<WorkflowResult<Payload>, JoinError>>> {
165 join_handle: F,
166 run_id: String,
167}
168
169struct ActivityHalf {
170 activity_fns: HashMap<String, ActivityFunction>,
172 task_tokens_to_cancels: HashMap<TaskToken, CancellationToken>,
173}
174
175impl Worker {
176 pub fn new_from_core(worker: Arc<dyn CoreWorker>, task_queue: impl Into<String>) -> Self {
178 Self {
179 common: CommonWorker {
180 worker,
181 task_queue: task_queue.into(),
182 worker_interceptor: None,
183 },
184 workflow_half: WorkflowHalf {
185 workflows: Default::default(),
186 workflow_fns: Default::default(),
187 workflow_removed_from_map: Default::default(),
188 },
189 activity_half: ActivityHalf {
190 activity_fns: Default::default(),
191 task_tokens_to_cancels: Default::default(),
192 },
193 app_data: Some(Default::default()),
194 }
195 }
196
197 pub fn task_queue(&self) -> &str {
199 &self.common.task_queue
200 }
201
202 pub fn shutdown_handle(&self) -> impl Fn() + use<> {
205 let w = self.common.worker.clone();
206 move || w.initiate_shutdown()
207 }
208
209 pub fn register_wf(
212 &mut self,
213 workflow_type: impl Into<String>,
214 wf_function: impl Into<WorkflowFunction>,
215 ) {
216 self.workflow_half
217 .workflow_fns
218 .get_mut()
219 .insert(workflow_type.into(), wf_function.into());
220 }
221
222 pub fn register_activity<A, R, O>(
225 &mut self,
226 activity_type: impl Into<String>,
227 act_function: impl IntoActivityFunc<A, R, O>,
228 ) {
229 self.activity_half.activity_fns.insert(
230 activity_type.into(),
231 ActivityFunction {
232 act_func: act_function.into_activity_fn(),
233 },
234 );
235 }
236
237 pub fn insert_app_data<T: Send + Sync + 'static>(&mut self, data: T) {
239 self.app_data.as_mut().map(|a| a.insert(data));
240 }
241
242 pub async fn run(&mut self) -> Result<(), anyhow::Error> {
245 let shutdown_token = CancellationToken::new();
246 let (common, wf_half, act_half, app_data) = self.split_apart();
247 let safe_app_data = Arc::new(
248 app_data
249 .take()
250 .ok_or_else(|| anyhow!("app_data should exist on run"))?,
251 );
252 let (wf_future_tx, wf_future_rx) = unbounded_channel();
253 let (completions_tx, completions_rx) = unbounded_channel();
254 let wf_future_joiner = async {
255 UnboundedReceiverStream::new(wf_future_rx)
256 .map(Result::<_, anyhow::Error>::Ok)
257 .try_for_each_concurrent(
258 None,
259 |WorkflowFutureHandle {
260 join_handle,
261 run_id,
262 }| {
263 let wf_half = &*wf_half;
264 async move {
265 join_handle.await??;
266 debug!(run_id=%run_id, "Removing workflow from cache");
267 wf_half.workflows.borrow_mut().remove(&run_id);
268 wf_half.workflow_removed_from_map.notify_one();
269 Ok(())
270 }
271 },
272 )
273 .await
274 .context("Workflow futures encountered an error")
275 };
276 let wf_completion_processor = async {
277 UnboundedReceiverStream::new(completions_rx)
278 .map(Ok)
279 .try_for_each_concurrent(None, |completion| async {
280 if let Some(ref i) = common.worker_interceptor {
281 i.on_workflow_activation_completion(&completion).await;
282 }
283 common.worker.complete_workflow_activation(completion).await
284 })
285 .map_err(anyhow::Error::from)
286 .await
287 .context("Workflow completions processor encountered an error")
288 };
289 tokio::try_join!(
290 async {
292 loop {
293 let activation = match common.worker.poll_workflow_activation().await {
294 Err(PollError::ShutDown) => {
295 break;
296 }
297 o => o?,
298 };
299 if let Some(ref i) = common.worker_interceptor {
300 i.on_workflow_activation(&activation).await?;
301 }
302 if let Some(wf_fut) = wf_half
303 .workflow_activation_handler(
304 common,
305 shutdown_token.clone(),
306 activation,
307 &completions_tx,
308 )
309 .await?
310 && wf_future_tx.send(wf_fut).is_err()
311 {
312 panic!("Receive half of completion processor channel cannot be dropped");
313 }
314 }
315 shutdown_token.cancel();
317 drop(wf_future_tx);
320 drop(completions_tx);
321 Result::<_, anyhow::Error>::Ok(())
322 },
323 async {
326 if !act_half.activity_fns.is_empty() {
327 loop {
328 let activity = common.worker.poll_activity_task().await;
329 if matches!(activity, Err(PollError::ShutDown)) {
330 break;
331 }
332 act_half.activity_task_handler(
333 common.worker.clone(),
334 safe_app_data.clone(),
335 common.task_queue.clone(),
336 activity?,
337 )?;
338 }
339 };
340 Result::<_, anyhow::Error>::Ok(())
341 },
342 wf_future_joiner,
343 wf_completion_processor,
344 )?;
345
346 info!("Polling loops exited");
347 if let Some(i) = self.common.worker_interceptor.as_ref() {
348 i.on_shutdown(self);
349 }
350 self.common.worker.shutdown().await;
351 debug!("Worker shutdown complete");
352 self.app_data = Some(
353 Arc::try_unwrap(safe_app_data)
354 .map_err(|_| anyhow!("some references of AppData exist on worker shutdown"))?,
355 );
356 Ok(())
357 }
358
359 pub fn set_worker_interceptor(&mut self, interceptor: impl WorkerInterceptor + 'static) {
361 self.common.worker_interceptor = Some(Box::new(interceptor));
362 }
363
364 pub fn with_new_core_worker(&mut self, new_core_worker: Arc<dyn CoreWorker>) {
368 self.common.worker = new_core_worker;
369 }
370
371 pub fn cached_workflows(&self) -> usize {
374 self.workflow_half.workflows.borrow().len()
375 }
376
377 fn split_apart(
378 &mut self,
379 ) -> (
380 &mut CommonWorker,
381 &mut WorkflowHalf,
382 &mut ActivityHalf,
383 &mut Option<AppData>,
384 ) {
385 (
386 &mut self.common,
387 &mut self.workflow_half,
388 &mut self.activity_half,
389 &mut self.app_data,
390 )
391 }
392}
393
394impl WorkflowHalf {
395 #[allow(clippy::type_complexity)]
396 async fn workflow_activation_handler(
397 &self,
398 common: &CommonWorker,
399 shutdown_token: CancellationToken,
400 mut activation: WorkflowActivation,
401 completions_tx: &UnboundedSender<WorkflowActivationCompletion>,
402 ) -> Result<
403 Option<
404 WorkflowFutureHandle<
405 impl Future<Output = Result<WorkflowResult<Payload>, JoinError>> + use<>,
406 >,
407 >,
408 anyhow::Error,
409 > {
410 let mut res = None;
411 let run_id = activation.run_id.clone();
412
413 if let Some(sw) = activation.jobs.iter_mut().find_map(|j| match j.variant {
416 Some(Variant::InitializeWorkflow(ref mut sw)) => Some(sw),
417 _ => None,
418 }) {
419 let workflow_type = &sw.workflow_type;
420 let (wff, activations) = {
421 let wf_fns_borrow = self.workflow_fns.borrow();
422
423 let Some(wf_function) = wf_fns_borrow.get(workflow_type) else {
424 warn!("Workflow type {workflow_type} not found");
425
426 completions_tx
427 .send(WorkflowActivationCompletion::fail(
428 run_id,
429 format!("Workflow type {workflow_type} not found").into(),
430 Some(WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure),
431 ))
432 .expect("Completion channel intact");
433 return Ok(None);
434 };
435
436 wf_function.start_workflow(
437 common.worker.get_config().namespace.clone(),
438 common.task_queue.clone(),
439 std::mem::take(sw),
440 completions_tx.clone(),
441 )
442 };
443 let jh = tokio::spawn(async move {
444 tokio::select! {
445 r = wff.fuse() => r,
446 _ = shutdown_token.cancelled() => {
449 Ok(WfExitValue::Evicted)
450 }
451 }
452 });
453 res = Some(WorkflowFutureHandle {
454 join_handle: jh,
455 run_id: run_id.clone(),
456 });
457 loop {
458 if self.workflows.borrow_mut().contains_key(&run_id) {
463 self.workflow_removed_from_map.notified().await;
464 } else {
465 break;
466 }
467 }
468 self.workflows.borrow_mut().insert(
469 run_id.clone(),
470 WorkflowData {
471 activation_chan: activations,
472 },
473 );
474 }
475
476 if let Some(dat) = self.workflows.borrow_mut().get_mut(&run_id) {
479 dat.activation_chan
480 .send(activation)
481 .expect("Workflow should exist if we're sending it an activation");
482 } else {
483 if activation.jobs.len() == 1
489 && matches!(
490 activation.jobs.first().map(|j| &j.variant),
491 Some(Some(Variant::RemoveFromCache(_)))
492 )
493 {
494 completions_tx
495 .send(WorkflowActivationCompletion::from_cmds(run_id, vec![]))
496 .expect("Completion channel intact");
497 return Ok(None);
498 }
499
500 bail!(
503 "Got activation {:?} for unknown workflow {}",
504 activation,
505 run_id
506 );
507 };
508
509 Ok(res)
510 }
511}
512
513impl ActivityHalf {
514 fn activity_task_handler(
516 &mut self,
517 worker: Arc<dyn CoreWorker>,
518 app_data: Arc<AppData>,
519 task_queue: String,
520 activity: ActivityTask,
521 ) -> Result<(), anyhow::Error> {
522 match activity.variant {
523 Some(activity_task::Variant::Start(start)) => {
524 let act_fn = self
525 .activity_fns
526 .get(&start.activity_type)
527 .ok_or_else(|| {
528 anyhow!(
529 "No function registered for activity type {}",
530 start.activity_type
531 )
532 })?
533 .clone();
534 let span = info_span!(
535 "RunActivity",
536 "otel.name" = format!("RunActivity:{}", start.activity_type),
537 "otel.kind" = "server",
538 "temporalActivityID" = start.activity_id,
539 "temporalWorkflowID" = field::Empty,
540 "temporalRunID" = field::Empty,
541 );
542 let ct = CancellationToken::new();
543 let task_token = activity.task_token;
544 self.task_tokens_to_cancels
545 .insert(task_token.clone().into(), ct.clone());
546
547 let (ctx, arg) = ActContext::new(
548 worker.clone(),
549 app_data,
550 ct,
551 task_queue,
552 task_token.clone(),
553 start,
554 );
555
556 tokio::spawn(async move {
557 let act_fut = async move {
558 if let Some(info) = &ctx.get_info().workflow_execution {
559 Span::current()
560 .record("temporalWorkflowID", &info.workflow_id)
561 .record("temporalRunID", &info.run_id);
562 }
563 (act_fn.act_func)(ctx, arg).await
564 }
565 .instrument(span);
566 let output = AssertUnwindSafe(act_fut).catch_unwind().await;
567 let result = match output {
568 Err(e) => ActivityExecutionResult::fail(Failure::application_failure(
569 format!("Activity function panicked: {}", panic_formatter(e)),
570 true,
571 )),
572 Ok(Ok(ActExitValue::Normal(p))) => ActivityExecutionResult::ok(p),
573 Ok(Ok(ActExitValue::WillCompleteAsync)) => {
574 ActivityExecutionResult::will_complete_async()
575 }
576 Ok(Err(err)) => match err {
577 ActivityError::Retryable {
578 source,
579 explicit_delay,
580 } => ActivityExecutionResult::fail({
581 let mut f = Failure::application_failure_from_error(source, false);
582 if let Some(d) = explicit_delay
583 && let Some(failure::FailureInfo::ApplicationFailureInfo(fi)) =
584 f.failure_info.as_mut()
585 {
586 fi.next_retry_delay = d.try_into().ok();
587 }
588 f
589 }),
590 ActivityError::Cancelled { details } => {
591 ActivityExecutionResult::cancel_from_details(details)
592 }
593 ActivityError::NonRetryable(nre) => ActivityExecutionResult::fail(
594 Failure::application_failure_from_error(nre, true),
595 ),
596 },
597 };
598 worker
599 .complete_activity_task(ActivityTaskCompletion {
600 task_token,
601 result: Some(result),
602 })
603 .await?;
604 Ok::<_, anyhow::Error>(())
605 });
606 }
607 Some(activity_task::Variant::Cancel(_)) => {
608 if let Some(ct) = self
609 .task_tokens_to_cancels
610 .get(activity.task_token.as_slice())
611 {
612 ct.cancel();
613 }
614 }
615 None => bail!("Undefined activity task variant"),
616 }
617 Ok(())
618 }
619}
620
621#[derive(Debug)]
622enum UnblockEvent {
623 Timer(u32, TimerResult),
624 Activity(u32, Box<ActivityResolution>),
625 WorkflowStart(u32, Box<ChildWorkflowStartStatus>),
626 WorkflowComplete(u32, Box<ChildWorkflowResult>),
627 SignalExternal(u32, Option<Failure>),
628 CancelExternal(u32, Option<Failure>),
629 NexusOperationStart(u32, Box<resolve_nexus_operation_start::Status>),
630 NexusOperationComplete(u32, Box<NexusOperationResult>),
631}
632
633#[derive(Debug, Copy, Clone)]
635pub enum TimerResult {
636 Cancelled,
638 Fired,
640}
641
642#[derive(Debug)]
644pub struct SignalExternalOk;
645pub type SignalExternalWfResult = Result<SignalExternalOk, Failure>;
647
648#[derive(Debug)]
650pub struct CancelExternalOk;
651pub type CancelExternalWfResult = Result<CancelExternalOk, Failure>;
653
654trait Unblockable {
655 type OtherDat;
656
657 fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self;
658}
659
660impl Unblockable for TimerResult {
661 type OtherDat = ();
662 fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
663 match ue {
664 UnblockEvent::Timer(_, result) => result,
665 _ => panic!("Invalid unblock event for timer"),
666 }
667 }
668}
669
670impl Unblockable for ActivityResolution {
671 type OtherDat = ();
672 fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
673 match ue {
674 UnblockEvent::Activity(_, result) => *result,
675 _ => panic!("Invalid unblock event for activity"),
676 }
677 }
678}
679
680impl Unblockable for PendingChildWorkflow {
681 type OtherDat = ChildWfCommon;
683 fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self {
684 match ue {
685 UnblockEvent::WorkflowStart(_, result) => Self {
686 status: *result,
687 common: od,
688 },
689 _ => panic!("Invalid unblock event for child workflow start"),
690 }
691 }
692}
693
694impl Unblockable for ChildWorkflowResult {
695 type OtherDat = ();
696 fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
697 match ue {
698 UnblockEvent::WorkflowComplete(_, result) => *result,
699 _ => panic!("Invalid unblock event for child workflow complete"),
700 }
701 }
702}
703
704impl Unblockable for SignalExternalWfResult {
705 type OtherDat = ();
706 fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
707 match ue {
708 UnblockEvent::SignalExternal(_, maybefail) => {
709 maybefail.map_or(Ok(SignalExternalOk), Err)
710 }
711 _ => panic!("Invalid unblock event for signal external workflow result"),
712 }
713 }
714}
715
716impl Unblockable for CancelExternalWfResult {
717 type OtherDat = ();
718 fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
719 match ue {
720 UnblockEvent::CancelExternal(_, maybefail) => {
721 maybefail.map_or(Ok(CancelExternalOk), Err)
722 }
723 _ => panic!("Invalid unblock event for signal external workflow result"),
724 }
725 }
726}
727
728type NexusStartResult = Result<StartedNexusOperation, Failure>;
729impl Unblockable for NexusStartResult {
730 type OtherDat = NexusUnblockData;
731 fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self {
732 match ue {
733 UnblockEvent::NexusOperationStart(_, result) => match *result {
734 resolve_nexus_operation_start::Status::OperationToken(op_token) => {
735 Ok(StartedNexusOperation {
736 operation_token: Some(op_token),
737 unblock_dat: od,
738 })
739 }
740 resolve_nexus_operation_start::Status::StartedSync(_) => {
741 Ok(StartedNexusOperation {
742 operation_token: None,
743 unblock_dat: od,
744 })
745 }
746 resolve_nexus_operation_start::Status::CancelledBeforeStart(f) => Err(f),
747 },
748 _ => panic!("Invalid unblock event for nexus operation"),
749 }
750 }
751}
752
753impl Unblockable for NexusOperationResult {
754 type OtherDat = ();
755
756 fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
757 match ue {
758 UnblockEvent::NexusOperationComplete(_, result) => *result,
759 _ => panic!("Invalid unblock event for nexus operation complete"),
760 }
761 }
762}
763
764#[derive(Debug, Clone)]
766pub(crate) enum CancellableID {
767 Timer(u32),
768 Activity(u32),
769 LocalActivity(u32),
770 ChildWorkflow {
771 seqnum: u32,
772 reason: String,
773 },
774 SignalExternalWorkflow(u32),
775 ExternalWorkflow {
776 seqnum: u32,
777 execution: NamespacedWorkflowExecution,
778 reason: String,
779 },
780 NexusOp(u32),
782}
783
784pub(crate) trait SupportsCancelReason {
786 fn with_reason(self, reason: String) -> CancellableID;
788}
789#[derive(Debug, Clone)]
790pub(crate) enum CancellableIDWithReason {
791 ChildWorkflow {
792 seqnum: u32,
793 },
794 ExternalWorkflow {
795 seqnum: u32,
796 execution: NamespacedWorkflowExecution,
797 },
798}
799impl CancellableIDWithReason {
800 pub(crate) fn seq_num(&self) -> u32 {
801 match self {
802 CancellableIDWithReason::ChildWorkflow { seqnum } => *seqnum,
803 CancellableIDWithReason::ExternalWorkflow { seqnum, .. } => *seqnum,
804 }
805 }
806}
807impl SupportsCancelReason for CancellableIDWithReason {
808 fn with_reason(self, reason: String) -> CancellableID {
809 match self {
810 CancellableIDWithReason::ChildWorkflow { seqnum } => {
811 CancellableID::ChildWorkflow { seqnum, reason }
812 }
813 CancellableIDWithReason::ExternalWorkflow { seqnum, execution } => {
814 CancellableID::ExternalWorkflow {
815 seqnum,
816 execution,
817 reason,
818 }
819 }
820 }
821 }
822}
823impl From<CancellableIDWithReason> for CancellableID {
824 fn from(v: CancellableIDWithReason) -> Self {
825 v.with_reason("".to_string())
826 }
827}
828
829#[derive(derive_more::From)]
830#[allow(clippy::large_enum_variant)]
831enum RustWfCmd {
832 #[from(ignore)]
833 Cancel(CancellableID),
834 ForceWFTFailure(anyhow::Error),
835 NewCmd(CommandCreateRequest),
836 NewNonblockingCmd(workflow_command::Variant),
837 SubscribeChildWorkflowCompletion(CommandSubscribeChildWorkflowCompletion),
838 SubscribeSignal(String, UnboundedSender<SignalData>),
839 RegisterUpdate(String, UpdateFunctions),
840 SubscribeNexusOperationCompletion {
841 seq: u32,
842 unblocker: oneshot::Sender<UnblockEvent>,
843 },
844}
845
846struct CommandCreateRequest {
847 cmd: WorkflowCommand,
848 unblocker: oneshot::Sender<UnblockEvent>,
849}
850
851struct CommandSubscribeChildWorkflowCompletion {
852 seq: u32,
853 unblocker: oneshot::Sender<UnblockEvent>,
854}
855
856type WfFunc = dyn Fn(WfContext) -> BoxFuture<'static, Result<WfExitValue<Payload>, anyhow::Error>>
857 + Send
858 + Sync
859 + 'static;
860
861pub struct WorkflowFunction {
863 wf_func: Box<WfFunc>,
864}
865
866impl<F, Fut, O> From<F> for WorkflowFunction
867where
868 F: Fn(WfContext) -> Fut + Send + Sync + 'static,
869 Fut: Future<Output = Result<WfExitValue<O>, anyhow::Error>> + Send + 'static,
870 O: Serialize,
871{
872 fn from(wf_func: F) -> Self {
873 Self::new(wf_func)
874 }
875}
876
877impl WorkflowFunction {
878 pub fn new<F, Fut, O>(f: F) -> Self
880 where
881 F: Fn(WfContext) -> Fut + Send + Sync + 'static,
882 Fut: Future<Output = Result<WfExitValue<O>, anyhow::Error>> + Send + 'static,
883 O: Serialize,
884 {
885 Self {
886 wf_func: Box::new(move |ctx: WfContext| {
887 (f)(ctx)
888 .map(|r| {
889 r.and_then(|r| {
890 Ok(match r {
891 WfExitValue::ContinueAsNew(b) => WfExitValue::ContinueAsNew(b),
892 WfExitValue::Cancelled => WfExitValue::Cancelled,
893 WfExitValue::Evicted => WfExitValue::Evicted,
894 WfExitValue::Normal(o) => WfExitValue::Normal(o.as_json_payload()?),
895 })
896 })
897 })
898 .boxed()
899 }),
900 }
901 }
902}
903
904pub type WorkflowResult<T> = Result<WfExitValue<T>, anyhow::Error>;
906
907#[derive(Debug, derive_more::From)]
909pub enum WfExitValue<T> {
910 #[from(ignore)]
912 ContinueAsNew(Box<ContinueAsNewWorkflowExecution>),
913 #[from(ignore)]
915 Cancelled,
916 #[from(ignore)]
918 Evicted,
919 Normal(T),
921}
922
923impl<T> WfExitValue<T> {
924 pub fn continue_as_new(can: ContinueAsNewWorkflowExecution) -> Self {
926 Self::ContinueAsNew(Box::new(can))
927 }
928}
929
930pub enum ActExitValue<T> {
932 WillCompleteAsync,
934 Normal(T),
936}
937
938impl<T: AsJsonPayloadExt> From<T> for ActExitValue<T> {
939 fn from(t: T) -> Self {
940 Self::Normal(t)
941 }
942}
943
944type BoxActFn = Arc<
945 dyn Fn(ActContext, Payload) -> BoxFuture<'static, Result<ActExitValue<Payload>, ActivityError>>
946 + Send
947 + Sync,
948>;
949
950#[derive(Clone)]
952pub struct ActivityFunction {
953 act_func: BoxActFn,
954}
955
956#[derive(Debug)]
958pub enum ActivityError {
959 Retryable {
963 source: anyhow::Error,
965 explicit_delay: Option<Duration>,
967 },
968 Cancelled {
970 details: Option<Payload>,
972 },
973 NonRetryable(anyhow::Error),
977}
978
979impl<E> From<E> for ActivityError
980where
981 E: Into<anyhow::Error>,
982{
983 fn from(source: E) -> Self {
984 Self::Retryable {
985 source: source.into(),
986 explicit_delay: None,
987 }
988 }
989}
990
991impl ActivityError {
992 pub fn cancelled() -> Self {
994 Self::Cancelled { details: None }
995 }
996}
997
998pub trait IntoActivityFunc<Args, Res, Out> {
1000 fn into_activity_fn(self) -> BoxActFn;
1002}
1003
1004impl<A, Rf, R, O, F> IntoActivityFunc<A, Rf, O> for F
1005where
1006 F: (Fn(ActContext, A) -> Rf) + Sync + Send + 'static,
1007 A: FromJsonPayloadExt + Send,
1008 Rf: Future<Output = Result<R, ActivityError>> + Send + 'static,
1009 R: Into<ActExitValue<O>>,
1010 O: AsJsonPayloadExt,
1011{
1012 fn into_activity_fn(self) -> BoxActFn {
1013 let wrapper = move |ctx: ActContext, input: Payload| {
1014 match A::from_json_payload(&input) {
1016 Ok(deser) => self(ctx, deser)
1017 .map(|r| {
1018 r.and_then(|r| {
1019 let exit_val: ActExitValue<O> = r.into();
1020 match exit_val {
1021 ActExitValue::WillCompleteAsync => {
1022 Ok(ActExitValue::WillCompleteAsync)
1023 }
1024 ActExitValue::Normal(x) => match x.as_json_payload() {
1025 Ok(v) => Ok(ActExitValue::Normal(v)),
1026 Err(e) => Err(ActivityError::NonRetryable(e)),
1027 },
1028 }
1029 })
1030 })
1031 .boxed(),
1032 Err(e) => async move { Err(ActivityError::NonRetryable(e.into())) }.boxed(),
1033 }
1034 };
1035 Arc::new(wrapper)
1036 }
1037}
1038
1039#[derive(Clone)]
1041pub struct UpdateInfo {
1042 pub update_id: String,
1044 pub headers: HashMap<String, Payload>,
1046}
1047
1048pub struct UpdateContext {
1050 pub wf_ctx: WfContext,
1052 pub info: UpdateInfo,
1054}
1055
1056struct UpdateFunctions {
1057 validator: BoxUpdateValidatorFn,
1058 handler: BoxUpdateHandlerFn,
1059}
1060
1061impl UpdateFunctions {
1062 pub(crate) fn new<Arg, Res>(
1063 v: impl IntoUpdateValidatorFunc<Arg> + Sized,
1064 h: impl IntoUpdateHandlerFunc<Arg, Res> + Sized,
1065 ) -> Self {
1066 Self {
1067 validator: v.into_update_validator_fn(),
1068 handler: h.into_update_handler_fn(),
1069 }
1070 }
1071}
1072
1073type BoxUpdateValidatorFn = Box<dyn Fn(&UpdateInfo, &Payload) -> Result<(), anyhow::Error> + Send>;
1074pub trait IntoUpdateValidatorFunc<Arg> {
1076 fn into_update_validator_fn(self) -> BoxUpdateValidatorFn;
1078}
1079impl<A, F> IntoUpdateValidatorFunc<A> for F
1080where
1081 A: FromJsonPayloadExt + Send,
1082 F: (for<'a> Fn(&'a UpdateInfo, A) -> Result<(), anyhow::Error>) + Send + 'static,
1083{
1084 fn into_update_validator_fn(self) -> BoxUpdateValidatorFn {
1085 let wrapper = move |ctx: &UpdateInfo, input: &Payload| match A::from_json_payload(input) {
1086 Ok(deser) => (self)(ctx, deser),
1087 Err(e) => Err(e.into()),
1088 };
1089 Box::new(wrapper)
1090 }
1091}
1092type BoxUpdateHandlerFn = Box<
1093 dyn FnMut(UpdateContext, &Payload) -> BoxFuture<'static, Result<Payload, anyhow::Error>> + Send,
1094>;
1095pub trait IntoUpdateHandlerFunc<Arg, Res> {
1097 fn into_update_handler_fn(self) -> BoxUpdateHandlerFn;
1099}
1100impl<A, F, Rf, R> IntoUpdateHandlerFunc<A, R> for F
1101where
1102 A: FromJsonPayloadExt + Send,
1103 F: (FnMut(UpdateContext, A) -> Rf) + Send + 'static,
1104 Rf: Future<Output = Result<R, anyhow::Error>> + Send + 'static,
1105 R: AsJsonPayloadExt,
1106{
1107 fn into_update_handler_fn(mut self) -> BoxUpdateHandlerFn {
1108 let wrapper = move |ctx: UpdateContext, input: &Payload| match A::from_json_payload(input) {
1109 Ok(deser) => (self)(ctx, deser)
1110 .map(|r| r.and_then(|r| r.as_json_payload()))
1111 .boxed(),
1112 Err(e) => async move { Err(e.into()) }.boxed(),
1113 };
1114 Box::new(wrapper)
1115 }
1116}
1117
1118fn panic_formatter(panic: Box<dyn Any>) -> Box<dyn Display> {
1120 _panic_formatter::<&str>(panic)
1121}
1122fn _panic_formatter<T: 'static + PrintablePanicType>(panic: Box<dyn Any>) -> Box<dyn Display> {
1123 match panic.downcast::<T>() {
1124 Ok(d) => d,
1125 Err(orig) => {
1126 if TypeId::of::<<T as PrintablePanicType>::NextType>()
1127 == TypeId::of::<EndPrintingAttempts>()
1128 {
1129 return Box::new("Couldn't turn panic into a string");
1130 }
1131 _panic_formatter::<T::NextType>(orig)
1132 }
1133 }
1134}
1135trait PrintablePanicType: Display {
1136 type NextType: PrintablePanicType;
1137}
1138impl PrintablePanicType for &str {
1139 type NextType = String;
1140}
1141impl PrintablePanicType for String {
1142 type NextType = EndPrintingAttempts;
1143}
1144struct EndPrintingAttempts {}
1145impl Display for EndPrintingAttempts {
1146 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1147 write!(f, "Will never be printed")
1148 }
1149}
1150impl PrintablePanicType for EndPrintingAttempts {
1151 type NextType = EndPrintingAttempts;
1152}