Skip to main content

wfe_core/primitives/
sub_workflow.rs

1use async_trait::async_trait;
2use chrono::Utc;
3
4use crate::models::ExecutionResult;
5use crate::models::schema::WorkflowSchema;
6use crate::traits::step::{StepBody, StepExecutionContext};
7
8/// A step that starts a child workflow and waits for its completion.
9///
10/// On first invocation, it validates inputs against `input_schema`, starts the
11/// child workflow via the host context, and returns a "wait for event" result.
12///
13/// When the child workflow completes, the event data arrives, output keys are
14/// extracted, and the step proceeds.
15#[derive(Default)]
16pub struct SubWorkflowStep {
17    /// The definition ID of the child workflow to start.
18    pub workflow_id: String,
19    /// The version of the child workflow definition.
20    pub version: u32,
21    /// Input data to pass to the child workflow.
22    pub inputs: serde_json::Value,
23    /// Keys to extract from the child workflow's completion event data.
24    pub output_keys: Vec<String>,
25    /// Optional schema to validate inputs before starting the child.
26    pub input_schema: Option<WorkflowSchema>,
27    /// Optional schema to validate outputs from the child.
28    pub output_schema: Option<WorkflowSchema>,
29}
30
31#[async_trait]
32impl StepBody for SubWorkflowStep {
33    async fn run(&mut self, context: &StepExecutionContext<'_>) -> crate::Result<ExecutionResult> {
34        // If event data has arrived, the child workflow completed.
35        if let Some(event_data) = &context.execution_pointer.event_data {
36            // Extract output_keys from event data.
37            let mut output = serde_json::Map::new();
38
39            // The event data contains { "status": "...", "data": { ... } }.
40            let child_data = event_data
41                .get("data")
42                .cloned()
43                .unwrap_or(serde_json::Value::Null);
44
45            if self.output_keys.is_empty() {
46                // If no specific keys requested, pass all child data through.
47                if let serde_json::Value::Object(map) = child_data {
48                    output = map;
49                }
50            } else {
51                // Extract only the requested keys.
52                for key in &self.output_keys {
53                    if let Some(val) = child_data.get(key) {
54                        output.insert(key.clone(), val.clone());
55                    }
56                }
57            }
58
59            let output_value = serde_json::Value::Object(output);
60
61            // Validate against output schema if present.
62            if let Some(ref schema) = self.output_schema
63                && let Err(errors) = schema.validate_outputs(&output_value)
64            {
65                return Err(crate::WfeError::StepExecution(format!(
66                    "SubWorkflow output validation failed: {}",
67                    errors.join("; ")
68                )));
69            }
70
71            let mut result = ExecutionResult::next();
72            result.output_data = Some(output_value);
73            return Ok(result);
74        }
75
76        // Hydrate from step_config if our fields are empty (created via Default).
77        if self.workflow_id.is_empty()
78            && let Some(config) = &context.step.step_config
79        {
80            if let Some(wf_id) = config.get("workflow_id").and_then(|v| v.as_str()) {
81                self.workflow_id = wf_id.to_string();
82            }
83            if let Some(ver) = config.get("version").and_then(|v| v.as_u64()) {
84                self.version = ver as u32;
85            }
86            if let Some(inputs) = config.get("inputs") {
87                self.inputs = inputs.clone();
88            }
89            if let Some(keys) = config.get("output_keys").and_then(|v| v.as_array()) {
90                self.output_keys = keys
91                    .iter()
92                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
93                    .collect();
94            }
95        }
96
97        // First call: validate inputs and start child workflow.
98        if let Some(ref schema) = self.input_schema
99            && let Err(errors) = schema.validate_inputs(&self.inputs)
100        {
101            return Err(crate::WfeError::StepExecution(format!(
102                "SubWorkflow input validation failed: {}",
103                errors.join("; ")
104            )));
105        }
106
107        let host = context.host_context.ok_or_else(|| {
108            crate::WfeError::StepExecution(
109                "SubWorkflowStep requires a host context to start child workflows".to_string(),
110            )
111        })?;
112
113        // Use explicit inputs if set; otherwise inherit the parent workflow's
114        // data so child steps can reference the same top-level fields (e.g.
115        // REPO_URL, COMMIT_SHA) without every `type: workflow` step having to
116        // re-declare them. Fall back to an empty object when the parent has
117        // no data either so the child still has a valid JSON object for
118        // storing step outputs.
119        let child_data = if !self.inputs.is_null() {
120            self.inputs.clone()
121        } else if context.workflow.data.is_object() {
122            context.workflow.data.clone()
123        } else {
124            serde_json::json!({})
125        };
126        // Inherit the parent's root — or, if the parent is itself a root
127        // (has no root set), use the parent's own id as the root for the
128        // child. This makes every descendant of a top-level ci run share
129        // the same root_workflow_id and therefore the same namespace and
130        // shared volume on backends that care.
131        let parent_root = context
132            .workflow
133            .root_workflow_id
134            .clone()
135            .or_else(|| Some(context.workflow.id.clone()));
136        let child_instance_id = host
137            .start_workflow(&self.workflow_id, self.version, child_data, parent_root)
138            .await?;
139
140        Ok(ExecutionResult::wait_for_event(
141            "wfe.workflow.completed",
142            child_instance_id,
143            Utc::now(),
144        ))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::models::ExecutionPointer;
152    use crate::models::schema::SchemaType;
153    use crate::primitives::test_helpers::*;
154    use crate::traits::step::HostContext;
155    use serde_json::json;
156    use std::collections::HashMap;
157    use std::sync::Mutex;
158
159    /// A mock HostContext that records calls and returns a fixed instance ID.
160    struct MockHostContext {
161        started: Mutex<Vec<(String, u32, serde_json::Value)>>,
162        result_id: String,
163    }
164
165    impl MockHostContext {
166        fn new(result_id: &str) -> Self {
167            Self {
168                started: Mutex::new(Vec::new()),
169                result_id: result_id.to_string(),
170            }
171        }
172
173        fn calls(&self) -> Vec<(String, u32, serde_json::Value)> {
174            self.started.lock().unwrap().clone()
175        }
176    }
177
178    impl HostContext for MockHostContext {
179        fn start_workflow(
180            &self,
181            definition_id: &str,
182            version: u32,
183            data: serde_json::Value,
184            _parent_root_workflow_id: Option<String>,
185        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
186        {
187            let def_id = definition_id.to_string();
188            let result_id = self.result_id.clone();
189            Box::pin(async move {
190                self.started.lock().unwrap().push((def_id, version, data));
191                Ok(result_id)
192            })
193        }
194    }
195
196    /// A mock HostContext that returns an error.
197    struct FailingHostContext;
198
199    impl HostContext for FailingHostContext {
200        fn start_workflow(
201            &self,
202            _definition_id: &str,
203            _version: u32,
204            _data: serde_json::Value,
205            _parent_root_workflow_id: Option<String>,
206        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
207        {
208            Box::pin(async {
209                Err(crate::WfeError::StepExecution(
210                    "failed to start child".to_string(),
211                ))
212            })
213        }
214    }
215
216    fn make_context_with_host<'a>(
217        pointer: &'a ExecutionPointer,
218        step: &'a crate::models::WorkflowStep,
219        workflow: &'a crate::models::WorkflowInstance,
220        host: &'a dyn HostContext,
221    ) -> StepExecutionContext<'a> {
222        StepExecutionContext {
223            definition: None,
224            item: None,
225            execution_pointer: pointer,
226            persistence_data: pointer.persistence_data.as_ref(),
227            step,
228            workflow,
229            cancellation_token: tokio_util::sync::CancellationToken::new(),
230            host_context: Some(host),
231            log_sink: None,
232            artifact_store: None,
233            artifact_volume: None,
234            artifact_package: None,
235            persistence: None,
236        }
237    }
238
239    #[tokio::test]
240    async fn first_call_starts_child_and_waits() {
241        let host = MockHostContext::new("child-123");
242        let mut step = SubWorkflowStep {
243            workflow_id: "child-def".into(),
244            version: 1,
245            inputs: json!({"x": 10}),
246            ..Default::default()
247        };
248
249        let pointer = ExecutionPointer::new(0);
250        let wf_step = default_step();
251        let workflow = default_workflow();
252        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
253
254        let result = step.run(&ctx).await.unwrap();
255        assert!(!result.proceed);
256        assert_eq!(result.event_name.as_deref(), Some("wfe.workflow.completed"));
257        assert_eq!(result.event_key.as_deref(), Some("child-123"));
258        assert!(result.event_as_of.is_some());
259
260        let calls = host.calls();
261        assert_eq!(calls.len(), 1);
262        assert_eq!(calls[0].0, "child-def");
263        assert_eq!(calls[0].1, 1);
264        assert_eq!(calls[0].2, json!({"x": 10}));
265    }
266
267    #[tokio::test]
268    async fn child_completed_proceeds_with_output() {
269        let mut step = SubWorkflowStep {
270            workflow_id: "child-def".into(),
271            version: 1,
272            inputs: json!({}),
273            output_keys: vec!["result".into()],
274            ..Default::default()
275        };
276
277        let mut pointer = ExecutionPointer::new(0);
278        pointer.event_data = Some(json!({
279            "status": "Complete",
280            "data": {"result": "success", "extra": "ignored"}
281        }));
282        let wf_step = default_step();
283        let workflow = default_workflow();
284        let ctx = make_context(&pointer, &wf_step, &workflow);
285
286        let result = step.run(&ctx).await.unwrap();
287        assert!(result.proceed);
288        assert_eq!(result.output_data, Some(json!({"result": "success"})));
289    }
290
291    #[tokio::test]
292    async fn child_completed_no_output_keys_passes_all() {
293        let mut step = SubWorkflowStep {
294            workflow_id: "child-def".into(),
295            version: 1,
296            inputs: json!({}),
297            output_keys: vec![],
298            ..Default::default()
299        };
300
301        let mut pointer = ExecutionPointer::new(0);
302        pointer.event_data = Some(json!({
303            "status": "Complete",
304            "data": {"a": 1, "b": 2}
305        }));
306        let wf_step = default_step();
307        let workflow = default_workflow();
308        let ctx = make_context(&pointer, &wf_step, &workflow);
309
310        let result = step.run(&ctx).await.unwrap();
311        assert!(result.proceed);
312        assert_eq!(result.output_data, Some(json!({"a": 1, "b": 2})));
313    }
314
315    #[tokio::test]
316    async fn no_host_context_errors() {
317        let mut step = SubWorkflowStep {
318            workflow_id: "child-def".into(),
319            version: 1,
320            inputs: json!({}),
321            ..Default::default()
322        };
323
324        let pointer = ExecutionPointer::new(0);
325        let wf_step = default_step();
326        let workflow = default_workflow();
327        let ctx = make_context(&pointer, &wf_step, &workflow);
328
329        let err = step.run(&ctx).await.unwrap_err();
330        assert!(err.to_string().contains("host context"));
331    }
332
333    #[tokio::test]
334    async fn input_validation_failure() {
335        let host = MockHostContext::new("child-123");
336        let mut step = SubWorkflowStep {
337            workflow_id: "child-def".into(),
338            version: 1,
339            inputs: json!({"name": 42}), // wrong type
340            input_schema: Some(WorkflowSchema {
341                inputs: HashMap::from([("name".into(), SchemaType::String)]),
342                outputs: HashMap::new(),
343            }),
344            ..Default::default()
345        };
346
347        let pointer = ExecutionPointer::new(0);
348        let wf_step = default_step();
349        let workflow = default_workflow();
350        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
351
352        let err = step.run(&ctx).await.unwrap_err();
353        assert!(err.to_string().contains("input validation failed"));
354        assert!(host.calls().is_empty());
355    }
356
357    #[tokio::test]
358    async fn output_validation_failure() {
359        let mut step = SubWorkflowStep {
360            workflow_id: "child-def".into(),
361            version: 1,
362            inputs: json!({}),
363            output_keys: vec![],
364            output_schema: Some(WorkflowSchema {
365                inputs: HashMap::new(),
366                outputs: HashMap::from([("result".into(), SchemaType::String)]),
367            }),
368            ..Default::default()
369        };
370
371        let mut pointer = ExecutionPointer::new(0);
372        pointer.event_data = Some(json!({
373            "status": "Complete",
374            "data": {"result": 42}
375        }));
376        let wf_step = default_step();
377        let workflow = default_workflow();
378        let ctx = make_context(&pointer, &wf_step, &workflow);
379
380        let err = step.run(&ctx).await.unwrap_err();
381        assert!(err.to_string().contains("output validation failed"));
382    }
383
384    #[tokio::test]
385    async fn input_validation_passes_then_starts_child() {
386        let host = MockHostContext::new("child-456");
387        let mut step = SubWorkflowStep {
388            workflow_id: "child-def".into(),
389            version: 2,
390            inputs: json!({"name": "Alice"}),
391            input_schema: Some(WorkflowSchema {
392                inputs: HashMap::from([("name".into(), SchemaType::String)]),
393                outputs: HashMap::new(),
394            }),
395            ..Default::default()
396        };
397
398        let pointer = ExecutionPointer::new(0);
399        let wf_step = default_step();
400        let workflow = default_workflow();
401        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
402
403        let result = step.run(&ctx).await.unwrap();
404        assert!(!result.proceed);
405        assert_eq!(result.event_key.as_deref(), Some("child-456"));
406        assert_eq!(host.calls().len(), 1);
407    }
408
409    #[tokio::test]
410    async fn host_start_workflow_error_propagates() {
411        let host = FailingHostContext;
412        let mut step = SubWorkflowStep {
413            workflow_id: "child-def".into(),
414            version: 1,
415            inputs: json!({}),
416            ..Default::default()
417        };
418
419        let pointer = ExecutionPointer::new(0);
420        let wf_step = default_step();
421        let workflow = default_workflow();
422        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
423
424        let err = step.run(&ctx).await.unwrap_err();
425        assert!(err.to_string().contains("failed to start child"));
426    }
427
428    #[tokio::test]
429    async fn event_data_without_data_field_returns_empty_output() {
430        let mut step = SubWorkflowStep {
431            workflow_id: "child-def".into(),
432            version: 1,
433            inputs: json!({}),
434            output_keys: vec!["foo".into()],
435            ..Default::default()
436        };
437
438        let mut pointer = ExecutionPointer::new(0);
439        pointer.event_data = Some(json!({"status": "Complete"}));
440        let wf_step = default_step();
441        let workflow = default_workflow();
442        let ctx = make_context(&pointer, &wf_step, &workflow);
443
444        let result = step.run(&ctx).await.unwrap();
445        assert!(result.proceed);
446        assert_eq!(result.output_data, Some(json!({})));
447    }
448
449    #[tokio::test]
450    async fn default_step_has_empty_fields() {
451        let step = SubWorkflowStep::default();
452        assert!(step.workflow_id.is_empty());
453        assert_eq!(step.version, 0);
454        assert_eq!(step.inputs, json!(null));
455        assert!(step.output_keys.is_empty());
456        assert!(step.input_schema.is_none());
457        assert!(step.output_schema.is_none());
458    }
459}