1use async_trait::async_trait;
2use chrono::Utc;
3
4use crate::models::ExecutionResult;
5use crate::models::schema::WorkflowSchema;
6use crate::traits::step::{StepBody, StepExecutionContext};
7
8#[derive(Default)]
16pub struct SubWorkflowStep {
17 pub workflow_id: String,
19 pub version: u32,
21 pub inputs: serde_json::Value,
23 pub output_keys: Vec<String>,
25 pub input_schema: Option<WorkflowSchema>,
27 pub output_schema: Option<WorkflowSchema>,
29}
30
31#[async_trait]
32impl StepBody for SubWorkflowStep {
33 async fn run(&mut self, context: &StepExecutionContext<'_>) -> crate::Result<ExecutionResult> {
34 if let Some(event_data) = &context.execution_pointer.event_data {
36 let mut output = serde_json::Map::new();
38
39 let child_data = event_data
41 .get("data")
42 .cloned()
43 .unwrap_or(serde_json::Value::Null);
44
45 if self.output_keys.is_empty() {
46 if let serde_json::Value::Object(map) = child_data {
48 output = map;
49 }
50 } else {
51 for key in &self.output_keys {
53 if let Some(val) = child_data.get(key) {
54 output.insert(key.clone(), val.clone());
55 }
56 }
57 }
58
59 let output_value = serde_json::Value::Object(output);
60
61 if let Some(ref schema) = self.output_schema
63 && let Err(errors) = schema.validate_outputs(&output_value)
64 {
65 return Err(crate::WfeError::StepExecution(format!(
66 "SubWorkflow output validation failed: {}",
67 errors.join("; ")
68 )));
69 }
70
71 let mut result = ExecutionResult::next();
72 result.output_data = Some(output_value);
73 return Ok(result);
74 }
75
76 if self.workflow_id.is_empty()
78 && let Some(config) = &context.step.step_config
79 {
80 if let Some(wf_id) = config.get("workflow_id").and_then(|v| v.as_str()) {
81 self.workflow_id = wf_id.to_string();
82 }
83 if let Some(ver) = config.get("version").and_then(|v| v.as_u64()) {
84 self.version = ver as u32;
85 }
86 if let Some(inputs) = config.get("inputs") {
87 self.inputs = inputs.clone();
88 }
89 if let Some(keys) = config.get("output_keys").and_then(|v| v.as_array()) {
90 self.output_keys = keys
91 .iter()
92 .filter_map(|v| v.as_str().map(|s| s.to_string()))
93 .collect();
94 }
95 }
96
97 if let Some(ref schema) = self.input_schema
99 && let Err(errors) = schema.validate_inputs(&self.inputs)
100 {
101 return Err(crate::WfeError::StepExecution(format!(
102 "SubWorkflow input validation failed: {}",
103 errors.join("; ")
104 )));
105 }
106
107 let host = context.host_context.ok_or_else(|| {
108 crate::WfeError::StepExecution(
109 "SubWorkflowStep requires a host context to start child workflows".to_string(),
110 )
111 })?;
112
113 let child_data = if !self.inputs.is_null() {
120 self.inputs.clone()
121 } else if context.workflow.data.is_object() {
122 context.workflow.data.clone()
123 } else {
124 serde_json::json!({})
125 };
126 let parent_root = context
132 .workflow
133 .root_workflow_id
134 .clone()
135 .or_else(|| Some(context.workflow.id.clone()));
136 let child_instance_id = host
137 .start_workflow(&self.workflow_id, self.version, child_data, parent_root)
138 .await?;
139
140 Ok(ExecutionResult::wait_for_event(
141 "wfe.workflow.completed",
142 child_instance_id,
143 Utc::now(),
144 ))
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::models::ExecutionPointer;
152 use crate::models::schema::SchemaType;
153 use crate::primitives::test_helpers::*;
154 use crate::traits::step::HostContext;
155 use serde_json::json;
156 use std::collections::HashMap;
157 use std::sync::Mutex;
158
159 struct MockHostContext {
161 started: Mutex<Vec<(String, u32, serde_json::Value)>>,
162 result_id: String,
163 }
164
165 impl MockHostContext {
166 fn new(result_id: &str) -> Self {
167 Self {
168 started: Mutex::new(Vec::new()),
169 result_id: result_id.to_string(),
170 }
171 }
172
173 fn calls(&self) -> Vec<(String, u32, serde_json::Value)> {
174 self.started.lock().unwrap().clone()
175 }
176 }
177
178 impl HostContext for MockHostContext {
179 fn start_workflow(
180 &self,
181 definition_id: &str,
182 version: u32,
183 data: serde_json::Value,
184 _parent_root_workflow_id: Option<String>,
185 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
186 {
187 let def_id = definition_id.to_string();
188 let result_id = self.result_id.clone();
189 Box::pin(async move {
190 self.started.lock().unwrap().push((def_id, version, data));
191 Ok(result_id)
192 })
193 }
194 }
195
196 struct FailingHostContext;
198
199 impl HostContext for FailingHostContext {
200 fn start_workflow(
201 &self,
202 _definition_id: &str,
203 _version: u32,
204 _data: serde_json::Value,
205 _parent_root_workflow_id: Option<String>,
206 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
207 {
208 Box::pin(async {
209 Err(crate::WfeError::StepExecution(
210 "failed to start child".to_string(),
211 ))
212 })
213 }
214 }
215
216 fn make_context_with_host<'a>(
217 pointer: &'a ExecutionPointer,
218 step: &'a crate::models::WorkflowStep,
219 workflow: &'a crate::models::WorkflowInstance,
220 host: &'a dyn HostContext,
221 ) -> StepExecutionContext<'a> {
222 StepExecutionContext {
223 definition: None,
224 item: None,
225 execution_pointer: pointer,
226 persistence_data: pointer.persistence_data.as_ref(),
227 step,
228 workflow,
229 cancellation_token: tokio_util::sync::CancellationToken::new(),
230 host_context: Some(host),
231 log_sink: None,
232 artifact_store: None,
233 artifact_volume: None,
234 artifact_package: None,
235 persistence: None,
236 }
237 }
238
239 #[tokio::test]
240 async fn first_call_starts_child_and_waits() {
241 let host = MockHostContext::new("child-123");
242 let mut step = SubWorkflowStep {
243 workflow_id: "child-def".into(),
244 version: 1,
245 inputs: json!({"x": 10}),
246 ..Default::default()
247 };
248
249 let pointer = ExecutionPointer::new(0);
250 let wf_step = default_step();
251 let workflow = default_workflow();
252 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
253
254 let result = step.run(&ctx).await.unwrap();
255 assert!(!result.proceed);
256 assert_eq!(result.event_name.as_deref(), Some("wfe.workflow.completed"));
257 assert_eq!(result.event_key.as_deref(), Some("child-123"));
258 assert!(result.event_as_of.is_some());
259
260 let calls = host.calls();
261 assert_eq!(calls.len(), 1);
262 assert_eq!(calls[0].0, "child-def");
263 assert_eq!(calls[0].1, 1);
264 assert_eq!(calls[0].2, json!({"x": 10}));
265 }
266
267 #[tokio::test]
268 async fn child_completed_proceeds_with_output() {
269 let mut step = SubWorkflowStep {
270 workflow_id: "child-def".into(),
271 version: 1,
272 inputs: json!({}),
273 output_keys: vec!["result".into()],
274 ..Default::default()
275 };
276
277 let mut pointer = ExecutionPointer::new(0);
278 pointer.event_data = Some(json!({
279 "status": "Complete",
280 "data": {"result": "success", "extra": "ignored"}
281 }));
282 let wf_step = default_step();
283 let workflow = default_workflow();
284 let ctx = make_context(&pointer, &wf_step, &workflow);
285
286 let result = step.run(&ctx).await.unwrap();
287 assert!(result.proceed);
288 assert_eq!(result.output_data, Some(json!({"result": "success"})));
289 }
290
291 #[tokio::test]
292 async fn child_completed_no_output_keys_passes_all() {
293 let mut step = SubWorkflowStep {
294 workflow_id: "child-def".into(),
295 version: 1,
296 inputs: json!({}),
297 output_keys: vec![],
298 ..Default::default()
299 };
300
301 let mut pointer = ExecutionPointer::new(0);
302 pointer.event_data = Some(json!({
303 "status": "Complete",
304 "data": {"a": 1, "b": 2}
305 }));
306 let wf_step = default_step();
307 let workflow = default_workflow();
308 let ctx = make_context(&pointer, &wf_step, &workflow);
309
310 let result = step.run(&ctx).await.unwrap();
311 assert!(result.proceed);
312 assert_eq!(result.output_data, Some(json!({"a": 1, "b": 2})));
313 }
314
315 #[tokio::test]
316 async fn no_host_context_errors() {
317 let mut step = SubWorkflowStep {
318 workflow_id: "child-def".into(),
319 version: 1,
320 inputs: json!({}),
321 ..Default::default()
322 };
323
324 let pointer = ExecutionPointer::new(0);
325 let wf_step = default_step();
326 let workflow = default_workflow();
327 let ctx = make_context(&pointer, &wf_step, &workflow);
328
329 let err = step.run(&ctx).await.unwrap_err();
330 assert!(err.to_string().contains("host context"));
331 }
332
333 #[tokio::test]
334 async fn input_validation_failure() {
335 let host = MockHostContext::new("child-123");
336 let mut step = SubWorkflowStep {
337 workflow_id: "child-def".into(),
338 version: 1,
339 inputs: json!({"name": 42}), input_schema: Some(WorkflowSchema {
341 inputs: HashMap::from([("name".into(), SchemaType::String)]),
342 outputs: HashMap::new(),
343 }),
344 ..Default::default()
345 };
346
347 let pointer = ExecutionPointer::new(0);
348 let wf_step = default_step();
349 let workflow = default_workflow();
350 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
351
352 let err = step.run(&ctx).await.unwrap_err();
353 assert!(err.to_string().contains("input validation failed"));
354 assert!(host.calls().is_empty());
355 }
356
357 #[tokio::test]
358 async fn output_validation_failure() {
359 let mut step = SubWorkflowStep {
360 workflow_id: "child-def".into(),
361 version: 1,
362 inputs: json!({}),
363 output_keys: vec![],
364 output_schema: Some(WorkflowSchema {
365 inputs: HashMap::new(),
366 outputs: HashMap::from([("result".into(), SchemaType::String)]),
367 }),
368 ..Default::default()
369 };
370
371 let mut pointer = ExecutionPointer::new(0);
372 pointer.event_data = Some(json!({
373 "status": "Complete",
374 "data": {"result": 42}
375 }));
376 let wf_step = default_step();
377 let workflow = default_workflow();
378 let ctx = make_context(&pointer, &wf_step, &workflow);
379
380 let err = step.run(&ctx).await.unwrap_err();
381 assert!(err.to_string().contains("output validation failed"));
382 }
383
384 #[tokio::test]
385 async fn input_validation_passes_then_starts_child() {
386 let host = MockHostContext::new("child-456");
387 let mut step = SubWorkflowStep {
388 workflow_id: "child-def".into(),
389 version: 2,
390 inputs: json!({"name": "Alice"}),
391 input_schema: Some(WorkflowSchema {
392 inputs: HashMap::from([("name".into(), SchemaType::String)]),
393 outputs: HashMap::new(),
394 }),
395 ..Default::default()
396 };
397
398 let pointer = ExecutionPointer::new(0);
399 let wf_step = default_step();
400 let workflow = default_workflow();
401 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
402
403 let result = step.run(&ctx).await.unwrap();
404 assert!(!result.proceed);
405 assert_eq!(result.event_key.as_deref(), Some("child-456"));
406 assert_eq!(host.calls().len(), 1);
407 }
408
409 #[tokio::test]
410 async fn host_start_workflow_error_propagates() {
411 let host = FailingHostContext;
412 let mut step = SubWorkflowStep {
413 workflow_id: "child-def".into(),
414 version: 1,
415 inputs: json!({}),
416 ..Default::default()
417 };
418
419 let pointer = ExecutionPointer::new(0);
420 let wf_step = default_step();
421 let workflow = default_workflow();
422 let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
423
424 let err = step.run(&ctx).await.unwrap_err();
425 assert!(err.to_string().contains("failed to start child"));
426 }
427
428 #[tokio::test]
429 async fn event_data_without_data_field_returns_empty_output() {
430 let mut step = SubWorkflowStep {
431 workflow_id: "child-def".into(),
432 version: 1,
433 inputs: json!({}),
434 output_keys: vec!["foo".into()],
435 ..Default::default()
436 };
437
438 let mut pointer = ExecutionPointer::new(0);
439 pointer.event_data = Some(json!({"status": "Complete"}));
440 let wf_step = default_step();
441 let workflow = default_workflow();
442 let ctx = make_context(&pointer, &wf_step, &workflow);
443
444 let result = step.run(&ctx).await.unwrap();
445 assert!(result.proceed);
446 assert_eq!(result.output_data, Some(json!({})));
447 }
448
449 #[tokio::test]
450 async fn default_step_has_empty_fields() {
451 let step = SubWorkflowStep::default();
452 assert!(step.workflow_id.is_empty());
453 assert_eq!(step.version, 0);
454 assert_eq!(step.inputs, json!(null));
455 assert!(step.output_keys.is_empty());
456 assert!(step.input_schema.is_none());
457 assert!(step.output_schema.is_none());
458 }
459}