1use async_trait::async_trait;
2use chrono::Utc;
3
4use crate::models::ExecutionResult;
5use crate::models::schema::WorkflowSchema;
6use crate::traits::step::{StepBody, StepExecutionContext};
7
8#[derive(Default)]
16pub struct SubWorkflowStep {
17 pub workflow_id: String,
19 pub version: u32,
21 pub inputs: serde_json::Value,
23 pub output_keys: Vec<String>,
25 pub input_schema: Option<WorkflowSchema>,
27 pub output_schema: Option<WorkflowSchema>,
29}
30
31#[async_trait]
32impl StepBody for SubWorkflowStep {
33 async fn run(&mut self, context: &StepExecutionContext<'_>) -> crate::Result<ExecutionResult> {
34 if let Some(event_data) = &context.execution_pointer.event_data {
36 let mut output = serde_json::Map::new();
38
39 let child_data = event_data
41 .get("data")
42 .cloned()
43 .unwrap_or(serde_json::Value::Null);
44
45 if self.output_keys.is_empty() {
46 if let serde_json::Value::Object(map) = child_data {
48 output = map;
49 }
50 } else {
51 for key in &self.output_keys {
53 if let Some(val) = child_data.get(key) {
54 output.insert(key.clone(), val.clone());
55 }
56 }
57 }
58
59 let output_value = serde_json::Value::Object(output);
60
61 if let Some(ref schema) = self.output_schema
63 && let Err(errors) = schema.validate_outputs(&output_value)
64 {
65 return Err(crate::WfeError::StepExecution(format!(
66 "SubWorkflow output validation failed: {}",
67 errors.join("; ")
68 )));
69 }
70
71 let mut result = ExecutionResult::next();
72 result.output_data = Some(output_value);
73 return Ok(result);
74 }
75
76 if self.workflow_id.is_empty()
78 && let Some(config) = &context.step.step_config
79 {
80 if let Some(wf_id) = config.get("workflow_id").and_then(|v| v.as_str()) {
81 self.workflow_id = wf_id.to_string();
82 }
83 if let Some(ver) = config.get("version").and_then(|v| v.as_u64()) {
84 self.version = ver as u32;
85 }
86 if let Some(inputs) = config.get("inputs") {
87 self.inputs = inputs.clone();
88 }
89 if let Some(keys) = config.get("output_keys").and_then(|v| v.as_array()) {
90 self.output_keys = keys
91 .iter()
92 .filter_map(|v| v.as_str().map(|s| s.to_string()))
93 .collect();
94 }
95 }
96
97 if let Some(ref schema) = self.input_schema
99 && let Err(errors) = schema.validate_inputs(&self.inputs)
100 {
101 return Err(crate::WfeError::StepExecution(format!(
102 "SubWorkflow input validation failed: {}",
103 errors.join("; ")
104 )));
105 }
106
107 let host = context.host_context.ok_or_else(|| {
108 crate::WfeError::StepExecution(
109 "SubWorkflowStep requires a host context to start child workflows".to_string(),
110 )
111 })?;
112
113 let child_data = if !self.inputs.is_null() {
120 self.inputs.clone()
121 } else if context.workflow.data.is_object() {
122 context.workflow.data.clone()
123 } else {
124 serde_json::json!({})
125 };
126 let parent_root = context
132 .workflow
133 .root_workflow_id
134 .clone()
135 .or_else(|| Some(context.workflow.id.clone()));
136 let child_instance_id = host
137 .start_workflow(&self.workflow_id, self.version, child_data, parent_root)
138 .await?;
139
140 Ok(ExecutionResult::wait_for_event(
141 "wfe.workflow.completed",
142 child_instance_id,
143 Utc::now(),
144 ))
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::models::ExecutionPointer;
152 use crate::models::schema::SchemaType;
153 use crate::primitives::test_helpers::*;
154 use crate::traits::step::HostContext;
155 use serde_json::json;
156 use std::collections::HashMap;
157 use std::sync::Mutex;
158
159 struct MockHostContext {
161 started: Mutex<Vec<(String, u32, serde_json::Value)>>,
162 result_id: String,
163 }
164
165 impl MockHostContext {
166 fn new(result_id: &str) -> Self {
167 Self {
168 started: Mutex::new(Vec::new()),
169 result_id: result_id.to_string(),
170 }
171 }
172
173 fn calls(&self) -> Vec<(String, u32, serde_json::Value)> {
174 self.started.lock().unwrap().clone()
175 }
176 }
177
178 impl HostContext for MockHostContext {
179 fn start_workflow(
180 &self,
181 definition_id: &str,
182 version: u32,
183 data: serde_json::Value,
184 _parent_root_workflow_id: Option<String>,
185 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
186 {
187 let def_id = definition_id.to_string();
188 let result_id = self.result_id.clone();
189 Box::pin(async move {
190 self.started.lock().unwrap().push((def_id, version, data));
191 Ok(result_id)
192 })
193 }
194 }
195
196 struct FailingHostContext;
198
199 impl HostContext for FailingHostContext {
200 fn start_workflow(
201 &self,
202 _definition_id: &str,
203 _version: u32,
204 _data: serde_json::Value,
205 _parent_root_workflow_id: Option<String>,
206 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
207 {
208 Box::pin(async {
209 Err(crate::WfeError::StepExecution(
210 "failed to start child".to_string(),
211 ))
212 })
213 }
214 }
215
216 fn make_context_with_host<'a>(
217 pointer: &'a ExecutionPointer,
218 step: &'a crate::models::WorkflowStep,
219 workflow: &'a crate::models::WorkflowInstance,
220 host: &'a dyn HostContext,
221 ) -> StepExecutionContext<'a> {
222 StepExecutionContext {
223 definition: None,
224 item: None,
225 execution_pointer: pointer,
226 persistence_data: pointer.persistence_data.as_ref(),
227 step,
228 workflow,
229 cancellation_token: tokio_util::sync::CancellationToken::new(),
230 host_context: Some(host),
231 log_sink: None,
232 }
233 }
234
235 #[tokio::test]
236 async fn first_call_starts_child_and_waits() {
237 let host = MockHostContext::new("child-123");
238 let mut step = SubWorkflowStep {
239 workflow_id: "child-def".into(),
240 version: 1,
241 inputs: json!({"x": 10}),
242 ..Default::default()
243 };
244
245 let pointer = ExecutionPointer::new(0);
246 let wf_step = default_step();
247 let workflow = default_workflow();
248 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
249
250 let result = step.run(&ctx).await.unwrap();
251 assert!(!result.proceed);
252 assert_eq!(result.event_name.as_deref(), Some("wfe.workflow.completed"));
253 assert_eq!(result.event_key.as_deref(), Some("child-123"));
254 assert!(result.event_as_of.is_some());
255
256 let calls = host.calls();
257 assert_eq!(calls.len(), 1);
258 assert_eq!(calls[0].0, "child-def");
259 assert_eq!(calls[0].1, 1);
260 assert_eq!(calls[0].2, json!({"x": 10}));
261 }
262
263 #[tokio::test]
264 async fn child_completed_proceeds_with_output() {
265 let mut step = SubWorkflowStep {
266 workflow_id: "child-def".into(),
267 version: 1,
268 inputs: json!({}),
269 output_keys: vec!["result".into()],
270 ..Default::default()
271 };
272
273 let mut pointer = ExecutionPointer::new(0);
274 pointer.event_data = Some(json!({
275 "status": "Complete",
276 "data": {"result": "success", "extra": "ignored"}
277 }));
278 let wf_step = default_step();
279 let workflow = default_workflow();
280 let ctx = make_context(&pointer, &wf_step, &workflow);
281
282 let result = step.run(&ctx).await.unwrap();
283 assert!(result.proceed);
284 assert_eq!(result.output_data, Some(json!({"result": "success"})));
285 }
286
287 #[tokio::test]
288 async fn child_completed_no_output_keys_passes_all() {
289 let mut step = SubWorkflowStep {
290 workflow_id: "child-def".into(),
291 version: 1,
292 inputs: json!({}),
293 output_keys: vec![],
294 ..Default::default()
295 };
296
297 let mut pointer = ExecutionPointer::new(0);
298 pointer.event_data = Some(json!({
299 "status": "Complete",
300 "data": {"a": 1, "b": 2}
301 }));
302 let wf_step = default_step();
303 let workflow = default_workflow();
304 let ctx = make_context(&pointer, &wf_step, &workflow);
305
306 let result = step.run(&ctx).await.unwrap();
307 assert!(result.proceed);
308 assert_eq!(result.output_data, Some(json!({"a": 1, "b": 2})));
309 }
310
311 #[tokio::test]
312 async fn no_host_context_errors() {
313 let mut step = SubWorkflowStep {
314 workflow_id: "child-def".into(),
315 version: 1,
316 inputs: json!({}),
317 ..Default::default()
318 };
319
320 let pointer = ExecutionPointer::new(0);
321 let wf_step = default_step();
322 let workflow = default_workflow();
323 let ctx = make_context(&pointer, &wf_step, &workflow);
324
325 let err = step.run(&ctx).await.unwrap_err();
326 assert!(err.to_string().contains("host context"));
327 }
328
329 #[tokio::test]
330 async fn input_validation_failure() {
331 let host = MockHostContext::new("child-123");
332 let mut step = SubWorkflowStep {
333 workflow_id: "child-def".into(),
334 version: 1,
335 inputs: json!({"name": 42}), input_schema: Some(WorkflowSchema {
337 inputs: HashMap::from([("name".into(), SchemaType::String)]),
338 outputs: HashMap::new(),
339 }),
340 ..Default::default()
341 };
342
343 let pointer = ExecutionPointer::new(0);
344 let wf_step = default_step();
345 let workflow = default_workflow();
346 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
347
348 let err = step.run(&ctx).await.unwrap_err();
349 assert!(err.to_string().contains("input validation failed"));
350 assert!(host.calls().is_empty());
351 }
352
353 #[tokio::test]
354 async fn output_validation_failure() {
355 let mut step = SubWorkflowStep {
356 workflow_id: "child-def".into(),
357 version: 1,
358 inputs: json!({}),
359 output_keys: vec![],
360 output_schema: Some(WorkflowSchema {
361 inputs: HashMap::new(),
362 outputs: HashMap::from([("result".into(), SchemaType::String)]),
363 }),
364 ..Default::default()
365 };
366
367 let mut pointer = ExecutionPointer::new(0);
368 pointer.event_data = Some(json!({
369 "status": "Complete",
370 "data": {"result": 42}
371 }));
372 let wf_step = default_step();
373 let workflow = default_workflow();
374 let ctx = make_context(&pointer, &wf_step, &workflow);
375
376 let err = step.run(&ctx).await.unwrap_err();
377 assert!(err.to_string().contains("output validation failed"));
378 }
379
380 #[tokio::test]
381 async fn input_validation_passes_then_starts_child() {
382 let host = MockHostContext::new("child-456");
383 let mut step = SubWorkflowStep {
384 workflow_id: "child-def".into(),
385 version: 2,
386 inputs: json!({"name": "Alice"}),
387 input_schema: Some(WorkflowSchema {
388 inputs: HashMap::from([("name".into(), SchemaType::String)]),
389 outputs: HashMap::new(),
390 }),
391 ..Default::default()
392 };
393
394 let pointer = ExecutionPointer::new(0);
395 let wf_step = default_step();
396 let workflow = default_workflow();
397 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
398
399 let result = step.run(&ctx).await.unwrap();
400 assert!(!result.proceed);
401 assert_eq!(result.event_key.as_deref(), Some("child-456"));
402 assert_eq!(host.calls().len(), 1);
403 }
404
405 #[tokio::test]
406 async fn host_start_workflow_error_propagates() {
407 let host = FailingHostContext;
408 let mut step = SubWorkflowStep {
409 workflow_id: "child-def".into(),
410 version: 1,
411 inputs: json!({}),
412 ..Default::default()
413 };
414
415 let pointer = ExecutionPointer::new(0);
416 let wf_step = default_step();
417 let workflow = default_workflow();
418 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
419
420 let err = step.run(&ctx).await.unwrap_err();
421 assert!(err.to_string().contains("failed to start child"));
422 }
423
424 #[tokio::test]
425 async fn event_data_without_data_field_returns_empty_output() {
426 let mut step = SubWorkflowStep {
427 workflow_id: "child-def".into(),
428 version: 1,
429 inputs: json!({}),
430 output_keys: vec!["foo".into()],
431 ..Default::default()
432 };
433
434 let mut pointer = ExecutionPointer::new(0);
435 pointer.event_data = Some(json!({"status": "Complete"}));
436 let wf_step = default_step();
437 let workflow = default_workflow();
438 let ctx = make_context(&pointer, &wf_step, &workflow);
439
440 let result = step.run(&ctx).await.unwrap();
441 assert!(result.proceed);
442 assert_eq!(result.output_data, Some(json!({})));
443 }
444
445 #[tokio::test]
446 async fn default_step_has_empty_fields() {
447 let step = SubWorkflowStep::default();
448 assert!(step.workflow_id.is_empty());
449 assert_eq!(step.version, 0);
450 assert_eq!(step.inputs, json!(null));
451 assert!(step.output_keys.is_empty());
452 assert!(step.input_schema.is_none());
453 assert!(step.output_schema.is_none());
454 }
455}