Skip to main content

simple_agents_workflow/yaml_runner/
recovery.rs

1use super::output::{RunMetadata, StepTiming, TokenTotals};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use simple_agent_type::message::Message;
5use std::collections::BTreeMap;
6
7/// Checkpoint captured on workflow failure, enabling retry from the failed node.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct WorkflowCheckpoint {
10    /// Path to the workflow YAML file.
11    pub workflow_path: String,
12    /// ID of the node that failed.
13    pub failed_node_id: String,
14    /// Ordered list of nodes that completed before failure.
15    pub completed_trace: Vec<String>,
16    /// Outputs from completed nodes.
17    pub completed_outputs: BTreeMap<String, Value>,
18    /// Global variable state at time of failure.
19    pub globals: BTreeMap<String, Value>,
20    /// The original input messages.
21    pub original_messages: Vec<Message>,
22    /// Step timings collected before failure.
23    pub step_timings: Vec<StepTiming>,
24    /// Token usage collected before failure.
25    pub token_totals: TokenTotals,
26}
27
28/// Partial output returned when a workflow fails after some nodes succeed.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct PartialWorkflowOutput {
31    /// Workflow identifier.
32    pub workflow_id: String,
33    /// Nodes that completed before failure.
34    pub completed_trace: Vec<String>,
35    /// Outputs from completed nodes.
36    pub completed_outputs: BTreeMap<String, Value>,
37    /// The node that failed.
38    pub failed_node_id: String,
39    /// Error description.
40    pub error: String,
41    /// Checkpoint for retrying.
42    pub checkpoint: WorkflowCheckpoint,
43    /// Optional performance metadata.
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub nerdstats: Option<RunMetadata>,
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51
52    #[test]
53    fn test_checkpoint_serialization_roundtrip() {
54        let cp = WorkflowCheckpoint {
55            workflow_path: "test.yaml".into(),
56            failed_node_id: "node_b".into(),
57            completed_trace: vec!["node_a".into()],
58            completed_outputs: BTreeMap::from([(
59                "node_a".into(),
60                serde_json::json!({"result": 1}),
61            )]),
62            globals: BTreeMap::new(),
63            original_messages: vec![Message::user("hello")],
64            step_timings: vec![],
65            token_totals: TokenTotals::default(),
66        };
67        let json = serde_json::to_string(&cp).unwrap();
68        let parsed: WorkflowCheckpoint = serde_json::from_str(&json).unwrap();
69        assert_eq!(parsed.failed_node_id, "node_b");
70        assert_eq!(parsed.completed_trace, vec!["node_a"]);
71        assert_eq!(parsed.original_messages.len(), 1);
72    }
73}