1use crate::{
47 BaseWorkflowContext, SyncWorkflowContext, WorkflowContext, WorkflowContextView,
48 WorkflowTermination,
49};
50use futures_util::future::{Fuse, FutureExt, LocalBoxFuture};
51use std::{
52 cell::RefCell,
53 collections::HashMap,
54 fmt::Debug,
55 pin::Pin,
56 rc::Rc,
57 sync::Arc,
58 task::{Context as TaskContext, Poll},
59};
60use temporalio_common::{
61 QueryDefinition, SignalDefinition, UpdateDefinition, WorkflowDefinition,
62 data_converters::{
63 GenericPayloadConverter, PayloadConversionError, PayloadConverter, SerializationContext,
64 SerializationContextData, TemporalDeserializable, TemporalSerializable,
65 },
66 protos::temporal::api::{
67 common::v1::{Payload, Payloads},
68 failure::v1::Failure,
69 },
70};
71
72#[derive(Debug, thiserror::Error)]
74pub enum WorkflowError {
75 #[error("Payload conversion error: {0}")]
77 PayloadConversion(#[from] PayloadConversionError),
78
79 #[error("Workflow execution error: {0}")]
81 Execution(#[from] anyhow::Error),
82}
83
84impl From<WorkflowError> for Failure {
85 fn from(err: WorkflowError) -> Self {
86 Failure {
87 message: err.to_string(),
88 ..Default::default()
89 }
90 }
91}
92
93#[doc(hidden)]
98pub trait WorkflowImplementation: Sized + 'static {
99 type Run: WorkflowDefinition;
101
102 const HAS_INIT: bool;
105
106 const INIT_TAKES_INPUT: bool;
109
110 fn name() -> &'static str;
112
113 fn init(
118 ctx: WorkflowContextView,
119 input: Option<<Self::Run as WorkflowDefinition>::Input>,
120 ) -> Self;
121
122 fn run(
126 ctx: WorkflowContext<Self>,
127 input: Option<<Self::Run as WorkflowDefinition>::Input>,
128 ) -> LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>;
129
130 fn dispatch_update(
132 _ctx: WorkflowContext<Self>,
133 _name: &str,
134 _payloads: Payloads,
135 _converter: &PayloadConverter,
136 ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
137 None
138 }
139
140 fn validate_update(
145 &self,
146 _ctx: WorkflowContextView,
147 _name: &str,
148 _payloads: &Payloads,
149 _converter: &PayloadConverter,
150 ) -> Option<Result<(), WorkflowError>> {
151 None
152 }
153
154 fn dispatch_signal(
160 _ctx: WorkflowContext<Self>,
161 _name: &str,
162 _payloads: Payloads,
163 _converter: &PayloadConverter,
164 ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
165 None
166 }
167
168 fn dispatch_query(
173 &self,
174 _ctx: WorkflowContextView,
175 _name: &str,
176 _payloads: &Payloads,
177 _converter: &PayloadConverter,
178 ) -> Option<Result<Payload, WorkflowError>> {
179 None
180 }
181}
182
183#[doc(hidden)]
189pub trait ExecutableSyncSignal<S: SignalDefinition>: WorkflowImplementation {
190 fn handle(&mut self, ctx: &mut SyncWorkflowContext<Self>, input: S::Input);
192
193 fn dispatch(
195 ctx: WorkflowContext<Self>,
196 payloads: Payloads,
197 converter: &PayloadConverter,
198 ) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
199 match deserialize_input::<S::Input>(payloads.payloads, converter) {
200 Ok(input) => {
201 let mut sync_ctx = ctx.sync_context();
202 ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
203 std::future::ready(Ok(())).boxed_local()
204 }
205 Err(e) => std::future::ready(Err(e)).boxed_local(),
206 }
207 }
208}
209
210#[doc(hidden)]
212pub trait ExecutableAsyncSignal<S: SignalDefinition>: WorkflowImplementation {
213 fn handle(ctx: WorkflowContext<Self>, input: S::Input) -> LocalBoxFuture<'static, ()>;
215
216 fn dispatch(
218 ctx: WorkflowContext<Self>,
219 payloads: Payloads,
220 converter: &PayloadConverter,
221 ) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
222 match deserialize_input::<S::Input>(payloads.payloads, converter) {
223 Ok(input) => Self::handle(ctx, input).map(|()| Ok(())).boxed_local(),
224 Err(e) => std::future::ready(Err(e)).boxed_local(),
225 }
226 }
227}
228
229#[doc(hidden)]
234pub trait ExecutableQuery<Q: QueryDefinition>: WorkflowImplementation {
235 fn handle(
240 &self,
241 ctx: &WorkflowContextView,
242 input: Q::Input,
243 ) -> Result<Q::Output, Box<dyn std::error::Error + Send + Sync>>;
244
245 fn dispatch(
247 &self,
248 ctx: &WorkflowContextView,
249 payloads: &Payloads,
250 converter: &PayloadConverter,
251 ) -> Result<Payload, WorkflowError> {
252 let input = deserialize_input::<Q::Input>(payloads.payloads.clone(), converter)?;
253 let output = self.handle(ctx, input).map_err(wrap_handler_error)?;
254 serialize_output(&output, converter)
255 }
256}
257
258#[doc(hidden)]
260pub trait ExecutableSyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
261 fn handle(
264 &mut self,
265 ctx: &mut SyncWorkflowContext<Self>,
266 input: U::Input,
267 ) -> Result<U::Output, Box<dyn std::error::Error + Send + Sync>>;
268
269 fn validate(
271 &self,
272 _ctx: &WorkflowContextView,
273 _input: &U::Input,
274 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
275 Ok(())
276 }
277
278 fn dispatch(
280 ctx: WorkflowContext<Self>,
281 payloads: Payloads,
282 converter: &PayloadConverter,
283 ) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
284 let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
285 Ok(v) => v,
286 Err(e) => return std::future::ready(Err(e)).boxed_local(),
287 };
288 let converter = converter.clone();
289 let mut sync_ctx = ctx.sync_context();
290 let result = ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
291 match result {
292 Ok(output) => match serialize_output(&output, &converter) {
293 Ok(payload) => std::future::ready(Ok(payload)).boxed_local(),
294 Err(e) => std::future::ready(Err(e)).boxed_local(),
295 },
296 Err(e) => std::future::ready(Err(wrap_handler_error(e))).boxed_local(),
297 }
298 }
299
300 fn dispatch_validate(
302 &self,
303 ctx: &WorkflowContextView,
304 payloads: &Payloads,
305 converter: &PayloadConverter,
306 ) -> Result<(), WorkflowError> {
307 let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
308 self.validate(ctx, &input).map_err(wrap_handler_error)
309 }
310}
311
312#[doc(hidden)]
314pub trait ExecutableAsyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
315 fn handle(
318 ctx: WorkflowContext<Self>,
319 input: U::Input,
320 ) -> LocalBoxFuture<'static, Result<U::Output, Box<dyn std::error::Error + Send + Sync>>>;
321
322 fn validate(
324 &self,
325 _ctx: &WorkflowContextView,
326 _input: &U::Input,
327 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
328 Ok(())
329 }
330
331 fn dispatch(
333 ctx: WorkflowContext<Self>,
334 payloads: Payloads,
335 converter: &PayloadConverter,
336 ) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
337 let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
338 Ok(v) => v,
339 Err(e) => return std::future::ready(Err(e)).boxed_local(),
340 };
341 let converter = converter.clone();
342 async move {
343 let output = Self::handle(ctx, input).await.map_err(wrap_handler_error)?;
344 serialize_output(&output, &converter)
345 }
346 .boxed_local()
347 }
348
349 fn dispatch_validate(
351 &self,
352 ctx: &WorkflowContextView,
353 payloads: &Payloads,
354 converter: &PayloadConverter,
355 ) -> Result<(), WorkflowError> {
356 let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
357 self.validate(ctx, &input).map_err(wrap_handler_error)
358 }
359}
360
361pub(crate) struct DispatchData<'a> {
363 pub(crate) payloads: Payloads,
364 pub(crate) headers: HashMap<String, Payload>,
365 pub(crate) converter: &'a PayloadConverter,
366}
367
368#[doc(hidden)]
372pub trait WorkflowImplementer: WorkflowImplementation {
373 fn register_all(defs: &mut WorkflowDefinitions);
375}
376
377pub(crate) trait DynWorkflowExecution {
379 fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>>;
381
382 fn validate_update(&self, name: &str, data: &DispatchData)
384 -> Option<Result<(), WorkflowError>>;
385
386 fn start_update(
388 &mut self,
389 name: &str,
390 data: DispatchData,
391 ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>>;
392
393 fn dispatch_signal(
395 &mut self,
396 name: &str,
397 data: DispatchData,
398 ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>>;
399
400 fn dispatch_query(
402 &self,
403 name: &str,
404 data: DispatchData,
405 ) -> Option<Result<Payload, WorkflowError>>;
406}
407
408pub(crate) struct WorkflowExecution<W: WorkflowImplementation> {
410 ctx: WorkflowContext<W>,
411 run_future: Fuse<LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>>,
412}
413
414impl<W: WorkflowImplementation> WorkflowExecution<W>
415where
416 <W::Run as WorkflowDefinition>::Input: Send,
417{
418 pub(crate) fn new(
420 base_ctx: BaseWorkflowContext,
421 init_input: Option<<W::Run as WorkflowDefinition>::Input>,
422 run_input: Option<<W::Run as WorkflowDefinition>::Input>,
423 ) -> Self {
424 let view = base_ctx.view();
425 let workflow = W::init(view, init_input);
426 Self::new_with_workflow(workflow, base_ctx, run_input)
427 }
428
429 pub(crate) fn new_with_workflow(
431 workflow: W,
432 base_ctx: BaseWorkflowContext,
433 run_input: Option<<W::Run as WorkflowDefinition>::Input>,
434 ) -> Self {
435 let workflow = Rc::new(RefCell::new(workflow));
436 let ctx = WorkflowContext::from_base(base_ctx, workflow);
437 let run_future = W::run(ctx.clone(), run_input).fuse();
438
439 Self { ctx, run_future }
440 }
441}
442
443impl<W: WorkflowImplementation> DynWorkflowExecution for WorkflowExecution<W> {
444 fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>> {
445 Pin::new(&mut self.run_future).poll(cx)
446 }
447
448 fn validate_update(
449 &self,
450 name: &str,
451 data: &DispatchData,
452 ) -> Option<Result<(), WorkflowError>> {
453 let view = self.ctx.view();
454 self.ctx
455 .state(|wf| wf.validate_update(view, name, &data.payloads, data.converter))
456 }
457
458 fn start_update(
459 &mut self,
460 name: &str,
461 data: DispatchData,
462 ) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
463 let ctx = self.ctx.with_headers(data.headers);
464 W::dispatch_update(ctx, name, data.payloads, data.converter)
465 }
466
467 fn dispatch_signal(
468 &mut self,
469 name: &str,
470 data: DispatchData,
471 ) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
472 let ctx = self.ctx.with_headers(data.headers);
473 W::dispatch_signal(ctx, name, data.payloads, data.converter)
474 }
475
476 fn dispatch_query(
477 &self,
478 name: &str,
479 data: DispatchData,
480 ) -> Option<Result<Payload, WorkflowError>> {
481 let view = self.ctx.view();
482 self.ctx
483 .state(|wf| wf.dispatch_query(view, name, &data.payloads, data.converter))
484 }
485}
486
487pub(crate) type WorkflowExecutionFactory = Arc<
491 dyn Fn(
492 Vec<Payload>,
493 PayloadConverter,
494 BaseWorkflowContext,
495 ) -> Result<Box<dyn DynWorkflowExecution>, PayloadConversionError>
496 + Send
497 + Sync,
498>;
499
500#[derive(Default, Clone)]
502pub struct WorkflowDefinitions {
503 workflows: HashMap<&'static str, WorkflowExecutionFactory>,
505}
506
507impl WorkflowDefinitions {
508 pub fn new() -> Self {
510 Self::default()
511 }
512
513 pub fn register_workflow<W: WorkflowImplementer>(&mut self) -> &mut Self {
515 W::register_all(self);
516 self
517 }
518
519 #[doc(hidden)]
521 pub fn register_workflow_run<W: WorkflowImplementation>(&mut self) -> &mut Self
522 where
523 <W::Run as WorkflowDefinition>::Input: Send,
524 {
525 let workflow_name = W::name();
526 let factory: WorkflowExecutionFactory =
527 Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
528 let ser_ctx = SerializationContext {
529 data: &SerializationContextData::Workflow,
530 converter: &converter,
531 };
532 let input = converter.from_payloads(&ser_ctx, payloads)?;
533 let (init_input, run_input) = if W::INIT_TAKES_INPUT {
534 (Some(input), None)
535 } else {
536 (None, Some(input))
537 };
538 Ok(
539 Box::new(WorkflowExecution::<W>::new(base_ctx, init_input, run_input))
540 as Box<dyn DynWorkflowExecution>,
541 )
542 });
543 self.workflows.insert(workflow_name, factory);
544 self
545 }
546
547 pub fn register_workflow_run_with_factory<W, F>(&mut self, user_factory: F) -> &mut Self
549 where
550 W: WorkflowImplementation,
551 <W::Run as WorkflowDefinition>::Input: Send,
552 F: Fn() -> W + Send + Sync + 'static,
553 {
554 assert!(
555 !W::HAS_INIT,
556 "Workflows registered with a factory must not define an #[init] method. \
557 The factory replaces init for instance creation."
558 );
559
560 let workflow_name = W::name();
561 let user_factory = Arc::new(user_factory);
562 let factory: WorkflowExecutionFactory =
563 Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
564 let ser_ctx = SerializationContext {
565 data: &SerializationContextData::Workflow,
566 converter: &converter,
567 };
568 let input: <W::Run as WorkflowDefinition>::Input =
569 converter.from_payloads(&ser_ctx, payloads)?;
570
571 let workflow = user_factory();
573 Ok(Box::new(WorkflowExecution::<W>::new_with_workflow(
574 workflow,
575 base_ctx,
576 Some(input),
577 )) as Box<dyn DynWorkflowExecution>)
578 });
579
580 self.workflows.insert(workflow_name, factory);
581 self
582 }
583
584 pub fn is_empty(&self) -> bool {
586 self.workflows.is_empty()
587 }
588
589 pub(crate) fn get_workflow(&self, workflow_type: &str) -> Option<WorkflowExecutionFactory> {
591 self.workflows.get(workflow_type).cloned()
592 }
593
594 pub fn workflow_types(&self) -> impl Iterator<Item = &'static str> + '_ {
596 self.workflows.keys().copied()
597 }
598}
599
600impl Debug for WorkflowDefinitions {
601 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602 f.debug_struct("WorkflowDefinitions")
603 .field("workflows", &self.workflows.keys().collect::<Vec<_>>())
604 .finish()
605 }
606}
607
608pub fn deserialize_input<I: TemporalDeserializable + 'static>(
610 payloads: Vec<Payload>,
611 converter: &PayloadConverter,
612) -> Result<I, WorkflowError> {
613 let ctx = SerializationContext {
614 data: &SerializationContextData::Workflow,
615 converter,
616 };
617 converter.from_payloads(&ctx, payloads).map_err(Into::into)
618}
619
620pub fn serialize_output<O: TemporalSerializable + 'static>(
622 output: &O,
623 converter: &PayloadConverter,
624) -> Result<Payload, WorkflowError> {
625 let ctx = SerializationContext {
626 data: &SerializationContextData::Workflow,
627 converter,
628 };
629 converter.to_payload(&ctx, output).map_err(Into::into)
630}
631
632pub fn wrap_handler_error(e: Box<dyn std::error::Error + Send + Sync>) -> WorkflowError {
634 WorkflowError::Execution(anyhow::anyhow!(e))
635}
636
637pub fn serialize_result<T: TemporalSerializable + 'static>(
639 result: T,
640 converter: &PayloadConverter,
641) -> Result<Payload, WorkflowError> {
642 serialize_output(&result, converter)
643}