Skip to main content

stepflow_flow/workflow/
step.rs

1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13use std::collections::HashMap;
14
15use serde_with::{DefaultOnNull, serde_as};
16
17use super::Component;
18use crate::ValueExpr;
19
20/// A step in a workflow that executes a component with specific arguments.
21///
22/// Note: Step output schemas are stored in the flow's `types.steps` field,
23/// not on individual steps. This allows for shared `$defs` and avoids duplication.
24#[serde_as]
25#[derive(Clone, serde::Serialize, serde::Deserialize, Debug, PartialEq, schemars::JsonSchema)]
26#[serde(rename_all = "camelCase")]
27pub struct Step {
28    /// Identifier for the step
29    pub id: String,
30
31    /// The component to execute in this step
32    pub component: Component,
33
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub on_error: Option<ErrorAction>,
36
37    /// Arguments to pass to the component for this step
38    #[serde(default, skip_serializing_if = "ValueExpr::is_null")]
39    pub input: ValueExpr,
40
41    /// If true, this step must execute even if its output is not used by the workflow output.
42    /// Useful for steps with side effects (e.g., writing to databases, sending notifications).
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub must_execute: Option<bool>,
45
46    /// Extensible metadata for the step that can be used by tools and frameworks.
47    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
48    #[serde_as(as = "DefaultOnNull")]
49    pub metadata: HashMap<String, serde_json::Value>,
50}
51
52impl Step {
53    pub fn on_error(&self) -> Option<&ErrorAction> {
54        self.on_error.as_ref()
55    }
56
57    /// Get the effective error action, applying the default if none is specified.
58    pub fn on_error_or_default(&self) -> ErrorAction {
59        self.on_error().cloned().unwrap_or_default()
60    }
61
62    /// Check if this step must execute, treating None as false (the default).
63    pub fn must_execute(&self) -> bool {
64        self.must_execute.unwrap_or(false)
65    }
66}
67
68/// Error action determines what happens when a step fails.
69#[derive(
70    Clone, Debug, PartialEq, Default, serde::Serialize, serde::Deserialize, schemars::JsonSchema,
71)]
72#[serde(tag = "action", rename_all = "camelCase")]
73#[schemars(transform = crate::discriminator_schema::AddDiscriminator::new("action"))]
74pub enum ErrorAction {
75    /// If the step fails, the flow will fail.
76    #[default]
77    #[schemars(title = "OnErrorFail")]
78    Fail,
79    /// If the step fails, use the `defaultValue` instead.
80    /// If `defaultValue` is not specified, the step returns null.
81    /// The default value must be a literal JSON value (not an expression).
82    /// For dynamic defaults, use `$coalesce` in the consuming expression instead.
83    #[serde(rename_all = "camelCase")]
84    #[schemars(title = "OnErrorDefault")]
85    UseDefault {
86        #[serde(skip_serializing_if = "Option::is_none")]
87        default_value: Option<serde_json::Value>,
88    },
89    /// If the step fails, retry it.
90    ///
91    /// `max_retries` limits retries due to component errors — cases where the
92    /// component ran and returned an error. Transport-level failures (subprocess
93    /// crashes, network errors) are retried separately according to the plugin's
94    /// retry configuration and do not count against this budget.
95    #[serde(rename_all = "camelCase")]
96    #[schemars(title = "OnErrorRetry")]
97    Retry {
98        /// Maximum number of retries due to component errors (default: 3).
99        ///
100        /// Total attempts for component errors = max_retries + 1 (initial).
101        #[serde(default, skip_serializing_if = "Option::is_none")]
102        max_retries: Option<u32>,
103    },
104}
105
106impl ErrorAction {
107    /// Default maximum retries for component errors.
108    pub const DEFAULT_MAX_RETRIES: u32 = 3;
109
110    pub fn is_default(&self) -> bool {
111        matches!(self, Self::Fail)
112    }
113
114    /// Get the max retries for component errors, if this is a Retry action.
115    pub fn max_retries(&self) -> Option<u32> {
116        match self {
117            Self::Retry { max_retries } => Some(max_retries.unwrap_or(Self::DEFAULT_MAX_RETRIES)),
118            _ => None,
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::workflow::StepBuilder;
127
128    #[test]
129    fn test_error_action_serialization() {
130        let fail = ErrorAction::Fail;
131        assert_eq!(serde_yaml_ng::to_string(&fail).unwrap(), "action: fail\n");
132
133        // Retry with default max_retries omits the field
134        let retry = ErrorAction::Retry { max_retries: None };
135        assert_eq!(serde_yaml_ng::to_string(&retry).unwrap(), "action: retry\n");
136
137        // Retry with explicit max_retries includes the field
138        let retry_with_max = ErrorAction::Retry {
139            max_retries: Some(5),
140        };
141        assert_eq!(
142            serde_yaml_ng::to_string(&retry_with_max).unwrap(),
143            "action: retry\nmaxRetries: 5\n"
144        );
145
146        let use_default = ErrorAction::UseDefault {
147            default_value: Some(serde_json::json!("test_default")),
148        };
149        assert_eq!(
150            serde_yaml_ng::to_string(&use_default).unwrap(),
151            "action: useDefault\ndefaultValue: test_default\n"
152        );
153
154        // UseDefault with no value serializes without defaultValue field
155        let use_default_none = ErrorAction::UseDefault {
156            default_value: None,
157        };
158        assert_eq!(
159            serde_yaml_ng::to_string(&use_default_none).unwrap(),
160            "action: useDefault\n"
161        );
162    }
163
164    #[test]
165    fn test_error_action_deserialization() {
166        let fail: ErrorAction = serde_yaml_ng::from_str("action: fail").unwrap();
167        assert_eq!(fail, ErrorAction::Fail);
168
169        // Retry without maxRetries
170        let retry: ErrorAction = serde_yaml_ng::from_str("action: retry").unwrap();
171        assert_eq!(retry, ErrorAction::Retry { max_retries: None });
172
173        // Retry with maxRetries
174        let retry_with_max: ErrorAction =
175            serde_yaml_ng::from_str("action: retry\nmaxRetries: 5").unwrap();
176        assert_eq!(
177            retry_with_max,
178            ErrorAction::Retry {
179                max_retries: Some(5)
180            }
181        );
182
183        let use_default: ErrorAction =
184            serde_yaml_ng::from_str("action: useDefault\ndefaultValue: test_default").unwrap();
185        assert_eq!(
186            use_default,
187            ErrorAction::UseDefault {
188                default_value: Some(serde_json::json!("test_default"))
189            }
190        );
191    }
192
193    #[test]
194    fn test_error_action_default() {
195        assert_eq!(ErrorAction::default(), ErrorAction::Fail);
196        assert!(ErrorAction::Fail.is_default());
197        assert!(!ErrorAction::Retry { max_retries: None }.is_default());
198        assert!(
199            !ErrorAction::UseDefault {
200                default_value: Some(serde_json::json!("test"))
201            }
202            .is_default()
203        );
204    }
205
206    #[test]
207    fn test_error_action_max_retries() {
208        assert_eq!(ErrorAction::Fail.max_retries(), None);
209        assert_eq!(
210            ErrorAction::Retry { max_retries: None }.max_retries(),
211            Some(ErrorAction::DEFAULT_MAX_RETRIES)
212        );
213        assert_eq!(
214            ErrorAction::Retry {
215                max_retries: Some(5)
216            }
217            .max_retries(),
218            Some(5)
219        );
220    }
221
222    #[test]
223    fn test_step_serialization_with_error_action() {
224        let step = StepBuilder::new("test_step")
225            .component("/mock/test_component")
226            .on_error(ErrorAction::UseDefault {
227                default_value: Some(serde_json::json!("fallback")),
228            })
229            .input(ValueExpr::null())
230            .build();
231
232        let yaml = serde_yaml_ng::to_string(&step).unwrap();
233        assert!(yaml.contains("onError:"));
234        assert!(yaml.contains("action: useDefault"));
235        assert!(yaml.contains("defaultValue: fallback"));
236    }
237
238    #[test]
239    fn test_step_default_error_action_not_serialized() {
240        let step = StepBuilder::new("test_step")
241            .component("/mock/test_component")
242            .input(ValueExpr::null())
243            .build();
244
245        let yaml = serde_yaml_ng::to_string(&step).unwrap();
246        assert!(!yaml.contains("onError:"));
247    }
248
249    #[test]
250    fn test_step_all_optional_null() {
251        // All optional/defaulted Step fields as explicit null — simulates a Python
252        // client calling model_dump() without exclude_none=True.
253        let json = serde_json::json!({
254            "id": "test_step",
255            "component": "/mock/test_component",
256            "onError": null,
257            "input": null,
258            "mustExecute": null,
259            "metadata": null,
260        });
261        let step: Step = serde_json::from_value(json).unwrap();
262        assert_eq!(step.id, "test_step");
263        assert!(step.on_error.is_none());
264        assert_eq!(step.on_error_or_default(), ErrorAction::Fail);
265        assert!(step.input.is_null());
266        assert!(step.must_execute.is_none());
267        assert!(step.metadata.is_empty());
268    }
269
270    #[test]
271    fn test_error_action_use_default_null_value() {
272        // defaultValue: null means "no default value" (use null as the step output)
273        let json = serde_json::json!({"action": "useDefault", "defaultValue": null});
274        let action: ErrorAction = serde_json::from_value(json).unwrap();
275        assert!(matches!(
276            action,
277            ErrorAction::UseDefault {
278                default_value: None
279            }
280        ));
281    }
282
283    #[test]
284    fn test_error_action_retry_null_max_retries() {
285        // maxRetries: null means "use the built-in default"
286        let json = serde_json::json!({"action": "retry", "maxRetries": null});
287        let action: ErrorAction = serde_json::from_value(json).unwrap();
288        assert!(matches!(action, ErrorAction::Retry { max_retries: None }));
289    }
290}