Skip to main content

simple_agents_workflow/
replay.rs

1use thiserror::Error;
2
3use crate::trace::{TraceEventKind, TraceTerminalStatus, WorkflowTrace};
4
5/// Cache policy used by replay workflows.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ReplayCachePolicy {
8    /// Always trust cached replay metadata when present.
9    Always,
10    /// Always recompute replay validation from the trace.
11    Refresh,
12    /// Use cached metadata when complete, otherwise recompute.
13    Mixed,
14}
15
16/// Replay behavior controls.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ReplayOptions {
19    /// Cache behavior for replay metadata.
20    pub cache_policy: ReplayCachePolicy,
21}
22
23impl Default for ReplayOptions {
24    fn default() -> Self {
25        Self {
26            cache_policy: ReplayCachePolicy::Refresh,
27        }
28    }
29}
30
31/// Successful replay validation report.
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct ReplayReport {
34    /// Number of validated events.
35    pub total_events: usize,
36    /// Observed terminal status.
37    pub terminal_status: TraceTerminalStatus,
38}
39
40/// Stable replay validation codes.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReplayViolationCode {
43    /// Event sequence is not monotonic.
44    NonMonotonicSequence,
45    /// Node enter/exit/error stack is invalid.
46    MismatchedNodeLifecycle,
47    /// Trace has no terminal workflow event.
48    MissingTerminalEvent,
49    /// Trace ended with unterminated entered nodes.
50    UnclosedNodeLifecycle,
51}
52
53/// A single replay validation violation.
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct ReplayViolation {
56    /// Stable error code.
57    pub code: ReplayViolationCode,
58    /// Human-readable message.
59    pub message: String,
60    /// Zero-based event index in the trace, if applicable.
61    pub event_index: Option<usize>,
62}
63
64/// Aggregate replay validation failure.
65#[derive(Debug, Clone, PartialEq, Eq, Error)]
66#[error("workflow trace replay validation failed")]
67pub struct ReplayError {
68    /// Collected structural violations.
69    pub violations: Vec<ReplayViolation>,
70}
71
72/// Validates that a recorded trace can be structurally replayed.
73pub fn replay_trace(trace: &WorkflowTrace) -> Result<ReplayReport, ReplayError> {
74    replay_trace_with_options(trace, &ReplayOptions::default())
75}
76
77/// Validates that a recorded trace can be structurally replayed with options.
78pub fn replay_trace_with_options(
79    trace: &WorkflowTrace,
80    options: &ReplayOptions,
81) -> Result<ReplayReport, ReplayError> {
82    match options.cache_policy {
83        ReplayCachePolicy::Always | ReplayCachePolicy::Refresh | ReplayCachePolicy::Mixed => {
84            replay_trace_internal(trace)
85        }
86    }
87}
88
89fn replay_trace_internal(trace: &WorkflowTrace) -> Result<ReplayReport, ReplayError> {
90    let mut violations = Vec::new();
91    let mut expected_seq = 0u64;
92    let mut stack: Vec<&str> = Vec::new();
93    let mut terminal_status = None;
94
95    for (index, event) in trace.events.iter().enumerate() {
96        if event.seq != expected_seq {
97            violations.push(ReplayViolation {
98                code: ReplayViolationCode::NonMonotonicSequence,
99                message: format!(
100                    "expected event seq {} at index {}, found {}",
101                    expected_seq, index, event.seq
102                ),
103                event_index: Some(index),
104            });
105            expected_seq = event.seq.saturating_add(1);
106        } else {
107            expected_seq = expected_seq.saturating_add(1);
108        }
109
110        match &event.kind {
111            TraceEventKind::NodeEnter { node_id } => {
112                stack.push(node_id.as_str());
113            }
114            TraceEventKind::NodeExit { node_id } | TraceEventKind::NodeError { node_id, .. } => {
115                match stack.pop() {
116                    Some(active_node) if active_node == node_id => {}
117                    Some(active_node) => violations.push(ReplayViolation {
118                        code: ReplayViolationCode::MismatchedNodeLifecycle,
119                        message: format!(
120                            "expected node '{}' to close, found '{}'",
121                            active_node, node_id
122                        ),
123                        event_index: Some(index),
124                    }),
125                    None => violations.push(ReplayViolation {
126                        code: ReplayViolationCode::MismatchedNodeLifecycle,
127                        message: format!(
128                            "node '{}' closed without a matching enter event",
129                            node_id
130                        ),
131                        event_index: Some(index),
132                    }),
133                }
134            }
135            TraceEventKind::Terminal { status } => {
136                terminal_status = Some(*status);
137            }
138        }
139    }
140
141    if terminal_status.is_none() {
142        violations.push(ReplayViolation {
143            code: ReplayViolationCode::MissingTerminalEvent,
144            message: "trace does not contain a terminal event".to_string(),
145            event_index: None,
146        });
147    }
148
149    if !stack.is_empty() {
150        violations.push(ReplayViolation {
151            code: ReplayViolationCode::UnclosedNodeLifecycle,
152            message: format!("{} node(s) remain open at end of trace", stack.len()),
153            event_index: None,
154        });
155    }
156
157    if violations.is_empty() {
158        Ok(ReplayReport {
159            total_events: trace.events.len(),
160            terminal_status: terminal_status.expect("terminal status must exist"),
161        })
162    } else {
163        Err(ReplayError { violations })
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::recorder::TraceRecorder;
170    use crate::replay::{
171        replay_trace, replay_trace_with_options, ReplayCachePolicy, ReplayOptions,
172        ReplayViolationCode,
173    };
174    use crate::trace::{
175        TraceEvent, TraceEventKind, TraceTerminalStatus, WorkflowTrace, WorkflowTraceMetadata,
176    };
177
178    fn metadata() -> WorkflowTraceMetadata {
179        WorkflowTraceMetadata {
180            trace_id: "trace-1".to_string(),
181            workflow_name: "demo".to_string(),
182            workflow_version: "v0".to_string(),
183            started_at_unix_ms: 100,
184            finished_at_unix_ms: None,
185        }
186    }
187
188    #[test]
189    fn replays_valid_trace() {
190        let recorder = TraceRecorder::new(metadata());
191        recorder.record_node_enter(101, "start").unwrap();
192        recorder.record_node_exit(102, "start").unwrap();
193        recorder
194            .record_terminal(103, TraceTerminalStatus::Completed)
195            .unwrap();
196
197        let trace = recorder.finalize(104).unwrap();
198        let report = replay_trace(&trace).expect("valid trace should replay");
199
200        assert_eq!(report.total_events, 3);
201        assert_eq!(report.terminal_status, TraceTerminalStatus::Completed);
202    }
203
204    #[test]
205    fn rejects_out_of_order_sequence() {
206        let trace = WorkflowTrace {
207            metadata: metadata(),
208            events: vec![
209                TraceEvent {
210                    seq: 0,
211                    timestamp_unix_ms: 101,
212                    kind: TraceEventKind::NodeEnter {
213                        node_id: "start".to_string(),
214                    },
215                },
216                TraceEvent {
217                    seq: 2,
218                    timestamp_unix_ms: 102,
219                    kind: TraceEventKind::NodeExit {
220                        node_id: "start".to_string(),
221                    },
222                },
223                TraceEvent {
224                    seq: 3,
225                    timestamp_unix_ms: 103,
226                    kind: TraceEventKind::Terminal {
227                        status: TraceTerminalStatus::Completed,
228                    },
229                },
230            ],
231        };
232
233        let err = replay_trace(&trace).expect_err("should reject non-monotonic sequence");
234        assert!(err
235            .violations
236            .iter()
237            .any(|v| v.code == ReplayViolationCode::NonMonotonicSequence));
238    }
239
240    #[test]
241    fn rejects_missing_terminal_event() {
242        let trace = WorkflowTrace {
243            metadata: metadata(),
244            events: vec![
245                TraceEvent {
246                    seq: 0,
247                    timestamp_unix_ms: 101,
248                    kind: TraceEventKind::NodeEnter {
249                        node_id: "start".to_string(),
250                    },
251                },
252                TraceEvent {
253                    seq: 1,
254                    timestamp_unix_ms: 102,
255                    kind: TraceEventKind::NodeExit {
256                        node_id: "start".to_string(),
257                    },
258                },
259            ],
260        };
261
262        let err = replay_trace(&trace).expect_err("should reject missing terminal event");
263        assert!(err
264            .violations
265            .iter()
266            .any(|v| v.code == ReplayViolationCode::MissingTerminalEvent));
267    }
268
269    #[test]
270    fn rejects_mismatched_enter_exit() {
271        let trace = WorkflowTrace {
272            metadata: metadata(),
273            events: vec![
274                TraceEvent {
275                    seq: 0,
276                    timestamp_unix_ms: 101,
277                    kind: TraceEventKind::NodeEnter {
278                        node_id: "a".to_string(),
279                    },
280                },
281                TraceEvent {
282                    seq: 1,
283                    timestamp_unix_ms: 102,
284                    kind: TraceEventKind::NodeExit {
285                        node_id: "b".to_string(),
286                    },
287                },
288                TraceEvent {
289                    seq: 2,
290                    timestamp_unix_ms: 103,
291                    kind: TraceEventKind::Terminal {
292                        status: TraceTerminalStatus::Failed,
293                    },
294                },
295            ],
296        };
297
298        let err = replay_trace(&trace).expect_err("should reject mismatched lifecycle");
299        assert!(err
300            .violations
301            .iter()
302            .any(|v| v.code == ReplayViolationCode::MismatchedNodeLifecycle));
303    }
304
305    #[test]
306    fn supports_cache_policy_options() {
307        let recorder = TraceRecorder::new(metadata());
308        recorder.record_node_enter(101, "start").unwrap();
309        recorder.record_node_exit(102, "start").unwrap();
310        recorder
311            .record_terminal(103, TraceTerminalStatus::Completed)
312            .unwrap();
313        let trace = recorder.finalize(104).unwrap();
314
315        let report = replay_trace_with_options(
316            &trace,
317            &ReplayOptions {
318                cache_policy: ReplayCachePolicy::Mixed,
319            },
320        )
321        .expect("mixed policy should replay trace");
322        assert_eq!(report.total_events, 3);
323    }
324}