squads_temporal_sdk/
lib.rs

1#![warn(missing_docs)] // error if there are missing docs
2
3//! This crate defines an alpha-stage Temporal Rust SDK.
4//!
5//! Currently defining activities and running an activity-only worker is the most stable code.
6//! Workflow definitions exist and running a workflow worker works, but the API is still very
7//! unstable.
8//!
9//! An example of running an activity worker:
10//! ```no_run
11//! use std::{str::FromStr, sync::Arc};
12//! use temporal_sdk::{sdk_client_options, ActContext, Worker};
13//! use squads_temporal_sdk_core::{init_worker, Url, CoreRuntime};
14//! use squads_temporal_sdk_core_api::{
15//!     worker::{WorkerConfigBuilder, WorkerVersioningStrategy},
16//!     telemetry::TelemetryOptionsBuilder
17//! };
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//!     let server_options = sdk_client_options(Url::from_str("http://localhost:7233")?).build()?;
22//!
23//!     let client = server_options.connect("default", None).await?;
24//!
25//!     let telemetry_options = TelemetryOptionsBuilder::default().build()?;
26//!     let runtime = CoreRuntime::new_assume_tokio(telemetry_options)?;
27//!
28//!     let worker_config = WorkerConfigBuilder::default()
29//!         .namespace("default")
30//!         .task_queue("task_queue")
31//!         .versioning_strategy(WorkerVersioningStrategy::None { build_id: "rust-sdk".to_owned() })
32//!         .build()?;
33//!
34//!     let core_worker = init_worker(&runtime, worker_config, client)?;
35//!
36//!     let mut worker = Worker::new_from_core(Arc::new(core_worker), "task_queue");
37//!     worker.register_activity(
38//!         "echo_activity",
39//!         |_ctx: ActContext, echo_me: String| async move { Ok(echo_me) },
40//!     );
41//!
42//!     worker.run().await?;
43//!
44//!     Ok(())
45//! }
46//! ```
47
48#[macro_use]
49extern crate tracing;
50
51mod activity_context;
52mod app_data;
53pub mod interceptors;
54pub mod prelude;
55pub mod workflow;
56mod workflow_context;
57mod workflow_future;
58
59pub use activity_context::ActContext;
60pub use squads_temporal_client::Namespace;
61use tracing::{Instrument, Span, field};
62pub use workflow_context::{
63    ActivityOptions, CancellableFuture, ChildWorkflow, ChildWorkflowOptions, LocalActivityOptions,
64    NexusOperationOptions, PendingChildWorkflow, Signal, SignalData, SignalWorkflowOptions,
65    StartedChildWorkflow, TimerOptions, WfContext,
66};
67
68use crate::{
69    interceptors::WorkerInterceptor,
70    workflow_context::{ChildWfCommon, NexusUnblockData, StartedNexusOperation},
71};
72use anyhow::{Context, anyhow, bail};
73use app_data::AppData;
74use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::BoxFuture};
75use serde::Serialize;
76use std::{
77    any::{Any, TypeId},
78    cell::RefCell,
79    collections::HashMap,
80    fmt::{Debug, Display, Formatter},
81    future::Future,
82    panic::AssertUnwindSafe,
83    sync::Arc,
84    time::Duration,
85};
86use squads_temporal_client::ClientOptionsBuilder;
87use squads_temporal_sdk_core::Url;
88use squads_temporal_sdk_core_api::{Worker as CoreWorker, errors::PollError};
89use squads_temporal_sdk_core_protos::{
90    TaskToken,
91    coresdk::{
92        ActivityTaskCompletion, AsJsonPayloadExt, FromJsonPayloadExt,
93        activity_result::{ActivityExecutionResult, ActivityResolution},
94        activity_task::{ActivityTask, activity_task},
95        child_workflow::ChildWorkflowResult,
96        common::NamespacedWorkflowExecution,
97        nexus::NexusOperationResult,
98        workflow_activation::{
99            WorkflowActivation,
100            resolve_child_workflow_execution_start::Status as ChildWorkflowStartStatus,
101            resolve_nexus_operation_start, workflow_activation_job::Variant,
102        },
103        workflow_commands::{ContinueAsNewWorkflowExecution, WorkflowCommand, workflow_command},
104        workflow_completion::WorkflowActivationCompletion,
105    },
106    temporal::api::{
107        common::v1::Payload,
108        enums::v1::WorkflowTaskFailedCause,
109        failure::v1::{Failure, failure},
110    },
111};
112use tokio::{
113    sync::{
114        Notify,
115        mpsc::{UnboundedSender, unbounded_channel},
116        oneshot,
117    },
118    task::JoinError,
119};
120use tokio_stream::wrappers::UnboundedReceiverStream;
121use tokio_util::sync::CancellationToken;
122
123const VERSION: &str = env!("CARGO_PKG_VERSION");
124
125/// Returns a [ClientOptionsBuilder] with required fields set to appropriate values
126/// for the Rust SDK.
127pub fn sdk_client_options(url: impl Into<Url>) -> ClientOptionsBuilder {
128    let mut builder = ClientOptionsBuilder::default();
129    builder
130        .target_url(url)
131        .client_name("temporal-rust".to_string())
132        .client_version(VERSION.to_string());
133
134    builder
135}
136
137/// A worker that can poll for and respond to workflow tasks by using [WorkflowFunction]s,
138/// and activity tasks by using [ActivityFunction]s
139pub struct Worker {
140    common: CommonWorker,
141    workflow_half: WorkflowHalf,
142    activity_half: ActivityHalf,
143    app_data: Option<AppData>,
144}
145
146struct CommonWorker {
147    worker: Arc<dyn CoreWorker>,
148    task_queue: String,
149    worker_interceptor: Option<Box<dyn WorkerInterceptor>>,
150}
151
152struct WorkflowHalf {
153    /// Maps run id to cached workflow state
154    workflows: RefCell<HashMap<String, WorkflowData>>,
155    /// Maps workflow type to the function for executing workflow runs with that ID
156    workflow_fns: RefCell<HashMap<String, WorkflowFunction>>,
157    workflow_removed_from_map: Notify,
158}
159struct WorkflowData {
160    /// Channel used to send the workflow activations
161    activation_chan: UnboundedSender<WorkflowActivation>,
162}
163
164struct WorkflowFutureHandle<F: Future<Output = Result<WorkflowResult<Payload>, JoinError>>> {
165    join_handle: F,
166    run_id: String,
167}
168
169struct ActivityHalf {
170    /// Maps activity type to the function for executing activities of that type
171    activity_fns: HashMap<String, ActivityFunction>,
172    task_tokens_to_cancels: HashMap<TaskToken, CancellationToken>,
173}
174
175impl Worker {
176    /// Create a new Rust SDK worker from a core worker
177    pub fn new_from_core(worker: Arc<dyn CoreWorker>, task_queue: impl Into<String>) -> Self {
178        Self {
179            common: CommonWorker {
180                worker,
181                task_queue: task_queue.into(),
182                worker_interceptor: None,
183            },
184            workflow_half: WorkflowHalf {
185                workflows: Default::default(),
186                workflow_fns: Default::default(),
187                workflow_removed_from_map: Default::default(),
188            },
189            activity_half: ActivityHalf {
190                activity_fns: Default::default(),
191                task_tokens_to_cancels: Default::default(),
192            },
193            app_data: Some(Default::default()),
194        }
195    }
196
197    /// Returns the task queue name this worker polls on
198    pub fn task_queue(&self) -> &str {
199        &self.common.task_queue
200    }
201
202    /// Return a handle that can be used to initiate shutdown.
203    /// TODO: Doc better after shutdown changes
204    pub fn shutdown_handle(&self) -> impl Fn() + use<> {
205        let w = self.common.worker.clone();
206        move || w.initiate_shutdown()
207    }
208
209    /// Register a Workflow function to invoke when the Worker is asked to run a workflow of
210    /// `workflow_type`
211    pub fn register_wf(
212        &mut self,
213        workflow_type: impl Into<String>,
214        wf_function: impl Into<WorkflowFunction>,
215    ) {
216        self.workflow_half
217            .workflow_fns
218            .get_mut()
219            .insert(workflow_type.into(), wf_function.into());
220    }
221
222    /// Register an Activity function to invoke when the Worker is asked to run an activity of
223    /// `activity_type`
224    pub fn register_activity<A, R, O>(
225        &mut self,
226        activity_type: impl Into<String>,
227        act_function: impl IntoActivityFunc<A, R, O>,
228    ) {
229        self.activity_half.activity_fns.insert(
230            activity_type.into(),
231            ActivityFunction {
232                act_func: act_function.into_activity_fn(),
233            },
234        );
235    }
236
237    /// Insert Custom App Context for Workflows and Activities
238    pub fn insert_app_data<T: Send + Sync + 'static>(&mut self, data: T) {
239        self.app_data.as_mut().map(|a| a.insert(data));
240    }
241
242    /// Runs the worker. Eventually resolves after the worker has been explicitly shut down,
243    /// or may return early with an error in the event of some unresolvable problem.
244    pub async fn run(&mut self) -> Result<(), anyhow::Error> {
245        let shutdown_token = CancellationToken::new();
246        let (common, wf_half, act_half, app_data) = self.split_apart();
247        let safe_app_data = Arc::new(
248            app_data
249                .take()
250                .ok_or_else(|| anyhow!("app_data should exist on run"))?,
251        );
252        let (wf_future_tx, wf_future_rx) = unbounded_channel();
253        let (completions_tx, completions_rx) = unbounded_channel();
254        let wf_future_joiner = async {
255            UnboundedReceiverStream::new(wf_future_rx)
256                .map(Result::<_, anyhow::Error>::Ok)
257                .try_for_each_concurrent(
258                    None,
259                    |WorkflowFutureHandle {
260                         join_handle,
261                         run_id,
262                     }| {
263                        let wf_half = &*wf_half;
264                        async move {
265                            join_handle.await??;
266                            debug!(run_id=%run_id, "Removing workflow from cache");
267                            wf_half.workflows.borrow_mut().remove(&run_id);
268                            wf_half.workflow_removed_from_map.notify_one();
269                            Ok(())
270                        }
271                    },
272                )
273                .await
274                .context("Workflow futures encountered an error")
275        };
276        let wf_completion_processor = async {
277            UnboundedReceiverStream::new(completions_rx)
278                .map(Ok)
279                .try_for_each_concurrent(None, |completion| async {
280                    if let Some(ref i) = common.worker_interceptor {
281                        i.on_workflow_activation_completion(&completion).await;
282                    }
283                    common.worker.complete_workflow_activation(completion).await
284                })
285                .map_err(anyhow::Error::from)
286                .await
287                .context("Workflow completions processor encountered an error")
288        };
289        tokio::try_join!(
290            // Workflow polling loop
291            async {
292                loop {
293                    let activation = match common.worker.poll_workflow_activation().await {
294                        Err(PollError::ShutDown) => {
295                            break;
296                        }
297                        o => o?,
298                    };
299                    if let Some(ref i) = common.worker_interceptor {
300                        i.on_workflow_activation(&activation).await?;
301                    }
302                    if let Some(wf_fut) = wf_half
303                        .workflow_activation_handler(
304                            common,
305                            shutdown_token.clone(),
306                            activation,
307                            &completions_tx,
308                        )
309                        .await?
310                        && wf_future_tx.send(wf_fut).is_err()
311                    {
312                        panic!("Receive half of completion processor channel cannot be dropped");
313                    }
314                }
315                // Tell still-alive workflows to evict themselves
316                shutdown_token.cancel();
317                // It's important to drop these so the future and completion processors will
318                // terminate.
319                drop(wf_future_tx);
320                drop(completions_tx);
321                Result::<_, anyhow::Error>::Ok(())
322            },
323            // Only poll on the activity queue if activity functions have been registered. This
324            // makes tests which use mocks dramatically more manageable.
325            async {
326                if !act_half.activity_fns.is_empty() {
327                    loop {
328                        let activity = common.worker.poll_activity_task().await;
329                        if matches!(activity, Err(PollError::ShutDown)) {
330                            break;
331                        }
332                        act_half.activity_task_handler(
333                            common.worker.clone(),
334                            safe_app_data.clone(),
335                            common.task_queue.clone(),
336                            activity?,
337                        )?;
338                    }
339                };
340                Result::<_, anyhow::Error>::Ok(())
341            },
342            wf_future_joiner,
343            wf_completion_processor,
344        )?;
345
346        info!("Polling loops exited");
347        if let Some(i) = self.common.worker_interceptor.as_ref() {
348            i.on_shutdown(self);
349        }
350        self.common.worker.shutdown().await;
351        debug!("Worker shutdown complete");
352        self.app_data = Some(
353            Arc::try_unwrap(safe_app_data)
354                .map_err(|_| anyhow!("some references of AppData exist on worker shutdown"))?,
355        );
356        Ok(())
357    }
358
359    /// Set a [WorkerInterceptor]
360    pub fn set_worker_interceptor(&mut self, interceptor: impl WorkerInterceptor + 'static) {
361        self.common.worker_interceptor = Some(Box::new(interceptor));
362    }
363
364    /// Turns this rust worker into a new worker with all the same workflows and activities
365    /// registered, but with a new underlying core worker. Can be used to swap the worker for
366    /// a replay worker, change task queues, etc.
367    pub fn with_new_core_worker(&mut self, new_core_worker: Arc<dyn CoreWorker>) {
368        self.common.worker = new_core_worker;
369    }
370
371    /// Returns number of currently cached workflows as understood by the SDK. Importantly, this
372    /// is not the same as understood by core, though they *should* always be in sync.
373    pub fn cached_workflows(&self) -> usize {
374        self.workflow_half.workflows.borrow().len()
375    }
376
377    fn split_apart(
378        &mut self,
379    ) -> (
380        &mut CommonWorker,
381        &mut WorkflowHalf,
382        &mut ActivityHalf,
383        &mut Option<AppData>,
384    ) {
385        (
386            &mut self.common,
387            &mut self.workflow_half,
388            &mut self.activity_half,
389            &mut self.app_data,
390        )
391    }
392}
393
394impl WorkflowHalf {
395    #[allow(clippy::type_complexity)]
396    async fn workflow_activation_handler(
397        &self,
398        common: &CommonWorker,
399        shutdown_token: CancellationToken,
400        mut activation: WorkflowActivation,
401        completions_tx: &UnboundedSender<WorkflowActivationCompletion>,
402    ) -> Result<
403        Option<
404            WorkflowFutureHandle<
405                impl Future<Output = Result<WorkflowResult<Payload>, JoinError>> + use<>,
406            >,
407        >,
408        anyhow::Error,
409    > {
410        let mut res = None;
411        let run_id = activation.run_id.clone();
412
413        // If the activation is to init a workflow, create a new workflow driver for it,
414        // using the function associated with that workflow id
415        if let Some(sw) = activation.jobs.iter_mut().find_map(|j| match j.variant {
416            Some(Variant::InitializeWorkflow(ref mut sw)) => Some(sw),
417            _ => None,
418        }) {
419            let workflow_type = &sw.workflow_type;
420            let (wff, activations) = {
421                let wf_fns_borrow = self.workflow_fns.borrow();
422
423                let Some(wf_function) = wf_fns_borrow.get(workflow_type) else {
424                    warn!("Workflow type {workflow_type} not found");
425
426                    completions_tx
427                        .send(WorkflowActivationCompletion::fail(
428                            run_id,
429                            format!("Workflow type {workflow_type} not found").into(),
430                            Some(WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure),
431                        ))
432                        .expect("Completion channel intact");
433                    return Ok(None);
434                };
435
436                wf_function.start_workflow(
437                    common.worker.get_config().namespace.clone(),
438                    common.task_queue.clone(),
439                    std::mem::take(sw),
440                    completions_tx.clone(),
441                )
442            };
443            let jh = tokio::spawn(async move {
444                tokio::select! {
445                    r = wff.fuse() => r,
446                    // TODO: This probably shouldn't abort early, as it could cause an in-progress
447                    //  complete to abort. Send synthetic remove activation
448                    _ = shutdown_token.cancelled() => {
449                        Ok(WfExitValue::Evicted)
450                    }
451                }
452            });
453            res = Some(WorkflowFutureHandle {
454                join_handle: jh,
455                run_id: run_id.clone(),
456            });
457            loop {
458                // It's possible that we've got a new initialize workflow action before the last
459                // future for this run finished evicting, as a result of how futures might be
460                // interleaved. In that case, just wait until it's not in the map, which should be
461                // a matter of only a few `poll` calls.
462                if self.workflows.borrow_mut().contains_key(&run_id) {
463                    self.workflow_removed_from_map.notified().await;
464                } else {
465                    break;
466                }
467            }
468            self.workflows.borrow_mut().insert(
469                run_id.clone(),
470                WorkflowData {
471                    activation_chan: activations,
472                },
473            );
474        }
475
476        // The activation is expected to apply to some workflow we know about. Use it to
477        // unblock things and advance the workflow.
478        if let Some(dat) = self.workflows.borrow_mut().get_mut(&run_id) {
479            dat.activation_chan
480                .send(activation)
481                .expect("Workflow should exist if we're sending it an activation");
482        } else {
483            // When we failed to start a workflow, we never inserted it into the cache. But core
484            // sends us a `RemoveFromCache` job when we mark the StartWorkflow workflow activation
485            // as a failure, which we need to complete. Other SDKs add the workflow to the cache
486            // even when the workflow type is unknown/not found. To circumvent this, we simply mark
487            // any RemoveFromCache job for workflows that are not in the cache as complete.
488            if activation.jobs.len() == 1
489                && matches!(
490                    activation.jobs.first().map(|j| &j.variant),
491                    Some(Some(Variant::RemoveFromCache(_)))
492                )
493            {
494                completions_tx
495                    .send(WorkflowActivationCompletion::from_cmds(run_id, vec![]))
496                    .expect("Completion channel intact");
497                return Ok(None);
498            }
499
500            // In all other cases, we want to error as the runtime could be in an inconsistent state
501            // at this point.
502            bail!(
503                "Got activation {:?} for unknown workflow {}",
504                activation,
505                run_id
506            );
507        };
508
509        Ok(res)
510    }
511}
512
513impl ActivityHalf {
514    /// Spawns off a task to handle the provided activity task
515    fn activity_task_handler(
516        &mut self,
517        worker: Arc<dyn CoreWorker>,
518        app_data: Arc<AppData>,
519        task_queue: String,
520        activity: ActivityTask,
521    ) -> Result<(), anyhow::Error> {
522        match activity.variant {
523            Some(activity_task::Variant::Start(start)) => {
524                let act_fn = self
525                    .activity_fns
526                    .get(&start.activity_type)
527                    .ok_or_else(|| {
528                        anyhow!(
529                            "No function registered for activity type {}",
530                            start.activity_type
531                        )
532                    })?
533                    .clone();
534                let span = info_span!(
535                    "RunActivity",
536                    "otel.name" = format!("RunActivity:{}", start.activity_type),
537                    "otel.kind" = "server",
538                    "temporalActivityID" = start.activity_id,
539                    "temporalWorkflowID" = field::Empty,
540                    "temporalRunID" = field::Empty,
541                );
542                let ct = CancellationToken::new();
543                let task_token = activity.task_token;
544                self.task_tokens_to_cancels
545                    .insert(task_token.clone().into(), ct.clone());
546
547                let (ctx, arg) = ActContext::new(
548                    worker.clone(),
549                    app_data,
550                    ct,
551                    task_queue,
552                    task_token.clone(),
553                    start,
554                );
555
556                tokio::spawn(async move {
557                    let act_fut = async move {
558                        if let Some(info) = &ctx.get_info().workflow_execution {
559                            Span::current()
560                                .record("temporalWorkflowID", &info.workflow_id)
561                                .record("temporalRunID", &info.run_id);
562                        }
563                        (act_fn.act_func)(ctx, arg).await
564                    }
565                    .instrument(span);
566                    let output = AssertUnwindSafe(act_fut).catch_unwind().await;
567                    let result = match output {
568                        Err(e) => ActivityExecutionResult::fail(Failure::application_failure(
569                            format!("Activity function panicked: {}", panic_formatter(e)),
570                            true,
571                        )),
572                        Ok(Ok(ActExitValue::Normal(p))) => ActivityExecutionResult::ok(p),
573                        Ok(Ok(ActExitValue::WillCompleteAsync)) => {
574                            ActivityExecutionResult::will_complete_async()
575                        }
576                        Ok(Err(err)) => match err {
577                            ActivityError::Retryable {
578                                source,
579                                explicit_delay,
580                            } => ActivityExecutionResult::fail({
581                                let mut f = Failure::application_failure_from_error(source, false);
582                                if let Some(d) = explicit_delay
583                                    && let Some(failure::FailureInfo::ApplicationFailureInfo(fi)) =
584                                        f.failure_info.as_mut()
585                                {
586                                    fi.next_retry_delay = d.try_into().ok();
587                                }
588                                f
589                            }),
590                            ActivityError::Cancelled { details } => {
591                                ActivityExecutionResult::cancel_from_details(details)
592                            }
593                            ActivityError::NonRetryable(nre) => ActivityExecutionResult::fail(
594                                Failure::application_failure_from_error(nre, true),
595                            ),
596                        },
597                    };
598                    worker
599                        .complete_activity_task(ActivityTaskCompletion {
600                            task_token,
601                            result: Some(result),
602                        })
603                        .await?;
604                    Ok::<_, anyhow::Error>(())
605                });
606            }
607            Some(activity_task::Variant::Cancel(_)) => {
608                if let Some(ct) = self
609                    .task_tokens_to_cancels
610                    .get(activity.task_token.as_slice())
611                {
612                    ct.cancel();
613                }
614            }
615            None => bail!("Undefined activity task variant"),
616        }
617        Ok(())
618    }
619}
620
621#[derive(Debug)]
622enum UnblockEvent {
623    Timer(u32, TimerResult),
624    Activity(u32, Box<ActivityResolution>),
625    WorkflowStart(u32, Box<ChildWorkflowStartStatus>),
626    WorkflowComplete(u32, Box<ChildWorkflowResult>),
627    SignalExternal(u32, Option<Failure>),
628    CancelExternal(u32, Option<Failure>),
629    NexusOperationStart(u32, Box<resolve_nexus_operation_start::Status>),
630    NexusOperationComplete(u32, Box<NexusOperationResult>),
631}
632
633/// Result of awaiting on a timer
634#[derive(Debug, Copy, Clone)]
635pub enum TimerResult {
636    /// The timer was cancelled
637    Cancelled,
638    /// The timer elapsed and fired
639    Fired,
640}
641
642/// Successful result of sending a signal to an external workflow
643#[derive(Debug)]
644pub struct SignalExternalOk;
645/// Result of awaiting on sending a signal to an external workflow
646pub type SignalExternalWfResult = Result<SignalExternalOk, Failure>;
647
648/// Successful result of sending a cancel request to an external workflow
649#[derive(Debug)]
650pub struct CancelExternalOk;
651/// Result of awaiting on sending a cancel request to an external workflow
652pub type CancelExternalWfResult = Result<CancelExternalOk, Failure>;
653
654trait Unblockable {
655    type OtherDat;
656
657    fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self;
658}
659
660impl Unblockable for TimerResult {
661    type OtherDat = ();
662    fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
663        match ue {
664            UnblockEvent::Timer(_, result) => result,
665            _ => panic!("Invalid unblock event for timer"),
666        }
667    }
668}
669
670impl Unblockable for ActivityResolution {
671    type OtherDat = ();
672    fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
673        match ue {
674            UnblockEvent::Activity(_, result) => *result,
675            _ => panic!("Invalid unblock event for activity"),
676        }
677    }
678}
679
680impl Unblockable for PendingChildWorkflow {
681    // Other data here is workflow id
682    type OtherDat = ChildWfCommon;
683    fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self {
684        match ue {
685            UnblockEvent::WorkflowStart(_, result) => Self {
686                status: *result,
687                common: od,
688            },
689            _ => panic!("Invalid unblock event for child workflow start"),
690        }
691    }
692}
693
694impl Unblockable for ChildWorkflowResult {
695    type OtherDat = ();
696    fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
697        match ue {
698            UnblockEvent::WorkflowComplete(_, result) => *result,
699            _ => panic!("Invalid unblock event for child workflow complete"),
700        }
701    }
702}
703
704impl Unblockable for SignalExternalWfResult {
705    type OtherDat = ();
706    fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
707        match ue {
708            UnblockEvent::SignalExternal(_, maybefail) => {
709                maybefail.map_or(Ok(SignalExternalOk), Err)
710            }
711            _ => panic!("Invalid unblock event for signal external workflow result"),
712        }
713    }
714}
715
716impl Unblockable for CancelExternalWfResult {
717    type OtherDat = ();
718    fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
719        match ue {
720            UnblockEvent::CancelExternal(_, maybefail) => {
721                maybefail.map_or(Ok(CancelExternalOk), Err)
722            }
723            _ => panic!("Invalid unblock event for signal external workflow result"),
724        }
725    }
726}
727
728type NexusStartResult = Result<StartedNexusOperation, Failure>;
729impl Unblockable for NexusStartResult {
730    type OtherDat = NexusUnblockData;
731    fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self {
732        match ue {
733            UnblockEvent::NexusOperationStart(_, result) => match *result {
734                resolve_nexus_operation_start::Status::OperationToken(op_token) => {
735                    Ok(StartedNexusOperation {
736                        operation_token: Some(op_token),
737                        unblock_dat: od,
738                    })
739                }
740                resolve_nexus_operation_start::Status::StartedSync(_) => {
741                    Ok(StartedNexusOperation {
742                        operation_token: None,
743                        unblock_dat: od,
744                    })
745                }
746                resolve_nexus_operation_start::Status::CancelledBeforeStart(f) => Err(f),
747            },
748            _ => panic!("Invalid unblock event for nexus operation"),
749        }
750    }
751}
752
753impl Unblockable for NexusOperationResult {
754    type OtherDat = ();
755
756    fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
757        match ue {
758            UnblockEvent::NexusOperationComplete(_, result) => *result,
759            _ => panic!("Invalid unblock event for nexus operation complete"),
760        }
761    }
762}
763
764/// Identifier for cancellable operations
765#[derive(Debug, Clone)]
766pub(crate) enum CancellableID {
767    Timer(u32),
768    Activity(u32),
769    LocalActivity(u32),
770    ChildWorkflow {
771        seqnum: u32,
772        reason: String,
773    },
774    SignalExternalWorkflow(u32),
775    ExternalWorkflow {
776        seqnum: u32,
777        execution: NamespacedWorkflowExecution,
778        reason: String,
779    },
780    /// A nexus operation (waiting for start)
781    NexusOp(u32),
782}
783
784/// Cancellation IDs that support a reason.
785pub(crate) trait SupportsCancelReason {
786    /// Returns a new version of this ID with the provided cancellation reason.
787    fn with_reason(self, reason: String) -> CancellableID;
788}
789#[derive(Debug, Clone)]
790pub(crate) enum CancellableIDWithReason {
791    ChildWorkflow {
792        seqnum: u32,
793    },
794    ExternalWorkflow {
795        seqnum: u32,
796        execution: NamespacedWorkflowExecution,
797    },
798}
799impl CancellableIDWithReason {
800    pub(crate) fn seq_num(&self) -> u32 {
801        match self {
802            CancellableIDWithReason::ChildWorkflow { seqnum } => *seqnum,
803            CancellableIDWithReason::ExternalWorkflow { seqnum, .. } => *seqnum,
804        }
805    }
806}
807impl SupportsCancelReason for CancellableIDWithReason {
808    fn with_reason(self, reason: String) -> CancellableID {
809        match self {
810            CancellableIDWithReason::ChildWorkflow { seqnum } => {
811                CancellableID::ChildWorkflow { seqnum, reason }
812            }
813            CancellableIDWithReason::ExternalWorkflow { seqnum, execution } => {
814                CancellableID::ExternalWorkflow {
815                    seqnum,
816                    execution,
817                    reason,
818                }
819            }
820        }
821    }
822}
823impl From<CancellableIDWithReason> for CancellableID {
824    fn from(v: CancellableIDWithReason) -> Self {
825        v.with_reason("".to_string())
826    }
827}
828
829#[derive(derive_more::From)]
830#[allow(clippy::large_enum_variant)]
831enum RustWfCmd {
832    #[from(ignore)]
833    Cancel(CancellableID),
834    ForceWFTFailure(anyhow::Error),
835    NewCmd(CommandCreateRequest),
836    NewNonblockingCmd(workflow_command::Variant),
837    SubscribeChildWorkflowCompletion(CommandSubscribeChildWorkflowCompletion),
838    SubscribeSignal(String, UnboundedSender<SignalData>),
839    RegisterUpdate(String, UpdateFunctions),
840    SubscribeNexusOperationCompletion {
841        seq: u32,
842        unblocker: oneshot::Sender<UnblockEvent>,
843    },
844}
845
846struct CommandCreateRequest {
847    cmd: WorkflowCommand,
848    unblocker: oneshot::Sender<UnblockEvent>,
849}
850
851struct CommandSubscribeChildWorkflowCompletion {
852    seq: u32,
853    unblocker: oneshot::Sender<UnblockEvent>,
854}
855
856type WfFunc = dyn Fn(WfContext) -> BoxFuture<'static, Result<WfExitValue<Payload>, anyhow::Error>>
857    + Send
858    + Sync
859    + 'static;
860
861/// The user's async function / workflow code
862pub struct WorkflowFunction {
863    wf_func: Box<WfFunc>,
864}
865
866impl<F, Fut, O> From<F> for WorkflowFunction
867where
868    F: Fn(WfContext) -> Fut + Send + Sync + 'static,
869    Fut: Future<Output = Result<WfExitValue<O>, anyhow::Error>> + Send + 'static,
870    O: Serialize,
871{
872    fn from(wf_func: F) -> Self {
873        Self::new(wf_func)
874    }
875}
876
877impl WorkflowFunction {
878    /// Build a workflow function from a closure or function pointer which accepts a [WfContext]
879    pub fn new<F, Fut, O>(f: F) -> Self
880    where
881        F: Fn(WfContext) -> Fut + Send + Sync + 'static,
882        Fut: Future<Output = Result<WfExitValue<O>, anyhow::Error>> + Send + 'static,
883        O: Serialize,
884    {
885        Self {
886            wf_func: Box::new(move |ctx: WfContext| {
887                (f)(ctx)
888                    .map(|r| {
889                        r.and_then(|r| {
890                            Ok(match r {
891                                WfExitValue::ContinueAsNew(b) => WfExitValue::ContinueAsNew(b),
892                                WfExitValue::Cancelled => WfExitValue::Cancelled,
893                                WfExitValue::Evicted => WfExitValue::Evicted,
894                                WfExitValue::Normal(o) => WfExitValue::Normal(o.as_json_payload()?),
895                            })
896                        })
897                    })
898                    .boxed()
899            }),
900        }
901    }
902}
903
904/// The result of running a workflow
905pub type WorkflowResult<T> = Result<WfExitValue<T>, anyhow::Error>;
906
907/// Workflow functions may return these values when exiting
908#[derive(Debug, derive_more::From)]
909pub enum WfExitValue<T> {
910    /// Continue the workflow as a new execution
911    #[from(ignore)]
912    ContinueAsNew(Box<ContinueAsNewWorkflowExecution>),
913    /// Confirm the workflow was cancelled (can be automatic in a more advanced iteration)
914    #[from(ignore)]
915    Cancelled,
916    /// The run was evicted
917    #[from(ignore)]
918    Evicted,
919    /// Finish with a result
920    Normal(T),
921}
922
923impl<T> WfExitValue<T> {
924    /// Construct a [WfExitValue::ContinueAsNew] variant (handles boxing)
925    pub fn continue_as_new(can: ContinueAsNewWorkflowExecution) -> Self {
926        Self::ContinueAsNew(Box::new(can))
927    }
928}
929
930/// Activity functions may return these values when exiting
931pub enum ActExitValue<T> {
932    /// Completion requires an asynchronous callback
933    WillCompleteAsync,
934    /// Finish with a result
935    Normal(T),
936}
937
938impl<T: AsJsonPayloadExt> From<T> for ActExitValue<T> {
939    fn from(t: T) -> Self {
940        Self::Normal(t)
941    }
942}
943
944type BoxActFn = Arc<
945    dyn Fn(ActContext, Payload) -> BoxFuture<'static, Result<ActExitValue<Payload>, ActivityError>>
946        + Send
947        + Sync,
948>;
949
950/// Container for user-defined activity functions
951#[derive(Clone)]
952pub struct ActivityFunction {
953    act_func: BoxActFn,
954}
955
956/// Returned as errors from activity functions
957#[derive(Debug)]
958pub enum ActivityError {
959    /// This error can be returned from activities to allow the explicit configuration of certain
960    /// error properties. It's also the default error type that arbitrary errors will be converted
961    /// into.
962    Retryable {
963        /// The underlying error
964        source: anyhow::Error,
965        /// If specified, the next retry (if there is one) will occur after this delay
966        explicit_delay: Option<Duration>,
967    },
968    /// Return this error to indicate your activity is cancelling
969    Cancelled {
970        /// Some data to save as the cancellation reason
971        details: Option<Payload>,
972    },
973    /// Return this error to indicate that your activity non-retryable
974    /// this is a transparent wrapper around anyhow Error so essentially any type of error
975    /// could be used here.
976    NonRetryable(anyhow::Error),
977}
978
979impl<E> From<E> for ActivityError
980where
981    E: Into<anyhow::Error>,
982{
983    fn from(source: E) -> Self {
984        Self::Retryable {
985            source: source.into(),
986            explicit_delay: None,
987        }
988    }
989}
990
991impl ActivityError {
992    /// Construct a cancelled error without details
993    pub fn cancelled() -> Self {
994        Self::Cancelled { details: None }
995    }
996}
997
998/// Closures / functions which can be turned into activity functions implement this trait
999pub trait IntoActivityFunc<Args, Res, Out> {
1000    /// Consume the closure or fn pointer and turned it into a boxed activity function
1001    fn into_activity_fn(self) -> BoxActFn;
1002}
1003
1004impl<A, Rf, R, O, F> IntoActivityFunc<A, Rf, O> for F
1005where
1006    F: (Fn(ActContext, A) -> Rf) + Sync + Send + 'static,
1007    A: FromJsonPayloadExt + Send,
1008    Rf: Future<Output = Result<R, ActivityError>> + Send + 'static,
1009    R: Into<ActExitValue<O>>,
1010    O: AsJsonPayloadExt,
1011{
1012    fn into_activity_fn(self) -> BoxActFn {
1013        let wrapper = move |ctx: ActContext, input: Payload| {
1014            // Some minor gymnastics are required to avoid needing to clone the function
1015            match A::from_json_payload(&input) {
1016                Ok(deser) => self(ctx, deser)
1017                    .map(|r| {
1018                        r.and_then(|r| {
1019                            let exit_val: ActExitValue<O> = r.into();
1020                            match exit_val {
1021                                ActExitValue::WillCompleteAsync => {
1022                                    Ok(ActExitValue::WillCompleteAsync)
1023                                }
1024                                ActExitValue::Normal(x) => match x.as_json_payload() {
1025                                    Ok(v) => Ok(ActExitValue::Normal(v)),
1026                                    Err(e) => Err(ActivityError::NonRetryable(e)),
1027                                },
1028                            }
1029                        })
1030                    })
1031                    .boxed(),
1032                Err(e) => async move { Err(ActivityError::NonRetryable(e.into())) }.boxed(),
1033            }
1034        };
1035        Arc::new(wrapper)
1036    }
1037}
1038
1039/// Extra information attached to workflow updates
1040#[derive(Clone)]
1041pub struct UpdateInfo {
1042    /// The update's id, unique within the workflow
1043    pub update_id: String,
1044    /// Headers attached to the update
1045    pub headers: HashMap<String, Payload>,
1046}
1047
1048/// Context for a workflow update
1049pub struct UpdateContext {
1050    /// The workflow context, can be used to do normal workflow things inside the update handler
1051    pub wf_ctx: WfContext,
1052    /// Additional update info
1053    pub info: UpdateInfo,
1054}
1055
1056struct UpdateFunctions {
1057    validator: BoxUpdateValidatorFn,
1058    handler: BoxUpdateHandlerFn,
1059}
1060
1061impl UpdateFunctions {
1062    pub(crate) fn new<Arg, Res>(
1063        v: impl IntoUpdateValidatorFunc<Arg> + Sized,
1064        h: impl IntoUpdateHandlerFunc<Arg, Res> + Sized,
1065    ) -> Self {
1066        Self {
1067            validator: v.into_update_validator_fn(),
1068            handler: h.into_update_handler_fn(),
1069        }
1070    }
1071}
1072
1073type BoxUpdateValidatorFn = Box<dyn Fn(&UpdateInfo, &Payload) -> Result<(), anyhow::Error> + Send>;
1074/// Closures / functions which can be turned into update validation functions implement this trait
1075pub trait IntoUpdateValidatorFunc<Arg> {
1076    /// Consume the closure/fn pointer and turn it into an update validator
1077    fn into_update_validator_fn(self) -> BoxUpdateValidatorFn;
1078}
1079impl<A, F> IntoUpdateValidatorFunc<A> for F
1080where
1081    A: FromJsonPayloadExt + Send,
1082    F: (for<'a> Fn(&'a UpdateInfo, A) -> Result<(), anyhow::Error>) + Send + 'static,
1083{
1084    fn into_update_validator_fn(self) -> BoxUpdateValidatorFn {
1085        let wrapper = move |ctx: &UpdateInfo, input: &Payload| match A::from_json_payload(input) {
1086            Ok(deser) => (self)(ctx, deser),
1087            Err(e) => Err(e.into()),
1088        };
1089        Box::new(wrapper)
1090    }
1091}
1092type BoxUpdateHandlerFn = Box<
1093    dyn FnMut(UpdateContext, &Payload) -> BoxFuture<'static, Result<Payload, anyhow::Error>> + Send,
1094>;
1095/// Closures / functions which can be turned into update handler functions implement this trait
1096pub trait IntoUpdateHandlerFunc<Arg, Res> {
1097    /// Consume the closure/fn pointer and turn it into an update handler
1098    fn into_update_handler_fn(self) -> BoxUpdateHandlerFn;
1099}
1100impl<A, F, Rf, R> IntoUpdateHandlerFunc<A, R> for F
1101where
1102    A: FromJsonPayloadExt + Send,
1103    F: (FnMut(UpdateContext, A) -> Rf) + Send + 'static,
1104    Rf: Future<Output = Result<R, anyhow::Error>> + Send + 'static,
1105    R: AsJsonPayloadExt,
1106{
1107    fn into_update_handler_fn(mut self) -> BoxUpdateHandlerFn {
1108        let wrapper = move |ctx: UpdateContext, input: &Payload| match A::from_json_payload(input) {
1109            Ok(deser) => (self)(ctx, deser)
1110                .map(|r| r.and_then(|r| r.as_json_payload()))
1111                .boxed(),
1112            Err(e) => async move { Err(e.into()) }.boxed(),
1113        };
1114        Box::new(wrapper)
1115    }
1116}
1117
1118/// Attempts to turn caught panics into something printable
1119fn panic_formatter(panic: Box<dyn Any>) -> Box<dyn Display> {
1120    _panic_formatter::<&str>(panic)
1121}
1122fn _panic_formatter<T: 'static + PrintablePanicType>(panic: Box<dyn Any>) -> Box<dyn Display> {
1123    match panic.downcast::<T>() {
1124        Ok(d) => d,
1125        Err(orig) => {
1126            if TypeId::of::<<T as PrintablePanicType>::NextType>()
1127                == TypeId::of::<EndPrintingAttempts>()
1128            {
1129                return Box::new("Couldn't turn panic into a string");
1130            }
1131            _panic_formatter::<T::NextType>(orig)
1132        }
1133    }
1134}
1135trait PrintablePanicType: Display {
1136    type NextType: PrintablePanicType;
1137}
1138impl PrintablePanicType for &str {
1139    type NextType = String;
1140}
1141impl PrintablePanicType for String {
1142    type NextType = EndPrintingAttempts;
1143}
1144struct EndPrintingAttempts {}
1145impl Display for EndPrintingAttempts {
1146    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1147        write!(f, "Will never be printed")
1148    }
1149}
1150impl PrintablePanicType for EndPrintingAttempts {
1151    type NextType = EndPrintingAttempts;
1152}