1use std::{collections::HashMap, fmt::Debug, rc::Rc, sync::Arc};
2
3use anyhow::Context;
4use temporalio_common::{
5 WorkflowDefinition,
6 data_converters::{
7 DataConverter, GenericPayloadConverter, PayloadConverter, SerializationContext,
8 SerializationContextData,
9 },
10 protos::{
11 coresdk::workflow_activation::InitializeWorkflow, temporal::api::common::v1::Payload,
12 },
13};
14use temporalio_workflow::{
15 BaseWorkflowContext,
16 runtime::{
17 entry::WorkflowImplementation,
18 guest::WorkflowInstance,
19 host::WorkflowHost,
20 instance::{GuestWorkflowInstance, instantiate_workflow},
21 types::WorkflowDefinitionDescriptor,
22 },
23};
24
25pub(crate) struct WorkflowExecutionInput {
27 pub namespace: String,
28 pub task_queue: String,
29 pub run_id: String,
30 pub init_workflow_job: InitializeWorkflow,
31 pub data_converter: DataConverter,
32 pub host: Rc<dyn WorkflowHost>,
33}
34
35pub(crate) type WorkflowExecutionFactory = Arc<
37 dyn Fn(WorkflowExecutionInput) -> Result<Box<dyn WorkflowInstance>, anyhow::Error>
38 + Send
39 + Sync,
40>;
41
42#[derive(Clone)]
43struct RegisteredWorkflow {
44 definition: WorkflowDefinitionDescriptor,
45 factory: WorkflowExecutionFactory,
46}
47
48#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
50pub enum WorkflowRegistrationError {
51 #[error("Workflow type {workflow_type} is already registered")]
53 DuplicateWorkflowType {
54 workflow_type: String,
56 },
57
58 #[error(
60 "Workflow type {workflow_type} must not define an #[init] method when registered with a factory"
61 )]
62 FactoryRegistrationWithInit {
63 workflow_type: String,
65 },
66}
67
68#[derive(Default, Clone)]
70pub struct WorkflowDefinitions {
71 workflows: HashMap<String, RegisteredWorkflow>,
72}
73
74impl WorkflowDefinitions {
75 pub fn new() -> Self {
77 Self::default()
78 }
79
80 pub fn register_workflow<W: WorkflowImplementation>(
84 &mut self,
85 ) -> Result<&mut Self, WorkflowRegistrationError>
86 where
87 <W::Run as WorkflowDefinition>::Input: Send,
88 {
89 let factory = Arc::new(move |input| {
90 let (payloads, payload_converter, base_ctx) = workflow_input_parts(input);
91 instantiate_workflow::<W>(payloads, payload_converter, base_ctx)
92 .context("Failed to instantiate native workflow")
93 });
94 self.insert_workflow(W::definition(), factory)?;
95 Ok(self)
96 }
97
98 pub fn register_workflow_run_with_factory<W, F>(
103 &mut self,
104 user_factory: F,
105 ) -> Result<&mut Self, WorkflowRegistrationError>
106 where
107 W: WorkflowImplementation,
108 <W::Run as WorkflowDefinition>::Input: Send,
109 F: Fn() -> W + Send + Sync + 'static,
110 {
111 if W::HAS_INIT {
112 return Err(WorkflowRegistrationError::FactoryRegistrationWithInit {
113 workflow_type: W::definition().workflow_type,
114 });
115 }
116
117 let factory = Arc::new(move |input| {
118 let (payloads, payload_converter, base_ctx) = workflow_input_parts(input);
119 let ser_ctx = SerializationContext {
120 data: &SerializationContextData::Workflow,
121 converter: &payload_converter,
122 };
123 let input: <W::Run as WorkflowDefinition>::Input =
124 payload_converter.from_payloads(&ser_ctx, payloads)?;
125
126 let workflow = user_factory();
127 Ok(Box::new(GuestWorkflowInstance::<W>::new_with_workflow(
128 workflow,
129 base_ctx,
130 Some(input),
131 )) as Box<dyn WorkflowInstance>)
132 });
133
134 self.insert_workflow(W::definition(), factory)?;
135 Ok(self)
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.workflows.is_empty()
141 }
142
143 pub(crate) fn insert_workflow(
144 &mut self,
145 definition: WorkflowDefinitionDescriptor,
146 factory: WorkflowExecutionFactory,
147 ) -> Result<(), WorkflowRegistrationError> {
148 let workflow_type = definition.workflow_type.clone();
149 if self.workflows.contains_key(&workflow_type) {
150 return Err(WorkflowRegistrationError::DuplicateWorkflowType { workflow_type });
151 }
152 self.workflows.insert(
153 workflow_type,
154 RegisteredWorkflow {
155 definition,
156 factory,
157 },
158 );
159 Ok(())
160 }
161
162 pub(crate) fn get_workflow(&self, workflow_type: &str) -> Option<WorkflowExecutionFactory> {
163 self.workflows
164 .get(workflow_type)
165 .map(|wf| wf.factory.clone())
166 }
167
168 pub fn workflow_definitions(&self) -> impl Iterator<Item = &WorkflowDefinitionDescriptor> + '_ {
170 self.workflows.values().map(|wf| &wf.definition)
171 }
172}
173
174fn workflow_input_parts(
175 input: WorkflowExecutionInput,
176) -> (Vec<Payload>, PayloadConverter, BaseWorkflowContext) {
177 let WorkflowExecutionInput {
178 namespace,
179 task_queue,
180 run_id,
181 init_workflow_job,
182 data_converter,
183 host,
184 } = input;
185 let payloads = init_workflow_job.arguments.clone();
186 let payload_converter = data_converter.payload_converter().clone();
187 let base_ctx = BaseWorkflowContext::new(
188 namespace,
189 task_queue,
190 run_id,
191 init_workflow_job,
192 data_converter,
193 host,
194 );
195 (payloads, payload_converter, base_ctx)
196}
197
198impl Debug for WorkflowDefinitions {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("WorkflowDefinitions")
201 .field("workflows", &self.workflows.keys().collect::<Vec<_>>())
202 .finish()
203 }
204}