Skip to main content

temporalio_sdk/
workflows.rs

1//! Functionality related to defining and interacting with workflows
2//!
3//! This module contains traits and types for implementing workflows using the
4//! `#[workflow]` and `#[workflow_methods]` macros.
5//!
6//! Example usage:
7//! ```
8//! use temporalio_macros::{workflow, workflow_methods};
9//! use temporalio_sdk::{
10//!     SyncWorkflowContext, WorkflowContext, WorkflowContextView, WorkflowResult,
11//! };
12//!
13//! #[workflow]
14//! pub struct MyWorkflow {
15//!     counter: u32,
16//! }
17//!
18//! #[workflow_methods]
19//! impl MyWorkflow {
20//!     #[init]
21//!     pub fn new(ctx: &WorkflowContextView, input: String) -> Self {
22//!         Self { counter: 0 }
23//!     }
24//!
25//!     // Async run method uses ctx.state() for reading
26//!     #[run]
27//!     pub async fn run(ctx: &mut WorkflowContext<Self>) -> WorkflowResult<String> {
28//!         let counter = ctx.state(|s| s.counter);
29//!         Ok(format!("Done with counter: {}", counter))
30//!     }
31//!
32//!     // Sync signals use &mut self for direct mutations
33//!     #[signal]
34//!     pub fn increment(&mut self, ctx: &mut SyncWorkflowContext<Self>, amount: u32) {
35//!         self.counter += amount;
36//!     }
37//!
38//!     // Queries use &self with read-only context
39//!     #[query]
40//!     pub fn get_counter(&self, ctx: &WorkflowContextView) -> u32 {
41//!         self.counter
42//!     }
43//! }
44//! ```
45
46use crate::{
47    BaseWorkflowContext, SyncWorkflowContext, WorkflowContext, WorkflowContextView,
48    WorkflowTermination,
49};
50use futures_util::future::{Fuse, FutureExt, LocalBoxFuture};
51use std::{
52    cell::RefCell,
53    collections::HashMap,
54    fmt::Debug,
55    pin::Pin,
56    rc::Rc,
57    sync::Arc,
58    task::{Context as TaskContext, Poll},
59};
60use temporalio_common::{
61    QueryDefinition, SignalDefinition, UpdateDefinition, WorkflowDefinition,
62    data_converters::{
63        GenericPayloadConverter, PayloadConversionError, PayloadConverter, SerializationContext,
64        SerializationContextData, TemporalDeserializable, TemporalSerializable,
65    },
66    protos::temporal::api::{
67        common::v1::{Payload, Payloads},
68        failure::v1::Failure,
69    },
70};
71
72/// Error type for workflow operations
73#[derive(Debug, thiserror::Error)]
74pub enum WorkflowError {
75    /// Error during payload conversion
76    #[error("Payload conversion error: {0}")]
77    PayloadConversion(#[from] PayloadConversionError),
78
79    /// Workflow execution error
80    #[error("Workflow execution error: {0}")]
81    Execution(#[from] anyhow::Error),
82}
83
84impl From<WorkflowError> for Failure {
85    fn from(err: WorkflowError) -> Self {
86        Failure {
87            message: err.to_string(),
88            ..Default::default()
89        }
90    }
91}
92
93/// Trait implemented by workflow structs to enable execution by the worker.
94///
95/// This trait is typically generated by the `#[workflow_methods]` macro and should not
96/// be implemented manually in most cases.
97#[doc(hidden)]
98pub trait WorkflowImplementation: Sized + 'static {
99    /// The marker struct for the run method that implements `WorkflowDefinition`
100    type Run: WorkflowDefinition;
101
102    /// Whether this workflow has a user-defined `#[init]` method.
103    /// Set to `true` by the macro when `#[init]` is present, `false` otherwise.
104    const HAS_INIT: bool;
105
106    /// Whether the init method accepts the workflow input.
107    /// If true, input goes to init. If false, input goes to run.
108    const INIT_TAKES_INPUT: bool;
109
110    /// Returns the workflow type name.
111    fn name() -> &'static str;
112
113    /// Initialize the workflow instance.
114    ///
115    /// This is called when a new workflow execution starts. If `INIT_TAKES_INPUT` is true,
116    /// `input` will be `Some`. Otherwise it's `None`.
117    fn init(
118        ctx: WorkflowContextView,
119        input: Option<<Self::Run as WorkflowDefinition>::Input>,
120    ) -> Self;
121
122    /// Execute the workflow's main run function.
123    ///
124    /// If `INIT_TAKES_INPUT` is false, `input` will be `Some`. Otherwise it's `None`.
125    fn run(
126        ctx: WorkflowContext<Self>,
127        input: Option<<Self::Run as WorkflowDefinition>::Input>,
128    ) -> LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>;
129
130    /// Dispatch an update request by name. Returns `None` if no handler for that name.
131    fn dispatch_update(
132        _ctx: WorkflowContext<Self>,
133        _name: &str,
134        _payloads: Payloads,
135        _converter: &PayloadConverter,
136    ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
137        None
138    }
139
140    /// Validate an update request by name.
141    ///
142    /// Returns `None` if no handler for that name, `Some(Ok(()))` if valid,
143    /// `Some(Err(...))` if validation failed.
144    fn validate_update(
145        &self,
146        _ctx: WorkflowContextView,
147        _name: &str,
148        _payloads: &Payloads,
149        _converter: &PayloadConverter,
150    ) -> Option<Result<(), WorkflowError>> {
151        None
152    }
153
154    /// Dispatch a signal by name.
155    ///
156    /// Returns `None` if no handler for that name. For sync signals, the mutation happens
157    /// immediately and returns a completed future. For async signals, returns a future
158    /// that must be polled to completion.
159    fn dispatch_signal(
160        _ctx: WorkflowContext<Self>,
161        _name: &str,
162        _payloads: Payloads,
163        _converter: &PayloadConverter,
164    ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
165        None
166    }
167
168    /// Dispatch a query by name.
169    ///
170    /// Returns `None` if no handler for that name, `Some(Ok(payload))` on success,
171    /// `Some(Err(...))` on failure. Queries are synchronous and read-only.
172    fn dispatch_query(
173        &self,
174        _ctx: WorkflowContextView,
175        _name: &str,
176        _payloads: &Payloads,
177        _converter: &PayloadConverter,
178    ) -> Option<Result<Payload, WorkflowError>> {
179        None
180    }
181}
182
183// NOTE: In the below traits, the dispatch functions take context by ownership while the handle
184// methods take them by ref when sync and by ownership when async. They must be owned by async
185// handlers since the returned futures must be 'static.
186
187/// Trait for executing synchronous signal handlers on a workflow.
188#[doc(hidden)]
189pub trait ExecutableSyncSignal<S: SignalDefinition>: WorkflowImplementation {
190    /// Handle an incoming signal with the given input.
191    fn handle(&mut self, ctx: &mut SyncWorkflowContext<Self>, input: S::Input);
192
193    /// Dispatch the signal with payload deserialization.
194    fn dispatch(
195        ctx: WorkflowContext<Self>,
196        payloads: Payloads,
197        converter: &PayloadConverter,
198    ) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
199        match deserialize_input::<S::Input>(payloads.payloads, converter) {
200            Ok(input) => {
201                let mut sync_ctx = ctx.sync_context();
202                ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
203                std::future::ready(Ok(())).boxed_local()
204            }
205            Err(e) => std::future::ready(Err(e)).boxed_local(),
206        }
207    }
208}
209
210/// Trait for executing asynchronous signal handlers on a workflow.
211#[doc(hidden)]
212pub trait ExecutableAsyncSignal<S: SignalDefinition>: WorkflowImplementation {
213    /// Handle an incoming signal with the given input.
214    fn handle(ctx: WorkflowContext<Self>, input: S::Input) -> LocalBoxFuture<'static, ()>;
215
216    /// Dispatch the signal with payload deserialization.
217    fn dispatch(
218        ctx: WorkflowContext<Self>,
219        payloads: Payloads,
220        converter: &PayloadConverter,
221    ) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
222        match deserialize_input::<S::Input>(payloads.payloads, converter) {
223            Ok(input) => Self::handle(ctx, input).map(|()| Ok(())).boxed_local(),
224            Err(e) => std::future::ready(Err(e)).boxed_local(),
225        }
226    }
227}
228
229/// Trait for executing query handlers on a workflow.
230///
231/// Queries are read-only operations that do not mutate workflow state.
232/// They must be synchronous.
233#[doc(hidden)]
234pub trait ExecutableQuery<Q: QueryDefinition>: WorkflowImplementation {
235    /// Handle a query with the given input and return the result.
236    ///
237    /// Queries take `&self` (immutable) and cannot modify workflow state.
238    /// Returning an error will cause the query to fail with that error message.
239    fn handle(
240        &self,
241        ctx: &WorkflowContextView,
242        input: Q::Input,
243    ) -> Result<Q::Output, Box<dyn std::error::Error + Send + Sync>>;
244
245    /// Dispatch the query with payload deserialization and output serialization.
246    fn dispatch(
247        &self,
248        ctx: &WorkflowContextView,
249        payloads: &Payloads,
250        converter: &PayloadConverter,
251    ) -> Result<Payload, WorkflowError> {
252        let input = deserialize_input::<Q::Input>(payloads.payloads.clone(), converter)?;
253        let output = self.handle(ctx, input).map_err(wrap_handler_error)?;
254        serialize_output(&output, converter)
255    }
256}
257
258/// Trait for executing synchronous update handlers on a workflow.
259#[doc(hidden)]
260pub trait ExecutableSyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
261    /// Handle an update with the given input and return the result.
262    /// Returning an error will cause the update to fail with that error message.
263    fn handle(
264        &mut self,
265        ctx: &mut SyncWorkflowContext<Self>,
266        input: U::Input,
267    ) -> Result<U::Output, Box<dyn std::error::Error + Send + Sync>>;
268
269    /// Validate an update before it is applied.
270    fn validate(
271        &self,
272        _ctx: &WorkflowContextView,
273        _input: &U::Input,
274    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
275        Ok(())
276    }
277
278    /// Dispatch the update with payload deserialization and output serialization.
279    fn dispatch(
280        ctx: WorkflowContext<Self>,
281        payloads: Payloads,
282        converter: &PayloadConverter,
283    ) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
284        let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
285            Ok(v) => v,
286            Err(e) => return std::future::ready(Err(e)).boxed_local(),
287        };
288        let converter = converter.clone();
289        let mut sync_ctx = ctx.sync_context();
290        let result = ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
291        match result {
292            Ok(output) => match serialize_output(&output, &converter) {
293                Ok(payload) => std::future::ready(Ok(payload)).boxed_local(),
294                Err(e) => std::future::ready(Err(e)).boxed_local(),
295            },
296            Err(e) => std::future::ready(Err(wrap_handler_error(e))).boxed_local(),
297        }
298    }
299
300    /// Dispatch validation with payload deserialization.
301    fn dispatch_validate(
302        &self,
303        ctx: &WorkflowContextView,
304        payloads: &Payloads,
305        converter: &PayloadConverter,
306    ) -> Result<(), WorkflowError> {
307        let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
308        self.validate(ctx, &input).map_err(wrap_handler_error)
309    }
310}
311
312/// Trait for executing asynchronous update handlers on a workflow.
313#[doc(hidden)]
314pub trait ExecutableAsyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
315    /// Handle an update with the given input and return the result.
316    /// Returning an error will cause the update to fail with that error message.
317    fn handle(
318        ctx: WorkflowContext<Self>,
319        input: U::Input,
320    ) -> LocalBoxFuture<'static, Result<U::Output, Box<dyn std::error::Error + Send + Sync>>>;
321
322    /// Validate an update before it is applied.
323    fn validate(
324        &self,
325        _ctx: &WorkflowContextView,
326        _input: &U::Input,
327    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
328        Ok(())
329    }
330
331    /// Dispatch the update with payload deserialization and output serialization.
332    fn dispatch(
333        ctx: WorkflowContext<Self>,
334        payloads: Payloads,
335        converter: &PayloadConverter,
336    ) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
337        let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
338            Ok(v) => v,
339            Err(e) => return std::future::ready(Err(e)).boxed_local(),
340        };
341        let converter = converter.clone();
342        async move {
343            let output = Self::handle(ctx, input).await.map_err(wrap_handler_error)?;
344            serialize_output(&output, &converter)
345        }
346        .boxed_local()
347    }
348
349    /// Dispatch validation with payload deserialization.
350    fn dispatch_validate(
351        &self,
352        ctx: &WorkflowContextView,
353        payloads: &Payloads,
354        converter: &PayloadConverter,
355    ) -> Result<(), WorkflowError> {
356        let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
357        self.validate(ctx, &input).map_err(wrap_handler_error)
358    }
359}
360
361/// Data passed to handler dispatch methods (signals, updates, queries).
362pub(crate) struct DispatchData<'a> {
363    pub(crate) payloads: Payloads,
364    pub(crate) headers: HashMap<String, Payload>,
365    pub(crate) converter: &'a PayloadConverter,
366}
367
368/// Trait implemented by workflow types to enable registration with workers.
369///
370/// This trait is automatically generated by the `#[workflow_methods]` macro.
371#[doc(hidden)]
372pub trait WorkflowImplementer: WorkflowImplementation {
373    /// Register this workflow and all its handlers with the given definitions container.
374    fn register_all(defs: &mut WorkflowDefinitions);
375}
376
377/// Type-erased trait for workflow execution instances.
378pub(crate) trait DynWorkflowExecution {
379    /// Poll the run future.
380    fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>>;
381
382    /// Validate an update request. Returns `None` if no handler.
383    fn validate_update(&self, name: &str, data: &DispatchData)
384    -> Option<Result<(), WorkflowError>>;
385
386    /// Start an update handler. Returns `None` if no handler for that name.
387    fn start_update(
388        &mut self,
389        name: &str,
390        data: DispatchData,
391    ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>>;
392
393    /// Dispatch a signal by name. Returns `None` if no handler.
394    fn dispatch_signal(
395        &mut self,
396        name: &str,
397        data: DispatchData,
398    ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>>;
399
400    /// Dispatch a query by name. Returns `None` if no handler.
401    fn dispatch_query(
402        &self,
403        name: &str,
404        data: DispatchData,
405    ) -> Option<Result<Payload, WorkflowError>>;
406}
407
408/// Manages a workflow execution, holding the context and run future.
409pub(crate) struct WorkflowExecution<W: WorkflowImplementation> {
410    ctx: WorkflowContext<W>,
411    run_future: Fuse<LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>>,
412}
413
414impl<W: WorkflowImplementation> WorkflowExecution<W>
415where
416    <W::Run as WorkflowDefinition>::Input: Send,
417{
418    /// Create a new workflow execution using the workflow's `init` method.
419    pub(crate) fn new(
420        base_ctx: BaseWorkflowContext,
421        init_input: Option<<W::Run as WorkflowDefinition>::Input>,
422        run_input: Option<<W::Run as WorkflowDefinition>::Input>,
423    ) -> Self {
424        let view = base_ctx.view();
425        let workflow = W::init(view, init_input);
426        Self::new_with_workflow(workflow, base_ctx, run_input)
427    }
428
429    /// Create a new workflow execution from an already-created workflow instance.
430    pub(crate) fn new_with_workflow(
431        workflow: W,
432        base_ctx: BaseWorkflowContext,
433        run_input: Option<<W::Run as WorkflowDefinition>::Input>,
434    ) -> Self {
435        let workflow = Rc::new(RefCell::new(workflow));
436        let ctx = WorkflowContext::from_base(base_ctx, workflow);
437        let run_future = W::run(ctx.clone(), run_input).fuse();
438
439        Self { ctx, run_future }
440    }
441}
442
443impl<W: WorkflowImplementation> DynWorkflowExecution for WorkflowExecution<W> {
444    fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>> {
445        Pin::new(&mut self.run_future).poll(cx)
446    }
447
448    fn validate_update(
449        &self,
450        name: &str,
451        data: &DispatchData,
452    ) -> Option<Result<(), WorkflowError>> {
453        let view = self.ctx.view();
454        self.ctx
455            .state(|wf| wf.validate_update(view, name, &data.payloads, data.converter))
456    }
457
458    fn start_update(
459        &mut self,
460        name: &str,
461        data: DispatchData,
462    ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
463        let ctx = self.ctx.with_headers(data.headers);
464        W::dispatch_update(ctx, name, data.payloads, data.converter)
465    }
466
467    fn dispatch_signal(
468        &mut self,
469        name: &str,
470        data: DispatchData,
471    ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
472        let ctx = self.ctx.with_headers(data.headers);
473        W::dispatch_signal(ctx, name, data.payloads, data.converter)
474    }
475
476    fn dispatch_query(
477        &self,
478        name: &str,
479        data: DispatchData,
480    ) -> Option<Result<Payload, WorkflowError>> {
481        let view = self.ctx.view();
482        self.ctx
483            .state(|wf| wf.dispatch_query(view, name, &data.payloads, data.converter))
484    }
485}
486
487/// Type alias for workflow execution factory functions.
488///
489/// Creates a new `WorkflowExecution` instance from the input payloads and context.
490pub(crate) type WorkflowExecutionFactory = Arc<
491    dyn Fn(
492            Vec<Payload>,
493            PayloadConverter,
494            BaseWorkflowContext,
495        ) -> Result<Box<dyn DynWorkflowExecution>, PayloadConversionError>
496        + Send
497        + Sync,
498>;
499
500/// Contains workflow registrations in a form ready for execution by workers.
501#[derive(Default, Clone)]
502pub struct WorkflowDefinitions {
503    /// Maps workflow type name to execution factories
504    workflows: HashMap<&'static str, WorkflowExecutionFactory>,
505}
506
507impl WorkflowDefinitions {
508    /// Creates a new empty `WorkflowDefinitions`.
509    pub fn new() -> Self {
510        Self::default()
511    }
512
513    /// Register a workflow implementation.
514    pub fn register_workflow<W: WorkflowImplementer>(&mut self) -> &mut Self {
515        W::register_all(self);
516        self
517    }
518
519    /// Register a specific workflow's run method.
520    #[doc(hidden)]
521    pub fn register_workflow_run<W: WorkflowImplementation>(&mut self) -> &mut Self
522    where
523        <W::Run as WorkflowDefinition>::Input: Send,
524    {
525        let workflow_name = W::name();
526        let factory: WorkflowExecutionFactory =
527            Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
528                let ser_ctx = SerializationContext {
529                    data: &SerializationContextData::Workflow,
530                    converter: &converter,
531                };
532                let input = converter.from_payloads(&ser_ctx, payloads)?;
533                let (init_input, run_input) = if W::INIT_TAKES_INPUT {
534                    (Some(input), None)
535                } else {
536                    (None, Some(input))
537                };
538                Ok(
539                    Box::new(WorkflowExecution::<W>::new(base_ctx, init_input, run_input))
540                        as Box<dyn DynWorkflowExecution>,
541                )
542            });
543        self.workflows.insert(workflow_name, factory);
544        self
545    }
546
547    /// Register a workflow with a custom factory for instance creation.
548    pub fn register_workflow_run_with_factory<W, F>(&mut self, user_factory: F) -> &mut Self
549    where
550        W: WorkflowImplementation,
551        <W::Run as WorkflowDefinition>::Input: Send,
552        F: Fn() -> W + Send + Sync + 'static,
553    {
554        assert!(
555            !W::HAS_INIT,
556            "Workflows registered with a factory must not define an #[init] method. \
557             The factory replaces init for instance creation."
558        );
559
560        let workflow_name = W::name();
561        let user_factory = Arc::new(user_factory);
562        let factory: WorkflowExecutionFactory =
563            Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
564                let ser_ctx = SerializationContext {
565                    data: &SerializationContextData::Workflow,
566                    converter: &converter,
567                };
568                let input: <W::Run as WorkflowDefinition>::Input =
569                    converter.from_payloads(&ser_ctx, payloads)?;
570
571                // User factory creates the instance - input always goes to run()
572                let workflow = user_factory();
573                Ok(Box::new(WorkflowExecution::<W>::new_with_workflow(
574                    workflow,
575                    base_ctx,
576                    Some(input),
577                )) as Box<dyn DynWorkflowExecution>)
578            });
579
580        self.workflows.insert(workflow_name, factory);
581        self
582    }
583
584    /// Check if any workflows are registered.
585    pub fn is_empty(&self) -> bool {
586        self.workflows.is_empty()
587    }
588
589    /// Get the workflow execution factory for a given workflow type.
590    pub(crate) fn get_workflow(&self, workflow_type: &str) -> Option<WorkflowExecutionFactory> {
591        self.workflows.get(workflow_type).cloned()
592    }
593
594    /// Returns an iterator over registered workflow type names.
595    pub fn workflow_types(&self) -> impl Iterator<Item = &'static str> + '_ {
596        self.workflows.keys().copied()
597    }
598}
599
600impl Debug for WorkflowDefinitions {
601    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602        f.debug_struct("WorkflowDefinitions")
603            .field("workflows", &self.workflows.keys().collect::<Vec<_>>())
604            .finish()
605    }
606}
607
608/// Deserialize handler input from payloads.
609pub fn deserialize_input<I: TemporalDeserializable + 'static>(
610    payloads: Vec<Payload>,
611    converter: &PayloadConverter,
612) -> Result<I, WorkflowError> {
613    let ctx = SerializationContext {
614        data: &SerializationContextData::Workflow,
615        converter,
616    };
617    converter.from_payloads(&ctx, payloads).map_err(Into::into)
618}
619
620/// Serialize handler output to a payload.
621pub fn serialize_output<O: TemporalSerializable + 'static>(
622    output: &O,
623    converter: &PayloadConverter,
624) -> Result<Payload, WorkflowError> {
625    let ctx = SerializationContext {
626        data: &SerializationContextData::Workflow,
627        converter,
628    };
629    converter.to_payload(&ctx, output).map_err(Into::into)
630}
631
632/// Wrap a handler error into WorkflowError.
633pub fn wrap_handler_error(e: Box<dyn std::error::Error + Send + Sync>) -> WorkflowError {
634    WorkflowError::Execution(anyhow::anyhow!(e))
635}
636
637/// Serialize a workflow result value to a payload.
638pub fn serialize_result<T: TemporalSerializable + 'static>(
639    result: T,
640    converter: &PayloadConverter,
641) -> Result<Payload, WorkflowError> {
642    serialize_output(&result, converter)
643}