1#[doc(inline)]
70pub use crate::__temporal_select as select;
71
72#[doc(inline)]
90pub use crate::__temporal_join as join;
91
92use crate::{
93 BaseWorkflowContext, SyncWorkflowContext, WorkflowContext, WorkflowContextView,
94 WorkflowTermination, workflow_executor::SdkGuardedFuture,
95};
96use futures_util::future::{Fuse, FutureExt, LocalBoxFuture};
97use std::{
98 cell::RefCell,
99 collections::HashMap,
100 fmt::Debug,
101 pin::Pin,
102 rc::Rc,
103 sync::Arc,
104 task::{Context as TaskContext, Poll},
105};
106use temporalio_common::{
107 QueryDefinition, SignalDefinition, UpdateDefinition, WorkflowDefinition,
108 data_converters::{
109 GenericPayloadConverter, PayloadConversionError, PayloadConverter, SerializationContext,
110 SerializationContextData, TemporalDeserializable, TemporalSerializable,
111 },
112 protos::temporal::api::{
113 common::v1::{Payload, Payloads},
114 failure::v1::Failure,
115 },
116};
117
118#[derive(Debug, thiserror::Error)]
120pub enum WorkflowError {
121 #[error("Payload conversion error: {0}")]
123 PayloadConversion(#[from] PayloadConversionError),
124
125 #[error("Workflow execution error: {0}")]
127 Execution(#[from] anyhow::Error),
128}
129
130impl From<WorkflowError> for Failure {
131 fn from(err: WorkflowError) -> Self {
132 Failure {
133 message: err.to_string(),
134 ..Default::default()
135 }
136 }
137}
138
139#[doc(hidden)]
144pub trait WorkflowImplementation: Sized + 'static {
145 type Run: WorkflowDefinition;
147
148 const HAS_INIT: bool;
151
152 const INIT_TAKES_INPUT: bool;
155
156 fn name() -> &'static str;
158
159 fn init(
164 ctx: WorkflowContextView,
165 input: Option<<Self::Run as WorkflowDefinition>::Input>,
166 ) -> Self;
167
168 fn run(
172 ctx: WorkflowContext<Self>,
173 input: Option<<Self::Run as WorkflowDefinition>::Input>,
174 ) -> LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>;
175
176 fn dispatch_update(
178 _ctx: WorkflowContext<Self>,
179 _name: &str,
180 _payloads: Payloads,
181 _converter: &PayloadConverter,
182 ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
183 None
184 }
185
186 fn validate_update(
191 &self,
192 _ctx: WorkflowContextView,
193 _name: &str,
194 _payloads: &Payloads,
195 _converter: &PayloadConverter,
196 ) -> Option<Result<(), WorkflowError>> {
197 None
198 }
199
200 fn dispatch_signal(
206 _ctx: WorkflowContext<Self>,
207 _name: &str,
208 _payloads: Payloads,
209 _converter: &PayloadConverter,
210 ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
211 None
212 }
213
214 fn dispatch_query(
219 &self,
220 _ctx: WorkflowContextView,
221 _name: &str,
222 _payloads: &Payloads,
223 _converter: &PayloadConverter,
224 ) -> Option<Result<Payload, WorkflowError>> {
225 None
226 }
227}
228
229#[doc(hidden)]
235pub trait ExecutableSyncSignal<S: SignalDefinition>: WorkflowImplementation {
236 fn handle(&mut self, ctx: &mut SyncWorkflowContext<Self>, input: S::Input);
238
239 fn dispatch(
241 ctx: WorkflowContext<Self>,
242 payloads: Payloads,
243 converter: &PayloadConverter,
244 ) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
245 match deserialize_input::<S::Input>(payloads.payloads, converter) {
246 Ok(input) => {
247 let mut sync_ctx = ctx.sync_context();
248 ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
249 std::future::ready(Ok(())).boxed_local()
250 }
251 Err(e) => std::future::ready(Err(e)).boxed_local(),
252 }
253 }
254}
255
256#[doc(hidden)]
258pub trait ExecutableAsyncSignal<S: SignalDefinition>: WorkflowImplementation {
259 fn handle(ctx: WorkflowContext<Self>, input: S::Input) -> LocalBoxFuture<'static, ()>;
261
262 fn dispatch(
264 ctx: WorkflowContext<Self>,
265 payloads: Payloads,
266 converter: &PayloadConverter,
267 ) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
268 match deserialize_input::<S::Input>(payloads.payloads, converter) {
269 Ok(input) => Self::handle(ctx, input).map(|()| Ok(())).boxed_local(),
270 Err(e) => std::future::ready(Err(e)).boxed_local(),
271 }
272 }
273}
274
275#[doc(hidden)]
280pub trait ExecutableQuery<Q: QueryDefinition>: WorkflowImplementation {
281 fn handle(
286 &self,
287 ctx: &WorkflowContextView,
288 input: Q::Input,
289 ) -> Result<Q::Output, Box<dyn std::error::Error + Send + Sync>>;
290
291 fn dispatch(
293 &self,
294 ctx: &WorkflowContextView,
295 payloads: &Payloads,
296 converter: &PayloadConverter,
297 ) -> Result<Payload, WorkflowError> {
298 let input = deserialize_input::<Q::Input>(payloads.payloads.clone(), converter)?;
299 let output = self.handle(ctx, input).map_err(wrap_handler_error)?;
300 serialize_output(&output, converter)
301 }
302}
303
304#[doc(hidden)]
306pub trait ExecutableSyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
307 fn handle(
310 &mut self,
311 ctx: &mut SyncWorkflowContext<Self>,
312 input: U::Input,
313 ) -> Result<U::Output, Box<dyn std::error::Error + Send + Sync>>;
314
315 fn validate(
317 &self,
318 _ctx: &WorkflowContextView,
319 _input: &U::Input,
320 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
321 Ok(())
322 }
323
324 fn dispatch(
326 ctx: WorkflowContext<Self>,
327 payloads: Payloads,
328 converter: &PayloadConverter,
329 ) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
330 let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
331 Ok(v) => v,
332 Err(e) => return std::future::ready(Err(e)).boxed_local(),
333 };
334 let converter = converter.clone();
335 let mut sync_ctx = ctx.sync_context();
336 let result = ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
337 match result {
338 Ok(output) => match serialize_output(&output, &converter) {
339 Ok(payload) => std::future::ready(Ok(payload)).boxed_local(),
340 Err(e) => std::future::ready(Err(e)).boxed_local(),
341 },
342 Err(e) => std::future::ready(Err(wrap_handler_error(e))).boxed_local(),
343 }
344 }
345
346 fn dispatch_validate(
348 &self,
349 ctx: &WorkflowContextView,
350 payloads: &Payloads,
351 converter: &PayloadConverter,
352 ) -> Result<(), WorkflowError> {
353 let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
354 self.validate(ctx, &input).map_err(wrap_handler_error)
355 }
356}
357
358#[doc(hidden)]
360pub trait ExecutableAsyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
361 fn handle(
364 ctx: WorkflowContext<Self>,
365 input: U::Input,
366 ) -> LocalBoxFuture<'static, Result<U::Output, Box<dyn std::error::Error + Send + Sync>>>;
367
368 fn validate(
370 &self,
371 _ctx: &WorkflowContextView,
372 _input: &U::Input,
373 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
374 Ok(())
375 }
376
377 fn dispatch(
379 ctx: WorkflowContext<Self>,
380 payloads: Payloads,
381 converter: &PayloadConverter,
382 ) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
383 let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
384 Ok(v) => v,
385 Err(e) => return std::future::ready(Err(e)).boxed_local(),
386 };
387 let converter = converter.clone();
388 async move {
389 let output = Self::handle(ctx, input).await.map_err(wrap_handler_error)?;
390 serialize_output(&output, &converter)
391 }
392 .boxed_local()
393 }
394
395 fn dispatch_validate(
397 &self,
398 ctx: &WorkflowContextView,
399 payloads: &Payloads,
400 converter: &PayloadConverter,
401 ) -> Result<(), WorkflowError> {
402 let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
403 self.validate(ctx, &input).map_err(wrap_handler_error)
404 }
405}
406
407pub(crate) struct DispatchData<'a> {
409 pub(crate) payloads: Payloads,
410 pub(crate) headers: HashMap<String, Payload>,
411 pub(crate) converter: &'a PayloadConverter,
412}
413
414#[doc(hidden)]
418pub trait WorkflowImplementer: WorkflowImplementation {
419 fn register_all(defs: &mut WorkflowDefinitions);
421}
422
423pub(crate) trait DynWorkflowExecution {
425 fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>>;
427
428 fn validate_update(&self, name: &str, data: &DispatchData)
430 -> Option<Result<(), WorkflowError>>;
431
432 fn start_update(
434 &mut self,
435 name: &str,
436 data: DispatchData,
437 ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>>;
438
439 fn dispatch_signal(
441 &mut self,
442 name: &str,
443 data: DispatchData,
444 ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>>;
445
446 fn dispatch_query(
448 &self,
449 name: &str,
450 data: DispatchData,
451 ) -> Option<Result<Payload, WorkflowError>>;
452}
453
454pub(crate) struct WorkflowExecution<W: WorkflowImplementation> {
456 ctx: WorkflowContext<W>,
457 run_future: Fuse<LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>>,
458}
459
460impl<W: WorkflowImplementation> WorkflowExecution<W>
461where
462 <W::Run as WorkflowDefinition>::Input: Send,
463{
464 pub(crate) fn new(
466 base_ctx: BaseWorkflowContext,
467 init_input: Option<<W::Run as WorkflowDefinition>::Input>,
468 run_input: Option<<W::Run as WorkflowDefinition>::Input>,
469 ) -> Self {
470 let view = base_ctx.view();
471 let workflow = W::init(view, init_input);
472 Self::new_with_workflow(workflow, base_ctx, run_input)
473 }
474
475 pub(crate) fn new_with_workflow(
477 workflow: W,
478 base_ctx: BaseWorkflowContext,
479 run_input: Option<<W::Run as WorkflowDefinition>::Input>,
480 ) -> Self {
481 let workflow = Rc::new(RefCell::new(workflow));
482 let ctx = WorkflowContext::from_base(base_ctx, workflow);
483 let run_future = W::run(ctx.clone(), run_input).fuse();
484
485 Self { ctx, run_future }
486 }
487}
488
489impl<W: WorkflowImplementation> DynWorkflowExecution for WorkflowExecution<W> {
490 fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>> {
491 Pin::new(&mut self.run_future).poll(cx)
492 }
493
494 fn validate_update(
495 &self,
496 name: &str,
497 data: &DispatchData,
498 ) -> Option<Result<(), WorkflowError>> {
499 let view = self.ctx.view();
500 self.ctx
501 .state(|wf| wf.validate_update(view, name, &data.payloads, data.converter))
502 }
503
504 fn start_update(
505 &mut self,
506 name: &str,
507 data: DispatchData,
508 ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
509 let ctx = self.ctx.with_headers(data.headers);
510 W::dispatch_update(ctx, name, data.payloads, data.converter)
511 }
512
513 fn dispatch_signal(
514 &mut self,
515 name: &str,
516 data: DispatchData,
517 ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
518 let ctx = self.ctx.with_headers(data.headers);
519 W::dispatch_signal(ctx, name, data.payloads, data.converter)
520 }
521
522 fn dispatch_query(
523 &self,
524 name: &str,
525 data: DispatchData,
526 ) -> Option<Result<Payload, WorkflowError>> {
527 let view = self.ctx.view();
528 self.ctx
529 .state(|wf| wf.dispatch_query(view, name, &data.payloads, data.converter))
530 }
531}
532
533pub(crate) type WorkflowExecutionFactory = Arc<
537 dyn Fn(
538 Vec<Payload>,
539 PayloadConverter,
540 BaseWorkflowContext,
541 ) -> Result<Box<dyn DynWorkflowExecution>, PayloadConversionError>
542 + Send
543 + Sync,
544>;
545
546#[derive(Default, Clone)]
548pub struct WorkflowDefinitions {
549 workflows: HashMap<&'static str, WorkflowExecutionFactory>,
551}
552
553impl WorkflowDefinitions {
554 pub fn new() -> Self {
556 Self::default()
557 }
558
559 pub fn register_workflow<W: WorkflowImplementer>(&mut self) -> &mut Self {
561 W::register_all(self);
562 self
563 }
564
565 #[doc(hidden)]
567 pub fn register_workflow_run<W: WorkflowImplementation>(&mut self) -> &mut Self
568 where
569 <W::Run as WorkflowDefinition>::Input: Send,
570 {
571 let workflow_name = W::name();
572 let factory: WorkflowExecutionFactory =
573 Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
574 let ser_ctx = SerializationContext {
575 data: &SerializationContextData::Workflow,
576 converter: &converter,
577 };
578 let input = converter.from_payloads(&ser_ctx, payloads)?;
579 let (init_input, run_input) = if W::INIT_TAKES_INPUT {
580 (Some(input), None)
581 } else {
582 (None, Some(input))
583 };
584 Ok(
585 Box::new(WorkflowExecution::<W>::new(base_ctx, init_input, run_input))
586 as Box<dyn DynWorkflowExecution>,
587 )
588 });
589 self.workflows.insert(workflow_name, factory);
590 self
591 }
592
593 pub fn register_workflow_run_with_factory<W, F>(&mut self, user_factory: F) -> &mut Self
595 where
596 W: WorkflowImplementation,
597 <W::Run as WorkflowDefinition>::Input: Send,
598 F: Fn() -> W + Send + Sync + 'static,
599 {
600 assert!(
601 !W::HAS_INIT,
602 "Workflows registered with a factory must not define an #[init] method. \
603 The factory replaces init for instance creation."
604 );
605
606 let workflow_name = W::name();
607 let user_factory = Arc::new(user_factory);
608 let factory: WorkflowExecutionFactory =
609 Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
610 let ser_ctx = SerializationContext {
611 data: &SerializationContextData::Workflow,
612 converter: &converter,
613 };
614 let input: <W::Run as WorkflowDefinition>::Input =
615 converter.from_payloads(&ser_ctx, payloads)?;
616
617 let workflow = user_factory();
619 Ok(Box::new(WorkflowExecution::<W>::new_with_workflow(
620 workflow,
621 base_ctx,
622 Some(input),
623 )) as Box<dyn DynWorkflowExecution>)
624 });
625
626 self.workflows.insert(workflow_name, factory);
627 self
628 }
629
630 pub fn is_empty(&self) -> bool {
632 self.workflows.is_empty()
633 }
634
635 pub(crate) fn get_workflow(&self, workflow_type: &str) -> Option<WorkflowExecutionFactory> {
637 self.workflows.get(workflow_type).cloned()
638 }
639
640 pub fn workflow_types(&self) -> impl Iterator<Item = &'static str> + '_ {
642 self.workflows.keys().copied()
643 }
644}
645
646impl Debug for WorkflowDefinitions {
647 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
648 f.debug_struct("WorkflowDefinitions")
649 .field("workflows", &self.workflows.keys().collect::<Vec<_>>())
650 .finish()
651 }
652}
653
654pub fn deserialize_input<I: TemporalDeserializable + 'static>(
656 payloads: Vec<Payload>,
657 converter: &PayloadConverter,
658) -> Result<I, WorkflowError> {
659 let ctx = SerializationContext {
660 data: &SerializationContextData::Workflow,
661 converter,
662 };
663 converter.from_payloads(&ctx, payloads).map_err(Into::into)
664}
665
666pub fn serialize_output<O: TemporalSerializable + 'static>(
668 output: &O,
669 converter: &PayloadConverter,
670) -> Result<Payload, WorkflowError> {
671 let ctx = SerializationContext {
672 data: &SerializationContextData::Workflow,
673 converter,
674 };
675 converter.to_payload(&ctx, output).map_err(Into::into)
676}
677
678pub fn wrap_handler_error(e: Box<dyn std::error::Error + Send + Sync>) -> WorkflowError {
680 WorkflowError::Execution(anyhow::anyhow!(e))
681}
682
683pub fn serialize_result<T: TemporalSerializable + 'static>(
685 result: T,
686 converter: &PayloadConverter,
687) -> Result<Payload, WorkflowError> {
688 serialize_output(&result, converter)
689}
690
691pub fn join_all<I>(iter: I) -> JoinAll<I::Item>
712where
713 I: IntoIterator,
714 I::Item: std::future::Future,
715{
716 JoinAll(SdkGuardedFuture(futures_util::future::join_all(iter)))
717}
718
719pub struct JoinAll<F: std::future::Future>(SdkGuardedFuture<futures_util::future::JoinAll<F>>);
721
722impl<F: std::future::Future> std::future::Future for JoinAll<F> {
723 type Output = Vec<F::Output>;
724
725 fn poll(
726 mut self: std::pin::Pin<&mut Self>,
727 cx: &mut std::task::Context<'_>,
728 ) -> std::task::Poll<Self::Output> {
729 self.0.poll_unpin(cx)
730 }
731}