Skip to main content

wfe_core/primitives/
sub_workflow.rs

1use async_trait::async_trait;
2use chrono::Utc;
3
4use crate::models::ExecutionResult;
5use crate::models::schema::WorkflowSchema;
6use crate::traits::step::{StepBody, StepExecutionContext};
7
8/// A step that starts a child workflow and waits for its completion.
9///
10/// On first invocation, it validates inputs against `input_schema`, starts the
11/// child workflow via the host context, and returns a "wait for event" result.
12///
13/// When the child workflow completes, the event data arrives, output keys are
14/// extracted, and the step proceeds.
15#[derive(Default)]
16pub struct SubWorkflowStep {
17    /// The definition ID of the child workflow to start.
18    pub workflow_id: String,
19    /// The version of the child workflow definition.
20    pub version: u32,
21    /// Input data to pass to the child workflow.
22    pub inputs: serde_json::Value,
23    /// Keys to extract from the child workflow's completion event data.
24    pub output_keys: Vec<String>,
25    /// Optional schema to validate inputs before starting the child.
26    pub input_schema: Option<WorkflowSchema>,
27    /// Optional schema to validate outputs from the child.
28    pub output_schema: Option<WorkflowSchema>,
29}
30
31#[async_trait]
32impl StepBody for SubWorkflowStep {
33    async fn run(&mut self, context: &StepExecutionContext<'_>) -> crate::Result<ExecutionResult> {
34        // If event data has arrived, the child workflow completed.
35        if let Some(event_data) = &context.execution_pointer.event_data {
36            // Extract output_keys from event data.
37            let mut output = serde_json::Map::new();
38
39            // The event data contains { "status": "...", "data": { ... } }.
40            let child_data = event_data
41                .get("data")
42                .cloned()
43                .unwrap_or(serde_json::Value::Null);
44
45            if self.output_keys.is_empty() {
46                // If no specific keys requested, pass all child data through.
47                if let serde_json::Value::Object(map) = child_data {
48                    output = map;
49                }
50            } else {
51                // Extract only the requested keys.
52                for key in &self.output_keys {
53                    if let Some(val) = child_data.get(key) {
54                        output.insert(key.clone(), val.clone());
55                    }
56                }
57            }
58
59            let output_value = serde_json::Value::Object(output);
60
61            // Validate against output schema if present.
62            if let Some(ref schema) = self.output_schema
63                && let Err(errors) = schema.validate_outputs(&output_value)
64            {
65                return Err(crate::WfeError::StepExecution(format!(
66                    "SubWorkflow output validation failed: {}",
67                    errors.join("; ")
68                )));
69            }
70
71            let mut result = ExecutionResult::next();
72            result.output_data = Some(output_value);
73            return Ok(result);
74        }
75
76        // Hydrate from step_config if our fields are empty (created via Default).
77        if self.workflow_id.is_empty()
78            && let Some(config) = &context.step.step_config
79        {
80            if let Some(wf_id) = config.get("workflow_id").and_then(|v| v.as_str()) {
81                self.workflow_id = wf_id.to_string();
82            }
83            if let Some(ver) = config.get("version").and_then(|v| v.as_u64()) {
84                self.version = ver as u32;
85            }
86            if let Some(inputs) = config.get("inputs") {
87                self.inputs = inputs.clone();
88            }
89            if let Some(keys) = config.get("output_keys").and_then(|v| v.as_array()) {
90                self.output_keys = keys
91                    .iter()
92                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
93                    .collect();
94            }
95        }
96
97        // First call: validate inputs and start child workflow.
98        if let Some(ref schema) = self.input_schema
99            && let Err(errors) = schema.validate_inputs(&self.inputs)
100        {
101            return Err(crate::WfeError::StepExecution(format!(
102                "SubWorkflow input validation failed: {}",
103                errors.join("; ")
104            )));
105        }
106
107        let host = context.host_context.ok_or_else(|| {
108            crate::WfeError::StepExecution(
109                "SubWorkflowStep requires a host context to start child workflows".to_string(),
110            )
111        })?;
112
113        // Use explicit inputs if set; otherwise inherit the parent workflow's
114        // data so child steps can reference the same top-level fields (e.g.
115        // REPO_URL, COMMIT_SHA) without every `type: workflow` step having to
116        // re-declare them. Fall back to an empty object when the parent has
117        // no data either so the child still has a valid JSON object for
118        // storing step outputs.
119        let child_data = if !self.inputs.is_null() {
120            self.inputs.clone()
121        } else if context.workflow.data.is_object() {
122            context.workflow.data.clone()
123        } else {
124            serde_json::json!({})
125        };
126        // Inherit the parent's root — or, if the parent is itself a root
127        // (has no root set), use the parent's own id as the root for the
128        // child. This makes every descendant of a top-level ci run share
129        // the same root_workflow_id and therefore the same namespace and
130        // shared volume on backends that care.
131        let parent_root = context
132            .workflow
133            .root_workflow_id
134            .clone()
135            .or_else(|| Some(context.workflow.id.clone()));
136        let child_instance_id = host
137            .start_workflow(&self.workflow_id, self.version, child_data, parent_root)
138            .await?;
139
140        Ok(ExecutionResult::wait_for_event(
141            "wfe.workflow.completed",
142            child_instance_id,
143            Utc::now(),
144        ))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::models::ExecutionPointer;
152    use crate::models::schema::SchemaType;
153    use crate::primitives::test_helpers::*;
154    use crate::traits::step::HostContext;
155    use serde_json::json;
156    use std::collections::HashMap;
157    use std::sync::Mutex;
158
159    /// A mock HostContext that records calls and returns a fixed instance ID.
160    struct MockHostContext {
161        started: Mutex<Vec<(String, u32, serde_json::Value)>>,
162        result_id: String,
163    }
164
165    impl MockHostContext {
166        fn new(result_id: &str) -> Self {
167            Self {
168                started: Mutex::new(Vec::new()),
169                result_id: result_id.to_string(),
170            }
171        }
172
173        fn calls(&self) -> Vec<(String, u32, serde_json::Value)> {
174            self.started.lock().unwrap().clone()
175        }
176    }
177
178    impl HostContext for MockHostContext {
179        fn start_workflow(
180            &self,
181            definition_id: &str,
182            version: u32,
183            data: serde_json::Value,
184            _parent_root_workflow_id: Option<String>,
185        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
186        {
187            let def_id = definition_id.to_string();
188            let result_id = self.result_id.clone();
189            Box::pin(async move {
190                self.started.lock().unwrap().push((def_id, version, data));
191                Ok(result_id)
192            })
193        }
194    }
195
196    /// A mock HostContext that returns an error.
197    struct FailingHostContext;
198
199    impl HostContext for FailingHostContext {
200        fn start_workflow(
201            &self,
202            _definition_id: &str,
203            _version: u32,
204            _data: serde_json::Value,
205            _parent_root_workflow_id: Option<String>,
206        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<String>> + Send + '_>>
207        {
208            Box::pin(async {
209                Err(crate::WfeError::StepExecution(
210                    "failed to start child".to_string(),
211                ))
212            })
213        }
214    }
215
216    fn make_context_with_host<'a>(
217        pointer: &'a ExecutionPointer,
218        step: &'a crate::models::WorkflowStep,
219        workflow: &'a crate::models::WorkflowInstance,
220        host: &'a dyn HostContext,
221    ) -> StepExecutionContext<'a> {
222        StepExecutionContext {
223            definition: None,
224            item: None,
225            execution_pointer: pointer,
226            persistence_data: pointer.persistence_data.as_ref(),
227            step,
228            workflow,
229            cancellation_token: tokio_util::sync::CancellationToken::new(),
230            host_context: Some(host),
231            log_sink: None,
232        }
233    }
234
235    #[tokio::test]
236    async fn first_call_starts_child_and_waits() {
237        let host = MockHostContext::new("child-123");
238        let mut step = SubWorkflowStep {
239            workflow_id: "child-def".into(),
240            version: 1,
241            inputs: json!({"x": 10}),
242            ..Default::default()
243        };
244
245        let pointer = ExecutionPointer::new(0);
246        let wf_step = default_step();
247        let workflow = default_workflow();
248        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
249
250        let result = step.run(&ctx).await.unwrap();
251        assert!(!result.proceed);
252        assert_eq!(result.event_name.as_deref(), Some("wfe.workflow.completed"));
253        assert_eq!(result.event_key.as_deref(), Some("child-123"));
254        assert!(result.event_as_of.is_some());
255
256        let calls = host.calls();
257        assert_eq!(calls.len(), 1);
258        assert_eq!(calls[0].0, "child-def");
259        assert_eq!(calls[0].1, 1);
260        assert_eq!(calls[0].2, json!({"x": 10}));
261    }
262
263    #[tokio::test]
264    async fn child_completed_proceeds_with_output() {
265        let mut step = SubWorkflowStep {
266            workflow_id: "child-def".into(),
267            version: 1,
268            inputs: json!({}),
269            output_keys: vec!["result".into()],
270            ..Default::default()
271        };
272
273        let mut pointer = ExecutionPointer::new(0);
274        pointer.event_data = Some(json!({
275            "status": "Complete",
276            "data": {"result": "success", "extra": "ignored"}
277        }));
278        let wf_step = default_step();
279        let workflow = default_workflow();
280        let ctx = make_context(&pointer, &wf_step, &workflow);
281
282        let result = step.run(&ctx).await.unwrap();
283        assert!(result.proceed);
284        assert_eq!(result.output_data, Some(json!({"result": "success"})));
285    }
286
287    #[tokio::test]
288    async fn child_completed_no_output_keys_passes_all() {
289        let mut step = SubWorkflowStep {
290            workflow_id: "child-def".into(),
291            version: 1,
292            inputs: json!({}),
293            output_keys: vec![],
294            ..Default::default()
295        };
296
297        let mut pointer = ExecutionPointer::new(0);
298        pointer.event_data = Some(json!({
299            "status": "Complete",
300            "data": {"a": 1, "b": 2}
301        }));
302        let wf_step = default_step();
303        let workflow = default_workflow();
304        let ctx = make_context(&pointer, &wf_step, &workflow);
305
306        let result = step.run(&ctx).await.unwrap();
307        assert!(result.proceed);
308        assert_eq!(result.output_data, Some(json!({"a": 1, "b": 2})));
309    }
310
311    #[tokio::test]
312    async fn no_host_context_errors() {
313        let mut step = SubWorkflowStep {
314            workflow_id: "child-def".into(),
315            version: 1,
316            inputs: json!({}),
317            ..Default::default()
318        };
319
320        let pointer = ExecutionPointer::new(0);
321        let wf_step = default_step();
322        let workflow = default_workflow();
323        let ctx = make_context(&pointer, &wf_step, &workflow);
324
325        let err = step.run(&ctx).await.unwrap_err();
326        assert!(err.to_string().contains("host context"));
327    }
328
329    #[tokio::test]
330    async fn input_validation_failure() {
331        let host = MockHostContext::new("child-123");
332        let mut step = SubWorkflowStep {
333            workflow_id: "child-def".into(),
334            version: 1,
335            inputs: json!({"name": 42}), // wrong type
336            input_schema: Some(WorkflowSchema {
337                inputs: HashMap::from([("name".into(), SchemaType::String)]),
338                outputs: HashMap::new(),
339            }),
340            ..Default::default()
341        };
342
343        let pointer = ExecutionPointer::new(0);
344        let wf_step = default_step();
345        let workflow = default_workflow();
346        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
347
348        let err = step.run(&ctx).await.unwrap_err();
349        assert!(err.to_string().contains("input validation failed"));
350        assert!(host.calls().is_empty());
351    }
352
353    #[tokio::test]
354    async fn output_validation_failure() {
355        let mut step = SubWorkflowStep {
356            workflow_id: "child-def".into(),
357            version: 1,
358            inputs: json!({}),
359            output_keys: vec![],
360            output_schema: Some(WorkflowSchema {
361                inputs: HashMap::new(),
362                outputs: HashMap::from([("result".into(), SchemaType::String)]),
363            }),
364            ..Default::default()
365        };
366
367        let mut pointer = ExecutionPointer::new(0);
368        pointer.event_data = Some(json!({
369            "status": "Complete",
370            "data": {"result": 42}
371        }));
372        let wf_step = default_step();
373        let workflow = default_workflow();
374        let ctx = make_context(&pointer, &wf_step, &workflow);
375
376        let err = step.run(&ctx).await.unwrap_err();
377        assert!(err.to_string().contains("output validation failed"));
378    }
379
380    #[tokio::test]
381    async fn input_validation_passes_then_starts_child() {
382        let host = MockHostContext::new("child-456");
383        let mut step = SubWorkflowStep {
384            workflow_id: "child-def".into(),
385            version: 2,
386            inputs: json!({"name": "Alice"}),
387            input_schema: Some(WorkflowSchema {
388                inputs: HashMap::from([("name".into(), SchemaType::String)]),
389                outputs: HashMap::new(),
390            }),
391            ..Default::default()
392        };
393
394        let pointer = ExecutionPointer::new(0);
395        let wf_step = default_step();
396        let workflow = default_workflow();
397        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
398
399        let result = step.run(&ctx).await.unwrap();
400        assert!(!result.proceed);
401        assert_eq!(result.event_key.as_deref(), Some("child-456"));
402        assert_eq!(host.calls().len(), 1);
403    }
404
405    #[tokio::test]
406    async fn host_start_workflow_error_propagates() {
407        let host = FailingHostContext;
408        let mut step = SubWorkflowStep {
409            workflow_id: "child-def".into(),
410            version: 1,
411            inputs: json!({}),
412            ..Default::default()
413        };
414
415        let pointer = ExecutionPointer::new(0);
416        let wf_step = default_step();
417        let workflow = default_workflow();
418        let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host);
419
420        let err = step.run(&ctx).await.unwrap_err();
421        assert!(err.to_string().contains("failed to start child"));
422    }
423
424    #[tokio::test]
425    async fn event_data_without_data_field_returns_empty_output() {
426        let mut step = SubWorkflowStep {
427            workflow_id: "child-def".into(),
428            version: 1,
429            inputs: json!({}),
430            output_keys: vec!["foo".into()],
431            ..Default::default()
432        };
433
434        let mut pointer = ExecutionPointer::new(0);
435        pointer.event_data = Some(json!({"status": "Complete"}));
436        let wf_step = default_step();
437        let workflow = default_workflow();
438        let ctx = make_context(&pointer, &wf_step, &workflow);
439
440        let result = step.run(&ctx).await.unwrap();
441        assert!(result.proceed);
442        assert_eq!(result.output_data, Some(json!({})));
443    }
444
445    #[tokio::test]
446    async fn default_step_has_empty_fields() {
447        let step = SubWorkflowStep::default();
448        assert!(step.workflow_id.is_empty());
449        assert_eq!(step.version, 0);
450        assert_eq!(step.inputs, json!(null));
451        assert!(step.output_keys.is_empty());
452        assert!(step.input_schema.is_none());
453        assert!(step.output_schema.is_none());
454    }
455}