Skip to main content

stepflow_flow/workflow/
overrides.rs

1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use error_stack::ResultExt as _;
17use serde::{Deserialize, Serialize};
18
19use super::Flow;
20
21/// Workflow overrides that can be applied to modify step behavior at runtime.
22///
23/// Overrides are keyed by step ID and contain merge patches or other transformation
24/// specifications to modify step properties before execution.
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
26pub struct WorkflowOverrides {
27    /// Map of step ID to override specification
28    pub steps: HashMap<String, StepOverride>,
29}
30
31impl WorkflowOverrides {
32    /// Create new empty workflow overrides
33    pub fn new() -> Self {
34        Self {
35            steps: HashMap::new(),
36        }
37    }
38
39    /// Check if there are any overrides defined
40    pub fn is_empty(&self) -> bool {
41        self.steps.is_empty()
42    }
43
44    /// Add an override for a specific step
45    pub fn add_step_override(&mut self, step_id: String, override_spec: StepOverride) {
46        self.steps.insert(step_id, override_spec);
47    }
48}
49
50impl Default for WorkflowOverrides {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56/// Override specification for a single step.
57///
58/// Contains the override type (merge patch, json patch, etc.) and the value
59/// to apply. The type field uses `$type` to avoid collisions with step properties.
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
61pub struct StepOverride {
62    /// The type of override to apply. Defaults to "merge_patch" if not specified.
63    #[serde(rename = "$type", default = "default_override_type")]
64    pub override_type: OverrideType,
65
66    /// The override value to apply, interpreted based on the override type.
67    pub value: serde_json::Value,
68}
69
70impl StepOverride {
71    /// Create a new step override with merge patch type
72    pub fn merge_patch(value: serde_json::Value) -> Self {
73        Self {
74            override_type: OverrideType::MergePatch,
75            value,
76        }
77    }
78
79    /// Create a new step override with explicit type
80    pub fn with_type(override_type: OverrideType, value: serde_json::Value) -> Self {
81        Self {
82            override_type,
83            value,
84        }
85    }
86}
87
88/// The type of override operation to perform.
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
90#[serde(rename_all = "snake_case")]
91pub enum OverrideType {
92    /// Apply a JSON Merge Patch (RFC 7396) to the step.
93    ///
94    /// This is the default override type. The value should be a JSON object
95    /// where null values indicate fields to remove and other values are merged
96    /// into the target step.
97    MergePatch,
98
99    /// Apply a JSON Patch (RFC 6902) to the step. (Future extension)
100    ///
101    /// The value should be an array of JSON Patch operations.
102    /// This is reserved for future use.
103    #[allow(dead_code)]
104    JsonPatch,
105}
106
107/// Default override type is merge patch
108fn default_override_type() -> OverrideType {
109    OverrideType::MergePatch
110}
111
112/// Errors that can occur when applying workflow overrides
113#[derive(Debug, thiserror::Error)]
114pub enum OverrideError {
115    #[error("Step '{step_id}' not found in workflow")]
116    StepNotFound { step_id: String },
117
118    #[error("Invalid override value for step '{step_id}': {reason}")]
119    InvalidOverrideValue { step_id: String, reason: String },
120
121    #[error("Unsupported override type: {override_type:?}")]
122    UnsupportedOverrideType { override_type: OverrideType },
123
124    #[error("JSON merge patch failed for step '{step_id}': {reason}")]
125    MergePatchFailed { step_id: String, reason: String },
126}
127
128pub type OverrideResult<T> = error_stack::Result<T, OverrideError>;
129
130/// Trait for applying workflow overrides to flows
131pub trait OverrideProcessor {
132    /// Apply the given overrides to a workflow, returning a modified workflow.
133    ///
134    /// This validates that all override targets exist in the workflow and
135    /// applies the specified transformations. If no overrides are needed,
136    /// the original Arc is returned unchanged.
137    fn apply_overrides(
138        &self,
139        flow: Arc<Flow>,
140        overrides: &WorkflowOverrides,
141    ) -> OverrideResult<Arc<Flow>>;
142}
143
144/// Default implementation of override processing
145pub struct DefaultOverrideProcessor;
146
147impl OverrideProcessor for DefaultOverrideProcessor {
148    fn apply_overrides(
149        &self,
150        flow: Arc<Flow>,
151        overrides: &WorkflowOverrides,
152    ) -> OverrideResult<Arc<Flow>> {
153        if overrides.is_empty() {
154            return Ok(flow);
155        }
156
157        log::debug!(
158            "Applying {} step overrides to workflow",
159            overrides.steps.len()
160        );
161
162        // Validate all override targets exist in the workflow
163        self.validate_override_targets(&flow, overrides)?;
164
165        // Need to clone the flow to apply modifications
166        let mut cloned_flow = flow.slow_clone();
167
168        // Apply overrides to each step
169        for step in &mut cloned_flow.steps {
170            if let Some(step_override) = overrides.steps.get(&step.id) {
171                log::debug!(
172                    "Applying override to step '{}' with type '{:?}'",
173                    step.id,
174                    step_override.override_type
175                );
176                self.apply_step_override(step, step_override)
177                    .change_context(OverrideError::InvalidOverrideValue {
178                        step_id: step.id.clone(),
179                        reason: "Failed to apply step override".to_string(),
180                    })?;
181            }
182        }
183
184        Ok(Arc::new(cloned_flow))
185    }
186}
187
188impl DefaultOverrideProcessor {
189    /// Create a new default override processor
190    pub fn new() -> Self {
191        Self
192    }
193
194    /// Validate that all override targets exist in the workflow
195    fn validate_override_targets(
196        &self,
197        flow: &Flow,
198        overrides: &WorkflowOverrides,
199    ) -> OverrideResult<()> {
200        let step_ids: std::collections::HashSet<&String> =
201            flow.steps().iter().map(|step| &step.id).collect();
202
203        for step_id in overrides.steps.keys() {
204            if !step_ids.contains(&step_id) {
205                return Err(error_stack::report!(OverrideError::StepNotFound {
206                    step_id: step_id.clone(),
207                }));
208            }
209        }
210
211        Ok(())
212    }
213
214    /// Apply a single step override based on its type
215    fn apply_step_override(
216        &self,
217        step: &mut super::Step,
218        step_override: &StepOverride,
219    ) -> OverrideResult<()> {
220        match step_override.override_type {
221            OverrideType::MergePatch => self.apply_merge_patch(step, &step_override.value),
222            OverrideType::JsonPatch => Err(error_stack::report!(
223                OverrideError::UnsupportedOverrideType {
224                    override_type: step_override.override_type.clone(),
225                }
226            )),
227        }
228    }
229
230    /// Apply a JSON merge patch to a step
231    fn apply_merge_patch(
232        &self,
233        step: &mut super::Step,
234        patch: &serde_json::Value,
235    ) -> OverrideResult<()> {
236        let step_id = step.id.clone(); // Capture the ID before borrowing
237
238        // Convert step to JSON for merging
239        let mut step_json =
240            serde_json::to_value(&*step).change_context(OverrideError::MergePatchFailed {
241                step_id: step_id.clone(),
242                reason: "Failed to serialize step to JSON".to_string(),
243            })?;
244
245        // Apply merge patch
246        json_patch::merge(&mut step_json, patch);
247
248        // Convert back to Step
249        *step =
250            serde_json::from_value(step_json).change_context(OverrideError::MergePatchFailed {
251                step_id,
252                reason: "Failed to deserialize modified step from JSON".to_string(),
253            })?;
254
255        Ok(())
256    }
257}
258
259impl Default for DefaultOverrideProcessor {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265/// Convenience function to apply overrides to a workflow using the default processor
266pub fn apply_overrides(
267    flow: Arc<Flow>,
268    overrides: &WorkflowOverrides,
269) -> OverrideResult<Arc<Flow>> {
270    DefaultOverrideProcessor::new().apply_overrides(flow, overrides)
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::ValueExpr;
277    use serde_json::json;
278
279    fn create_test_flow() -> Flow {
280        Flow {
281            name: Some("test_flow".to_string()),
282            description: None,
283            version: None,
284            schemas: super::super::FlowSchema::default(),
285            steps: vec![super::super::Step {
286                id: "step1".to_string(),
287                component: super::super::Component::from_string("/test/component"),
288                on_error: None,
289                input: ValueExpr::null(),
290                must_execute: None,
291                metadata: std::collections::HashMap::new(),
292            }],
293            output: ValueExpr::null(),
294            test: None,
295            examples: None,
296            metadata: std::collections::HashMap::new(),
297        }
298    }
299
300    #[test]
301    fn test_workflow_overrides_creation() {
302        let overrides = WorkflowOverrides::new();
303        assert!(overrides.is_empty());
304
305        let mut overrides = WorkflowOverrides::new();
306        overrides.add_step_override(
307            "step1".to_string(),
308            StepOverride::merge_patch(json!({"input": {"temperature": 0.8}})),
309        );
310        assert!(!overrides.is_empty());
311        assert_eq!(overrides.steps.len(), 1);
312    }
313
314    #[test]
315    fn test_step_override_creation() {
316        let override_spec = StepOverride::merge_patch(json!({"input": {"temperature": 0.8}}));
317        assert!(matches!(
318            override_spec.override_type,
319            OverrideType::MergePatch
320        ));
321        assert_eq!(override_spec.value, json!({"input": {"temperature": 0.8}}));
322
323        let override_spec = StepOverride::with_type(
324            OverrideType::MergePatch,
325            json!({"component": "/different/component"}),
326        );
327        assert!(matches!(
328            override_spec.override_type,
329            OverrideType::MergePatch
330        ));
331        assert_eq!(
332            override_spec.value,
333            json!({"component": "/different/component"})
334        );
335    }
336
337    #[test]
338    fn test_apply_empty_overrides() {
339        let flow = Arc::new(create_test_flow());
340        let overrides = WorkflowOverrides::new();
341
342        let original_step_count = flow.steps().len();
343        let result = apply_overrides(flow, &overrides).unwrap();
344        assert_eq!(result.steps().len(), original_step_count);
345    }
346
347    #[test]
348    fn test_apply_merge_patch_override() {
349        let flow = Arc::new(create_test_flow());
350        let mut overrides = WorkflowOverrides::new();
351        overrides.add_step_override(
352            "step1".to_string(),
353            StepOverride::merge_patch(json!({
354                "input": {"temperature": 0.8},
355                "component": "/new/component"
356            })),
357        );
358
359        let result = apply_overrides(flow, &overrides).unwrap();
360        let step = &result.steps()[0];
361
362        // Check that the component was overridden
363        assert_eq!(step.component.to_string(), "/new/component");
364
365        // Check that input was merged (note: this is a simplified test)
366        // In practice, the input field would be properly merged with existing values
367    }
368
369    #[test]
370    fn test_validate_override_targets_missing_step() {
371        let flow = Arc::new(create_test_flow());
372        let mut overrides = WorkflowOverrides::new();
373        overrides.add_step_override(
374            "nonexistent_step".to_string(),
375            StepOverride::merge_patch(json!({"input": {"temperature": 0.8}})),
376        );
377
378        let result = apply_overrides(flow, &overrides);
379        assert!(result.is_err());
380
381        let error = result.unwrap_err();
382        assert!(
383            error
384                .to_string()
385                .contains("Step 'nonexistent_step' not found in workflow")
386        );
387    }
388
389    #[test]
390    fn test_serde_override_type_default() {
391        let json_str = r#"{"value": {"temperature": 0.8}}"#;
392        let step_override: StepOverride = serde_json::from_str(json_str).unwrap();
393
394        assert!(matches!(
395            step_override.override_type,
396            OverrideType::MergePatch
397        ));
398        assert_eq!(step_override.value, json!({"temperature": 0.8}));
399    }
400
401    #[test]
402    fn test_serde_override_type_explicit() {
403        let json_str = r#"{"$type": "merge_patch", "value": {"temperature": 0.8}}"#;
404        let step_override: StepOverride = serde_json::from_str(json_str).unwrap();
405
406        assert!(matches!(
407            step_override.override_type,
408            OverrideType::MergePatch
409        ));
410        assert_eq!(step_override.value, json!({"temperature": 0.8}));
411    }
412}