Skip to main content

vellaveto_engine/
cascade_graph.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4//
5// Copyright 2026 Paolo Vella
6// SPDX-License-Identifier: MPL-2.0
7
8//! Cascading failure propagation graph (OWASP ASI08).
9//!
10//! Tracks error propagation across tool calls and agent interactions
11//! to detect cascading failure patterns that could indicate systemic
12//! compromise or denial-of-service propagation.
13
14use std::collections::HashMap;
15
16/// Maximum tracked nodes in the failure graph.
17const MAX_NODES: usize = 500;
18
19/// A failure propagation event.
20#[derive(Debug, Clone)]
21pub struct FailureEvent {
22    pub source: String,
23    pub failure_type: FailureType,
24    pub timestamp_ms: u64,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum FailureType {
29    /// Tool call timed out.
30    Timeout,
31    /// Tool returned an error.
32    ToolError,
33    /// Policy denied the call.
34    PolicyDenial,
35    /// Circuit breaker tripped.
36    CircuitBreaker,
37    /// Rate limit exceeded.
38    RateLimit,
39}
40
41/// A cascade finding.
42#[derive(Debug, Clone)]
43pub struct CascadeFinding {
44    pub cascade_depth: usize,
45    pub affected_tools: Vec<String>,
46    pub trigger: String,
47    pub description: String,
48}
49
50/// Tracks failure propagation across tool calls.
51pub struct CascadeGraph {
52    /// Per-tool failure counts in current window.
53    failure_counts: HashMap<String, Vec<FailureEvent>>,
54    /// Time window for cascade detection (ms).
55    window_ms: u64,
56    /// Threshold: failures in window to flag cascade.
57    cascade_threshold: usize,
58}
59
60impl CascadeGraph {
61    pub fn new(window_ms: u64, cascade_threshold: usize) -> Self {
62        Self {
63            failure_counts: HashMap::new(),
64            window_ms,
65            cascade_threshold: cascade_threshold.max(2),
66        }
67    }
68
69    /// Record a failure and check for cascade patterns.
70    pub fn record_failure(
71        &mut self,
72        tool_name: &str,
73        failure_type: FailureType,
74    ) -> Option<CascadeFinding> {
75        let now = now_ms();
76
77        if self.failure_counts.len() >= MAX_NODES && !self.failure_counts.contains_key(tool_name) {
78            return None;
79        }
80
81        let events = self
82            .failure_counts
83            .entry(tool_name[..tool_name.len().min(256)].to_string())
84            .or_default();
85        events.push(FailureEvent {
86            source: tool_name[..tool_name.len().min(256)].to_string(),
87            failure_type,
88            timestamp_ms: now,
89        });
90
91        // Prune old events
92        let cutoff = now.saturating_sub(self.window_ms);
93        events.retain(|e| e.timestamp_ms >= cutoff);
94
95        // Count distinct failing tools in the window
96        let mut failing_tools = Vec::new();
97        for (tool, tool_events) in &self.failure_counts {
98            let recent = tool_events
99                .iter()
100                .filter(|e| e.timestamp_ms >= cutoff)
101                .count();
102            if recent > 0 {
103                failing_tools.push(tool.clone());
104            }
105        }
106
107        if failing_tools.len() >= self.cascade_threshold {
108            Some(CascadeFinding {
109                cascade_depth: failing_tools.len(),
110                affected_tools: failing_tools,
111                trigger: tool_name.to_string(),
112                description: format!(
113                    "Cascading failure: {} tools failing within {}ms window",
114                    self.failure_counts
115                        .values()
116                        .filter(|v| v.iter().any(|e| e.timestamp_ms >= cutoff))
117                        .count(),
118                    self.window_ms
119                ),
120            })
121        } else {
122            None
123        }
124    }
125
126    /// Get current failure count across all tools.
127    pub fn total_failures_in_window(&self) -> usize {
128        let cutoff = now_ms().saturating_sub(self.window_ms);
129        self.failure_counts
130            .values()
131            .flat_map(|v| v.iter())
132            .filter(|e| e.timestamp_ms >= cutoff)
133            .count()
134    }
135}
136
137fn now_ms() -> u64 {
138    std::time::SystemTime::now()
139        .duration_since(std::time::UNIX_EPOCH)
140        .map(|d| d.as_millis() as u64)
141        .unwrap_or(0)
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_single_failure_no_cascade() {
150        let mut graph = CascadeGraph::new(60_000, 3);
151        let finding = graph.record_failure("tool_a", FailureType::Timeout);
152        assert!(finding.is_none());
153    }
154
155    #[test]
156    fn test_cascade_detected() {
157        let mut graph = CascadeGraph::new(60_000, 3);
158        graph.record_failure("tool_a", FailureType::Timeout);
159        graph.record_failure("tool_b", FailureType::ToolError);
160        let finding = graph.record_failure("tool_c", FailureType::CircuitBreaker);
161        assert!(finding.is_some());
162        let f = finding.unwrap();
163        assert!(f.cascade_depth >= 3);
164    }
165
166    #[test]
167    fn test_total_failures() {
168        let mut graph = CascadeGraph::new(60_000, 10);
169        graph.record_failure("tool_a", FailureType::Timeout);
170        graph.record_failure("tool_a", FailureType::Timeout);
171        graph.record_failure("tool_b", FailureType::ToolError);
172        assert_eq!(graph.total_failures_in_window(), 3);
173    }
174
175    #[test]
176    fn test_capacity_bounded() {
177        let mut graph = CascadeGraph::new(60_000, 1000);
178        for i in 0..MAX_NODES + 50 {
179            graph.record_failure(&format!("tool_{i}"), FailureType::Timeout);
180        }
181        assert!(graph.failure_counts.len() <= MAX_NODES);
182    }
183}