1use std::collections::HashMap;
14use std::sync::Arc;
15
16use error_stack::ResultExt as _;
17use serde::{Deserialize, Serialize};
18
19use super::Flow;
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
26pub struct WorkflowOverrides {
27 pub steps: HashMap<String, StepOverride>,
29}
30
31impl WorkflowOverrides {
32 pub fn new() -> Self {
34 Self {
35 steps: HashMap::new(),
36 }
37 }
38
39 pub fn is_empty(&self) -> bool {
41 self.steps.is_empty()
42 }
43
44 pub fn add_step_override(&mut self, step_id: String, override_spec: StepOverride) {
46 self.steps.insert(step_id, override_spec);
47 }
48}
49
50impl Default for WorkflowOverrides {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
61pub struct StepOverride {
62 #[serde(rename = "$type", default = "default_override_type")]
64 pub override_type: OverrideType,
65
66 pub value: serde_json::Value,
68}
69
70impl StepOverride {
71 pub fn merge_patch(value: serde_json::Value) -> Self {
73 Self {
74 override_type: OverrideType::MergePatch,
75 value,
76 }
77 }
78
79 pub fn with_type(override_type: OverrideType, value: serde_json::Value) -> Self {
81 Self {
82 override_type,
83 value,
84 }
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
90#[serde(rename_all = "snake_case")]
91pub enum OverrideType {
92 MergePatch,
98
99 #[allow(dead_code)]
104 JsonPatch,
105}
106
107fn default_override_type() -> OverrideType {
109 OverrideType::MergePatch
110}
111
112#[derive(Debug, thiserror::Error)]
114pub enum OverrideError {
115 #[error("Step '{step_id}' not found in workflow")]
116 StepNotFound { step_id: String },
117
118 #[error("Invalid override value for step '{step_id}': {reason}")]
119 InvalidOverrideValue { step_id: String, reason: String },
120
121 #[error("Unsupported override type: {override_type:?}")]
122 UnsupportedOverrideType { override_type: OverrideType },
123
124 #[error("JSON merge patch failed for step '{step_id}': {reason}")]
125 MergePatchFailed { step_id: String, reason: String },
126}
127
128pub type OverrideResult<T> = error_stack::Result<T, OverrideError>;
129
130pub trait OverrideProcessor {
132 fn apply_overrides(
138 &self,
139 flow: Arc<Flow>,
140 overrides: &WorkflowOverrides,
141 ) -> OverrideResult<Arc<Flow>>;
142}
143
144pub struct DefaultOverrideProcessor;
146
147impl OverrideProcessor for DefaultOverrideProcessor {
148 fn apply_overrides(
149 &self,
150 flow: Arc<Flow>,
151 overrides: &WorkflowOverrides,
152 ) -> OverrideResult<Arc<Flow>> {
153 if overrides.is_empty() {
154 return Ok(flow);
155 }
156
157 log::debug!(
158 "Applying {} step overrides to workflow",
159 overrides.steps.len()
160 );
161
162 self.validate_override_targets(&flow, overrides)?;
164
165 let mut cloned_flow = flow.slow_clone();
167
168 for step in &mut cloned_flow.steps {
170 if let Some(step_override) = overrides.steps.get(&step.id) {
171 log::debug!(
172 "Applying override to step '{}' with type '{:?}'",
173 step.id,
174 step_override.override_type
175 );
176 self.apply_step_override(step, step_override)
177 .change_context(OverrideError::InvalidOverrideValue {
178 step_id: step.id.clone(),
179 reason: "Failed to apply step override".to_string(),
180 })?;
181 }
182 }
183
184 Ok(Arc::new(cloned_flow))
185 }
186}
187
188impl DefaultOverrideProcessor {
189 pub fn new() -> Self {
191 Self
192 }
193
194 fn validate_override_targets(
196 &self,
197 flow: &Flow,
198 overrides: &WorkflowOverrides,
199 ) -> OverrideResult<()> {
200 let step_ids: std::collections::HashSet<&String> =
201 flow.steps().iter().map(|step| &step.id).collect();
202
203 for step_id in overrides.steps.keys() {
204 if !step_ids.contains(&step_id) {
205 return Err(error_stack::report!(OverrideError::StepNotFound {
206 step_id: step_id.clone(),
207 }));
208 }
209 }
210
211 Ok(())
212 }
213
214 fn apply_step_override(
216 &self,
217 step: &mut super::Step,
218 step_override: &StepOverride,
219 ) -> OverrideResult<()> {
220 match step_override.override_type {
221 OverrideType::MergePatch => self.apply_merge_patch(step, &step_override.value),
222 OverrideType::JsonPatch => Err(error_stack::report!(
223 OverrideError::UnsupportedOverrideType {
224 override_type: step_override.override_type.clone(),
225 }
226 )),
227 }
228 }
229
230 fn apply_merge_patch(
232 &self,
233 step: &mut super::Step,
234 patch: &serde_json::Value,
235 ) -> OverrideResult<()> {
236 let step_id = step.id.clone(); let mut step_json =
240 serde_json::to_value(&*step).change_context(OverrideError::MergePatchFailed {
241 step_id: step_id.clone(),
242 reason: "Failed to serialize step to JSON".to_string(),
243 })?;
244
245 json_patch::merge(&mut step_json, patch);
247
248 *step =
250 serde_json::from_value(step_json).change_context(OverrideError::MergePatchFailed {
251 step_id,
252 reason: "Failed to deserialize modified step from JSON".to_string(),
253 })?;
254
255 Ok(())
256 }
257}
258
259impl Default for DefaultOverrideProcessor {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265pub fn apply_overrides(
267 flow: Arc<Flow>,
268 overrides: &WorkflowOverrides,
269) -> OverrideResult<Arc<Flow>> {
270 DefaultOverrideProcessor::new().apply_overrides(flow, overrides)
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::ValueExpr;
277 use serde_json::json;
278
279 fn create_test_flow() -> Flow {
280 Flow {
281 name: Some("test_flow".to_string()),
282 description: None,
283 version: None,
284 schemas: super::super::FlowSchema::default(),
285 steps: vec![super::super::Step {
286 id: "step1".to_string(),
287 component: super::super::Component::from_string("/test/component"),
288 on_error: None,
289 input: ValueExpr::null(),
290 must_execute: None,
291 metadata: std::collections::HashMap::new(),
292 }],
293 output: ValueExpr::null(),
294 test: None,
295 examples: None,
296 metadata: std::collections::HashMap::new(),
297 }
298 }
299
300 #[test]
301 fn test_workflow_overrides_creation() {
302 let overrides = WorkflowOverrides::new();
303 assert!(overrides.is_empty());
304
305 let mut overrides = WorkflowOverrides::new();
306 overrides.add_step_override(
307 "step1".to_string(),
308 StepOverride::merge_patch(json!({"input": {"temperature": 0.8}})),
309 );
310 assert!(!overrides.is_empty());
311 assert_eq!(overrides.steps.len(), 1);
312 }
313
314 #[test]
315 fn test_step_override_creation() {
316 let override_spec = StepOverride::merge_patch(json!({"input": {"temperature": 0.8}}));
317 assert!(matches!(
318 override_spec.override_type,
319 OverrideType::MergePatch
320 ));
321 assert_eq!(override_spec.value, json!({"input": {"temperature": 0.8}}));
322
323 let override_spec = StepOverride::with_type(
324 OverrideType::MergePatch,
325 json!({"component": "/different/component"}),
326 );
327 assert!(matches!(
328 override_spec.override_type,
329 OverrideType::MergePatch
330 ));
331 assert_eq!(
332 override_spec.value,
333 json!({"component": "/different/component"})
334 );
335 }
336
337 #[test]
338 fn test_apply_empty_overrides() {
339 let flow = Arc::new(create_test_flow());
340 let overrides = WorkflowOverrides::new();
341
342 let original_step_count = flow.steps().len();
343 let result = apply_overrides(flow, &overrides).unwrap();
344 assert_eq!(result.steps().len(), original_step_count);
345 }
346
347 #[test]
348 fn test_apply_merge_patch_override() {
349 let flow = Arc::new(create_test_flow());
350 let mut overrides = WorkflowOverrides::new();
351 overrides.add_step_override(
352 "step1".to_string(),
353 StepOverride::merge_patch(json!({
354 "input": {"temperature": 0.8},
355 "component": "/new/component"
356 })),
357 );
358
359 let result = apply_overrides(flow, &overrides).unwrap();
360 let step = &result.steps()[0];
361
362 assert_eq!(step.component.to_string(), "/new/component");
364
365 }
368
369 #[test]
370 fn test_validate_override_targets_missing_step() {
371 let flow = Arc::new(create_test_flow());
372 let mut overrides = WorkflowOverrides::new();
373 overrides.add_step_override(
374 "nonexistent_step".to_string(),
375 StepOverride::merge_patch(json!({"input": {"temperature": 0.8}})),
376 );
377
378 let result = apply_overrides(flow, &overrides);
379 assert!(result.is_err());
380
381 let error = result.unwrap_err();
382 assert!(
383 error
384 .to_string()
385 .contains("Step 'nonexistent_step' not found in workflow")
386 );
387 }
388
389 #[test]
390 fn test_serde_override_type_default() {
391 let json_str = r#"{"value": {"temperature": 0.8}}"#;
392 let step_override: StepOverride = serde_json::from_str(json_str).unwrap();
393
394 assert!(matches!(
395 step_override.override_type,
396 OverrideType::MergePatch
397 ));
398 assert_eq!(step_override.value, json!({"temperature": 0.8}));
399 }
400
401 #[test]
402 fn test_serde_override_type_explicit() {
403 let json_str = r#"{"$type": "merge_patch", "value": {"temperature": 0.8}}"#;
404 let step_override: StepOverride = serde_json::from_str(json_str).unwrap();
405
406 assert!(matches!(
407 step_override.override_type,
408 OverrideType::MergePatch
409 ));
410 assert_eq!(step_override.value, json!({"temperature": 0.8}));
411 }
412}