Skip to main content

shadow_network_sim/
correlation.rs

1//! Traffic correlation analysis — detect relationships between flows.
2
3use serde::{Serialize, Deserialize};
4
5/// A recorded traffic flow for correlation analysis.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct FlowRecord {
8    /// Unique flow identifier.
9    pub flow_id: String,
10    /// Packet sizes in order.
11    pub packet_sizes: Vec<usize>,
12    /// Timestamps in microseconds.
13    pub timestamps_us: Vec<u64>,
14    /// Direction markers (true = outgoing, false = incoming).
15    pub directions: Vec<bool>,
16    /// Source label (for testing purposes).
17    pub source_label: String,
18}
19
20impl FlowRecord {
21    pub fn new(flow_id: impl Into<String>) -> Self {
22        Self {
23            flow_id: flow_id.into(),
24            packet_sizes: Vec::new(),
25            timestamps_us: Vec::new(),
26            directions: Vec::new(),
27            source_label: String::new(),
28        }
29    }
30
31    /// Add a packet to this flow.
32    pub fn add_packet(&mut self, size: usize, timestamp_us: u64, outgoing: bool) {
33        self.packet_sizes.push(size);
34        self.timestamps_us.push(timestamp_us);
35        self.directions.push(outgoing);
36    }
37
38    /// Duration of the flow in microseconds.
39    pub fn duration_us(&self) -> u64 {
40        if self.timestamps_us.len() < 2 {
41            return 0;
42        }
43        self.timestamps_us.last().unwrap() - self.timestamps_us.first().unwrap()
44    }
45
46    /// Total bytes transferred.
47    pub fn total_bytes(&self) -> usize {
48        self.packet_sizes.iter().sum()
49    }
50
51    /// Inter-packet delays in microseconds.
52    pub fn inter_packet_delays(&self) -> Vec<u64> {
53        self.timestamps_us
54            .windows(2)
55            .map(|w| w[1] - w[0])
56            .collect()
57    }
58
59    /// Average packet size.
60    pub fn avg_packet_size(&self) -> f64 {
61        if self.packet_sizes.is_empty() {
62            return 0.0;
63        }
64        self.packet_sizes.iter().sum::<usize>() as f64 / self.packet_sizes.len() as f64
65    }
66
67    /// Packet count.
68    pub fn packet_count(&self) -> usize {
69        self.packet_sizes.len()
70    }
71}
72
73/// Result of correlating two flows.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct CorrelationResult {
76    pub flow_a: String,
77    pub flow_b: String,
78    /// Timing correlation coefficient (0.0–1.0).
79    pub timing_correlation: f64,
80    /// Size correlation coefficient (0.0–1.0).
81    pub size_correlation: f64,
82    /// Combined correlation score.
83    pub combined_score: f64,
84    /// Whether the correlation is significant.
85    pub is_correlated: bool,
86}
87
88/// Analyzer for detecting traffic correlation patterns.
89pub struct CorrelationAnalyzer {
90    /// Threshold above which two flows are considered correlated.
91    pub correlation_threshold: f64,
92    /// Tolerance for timing offset (microseconds).
93    pub timing_tolerance_us: u64,
94}
95
96impl CorrelationAnalyzer {
97    pub fn new() -> Self {
98        Self {
99            correlation_threshold: 0.7,
100            timing_tolerance_us: 10_000, // 10ms
101        }
102    }
103
104    pub fn with_threshold(mut self, threshold: f64) -> Self {
105        self.correlation_threshold = threshold;
106        self
107    }
108
109    /// Analyze timing correlation between two flows.
110    pub fn timing_correlation(&self, a: &FlowRecord, b: &FlowRecord) -> f64 {
111        let delays_a: Vec<f64> = a.inter_packet_delays().iter().map(|&d| d as f64).collect();
112        let delays_b: Vec<f64> = b.inter_packet_delays().iter().map(|&d| d as f64).collect();
113
114        if delays_a.is_empty() || delays_b.is_empty() {
115            return 0.0;
116        }
117
118        pearson(&delays_a, &delays_b).abs()
119    }
120
121    /// Analyze size-pattern correlation between two flows.
122    pub fn size_correlation(&self, a: &FlowRecord, b: &FlowRecord) -> f64 {
123        let sizes_a: Vec<f64> = a.packet_sizes.iter().map(|&s| s as f64).collect();
124        let sizes_b: Vec<f64> = b.packet_sizes.iter().map(|&s| s as f64).collect();
125
126        if sizes_a.is_empty() || sizes_b.is_empty() {
127            return 0.0;
128        }
129
130        pearson(&sizes_a, &sizes_b).abs()
131    }
132
133    /// Full correlation analysis between two flows.
134    pub fn correlate(&self, a: &FlowRecord, b: &FlowRecord) -> CorrelationResult {
135        let timing = self.timing_correlation(a, b);
136        let size = self.size_correlation(a, b);
137
138        // Weighted combination: timing is more important
139        let combined = 0.6 * timing + 0.4 * size;
140
141        CorrelationResult {
142            flow_a: a.flow_id.clone(),
143            flow_b: b.flow_id.clone(),
144            timing_correlation: timing,
145            size_correlation: size,
146            combined_score: combined,
147            is_correlated: combined >= self.correlation_threshold,
148        }
149    }
150
151    /// Analyze all pairs from a set of flows.
152    pub fn analyze_all(&self, flows: &[FlowRecord]) -> Vec<CorrelationResult> {
153        let mut results = Vec::new();
154        for i in 0..flows.len() {
155            for j in (i + 1)..flows.len() {
156                results.push(self.correlate(&flows[i], &flows[j]));
157            }
158        }
159        results
160    }
161
162    /// Find flows correlated with a target flow.
163    pub fn find_correlated(
164        &self,
165        target: &FlowRecord,
166        candidates: &[FlowRecord],
167    ) -> Vec<CorrelationResult> {
168        candidates
169            .iter()
170            .filter(|c| c.flow_id != target.flow_id)
171            .map(|c| self.correlate(target, c))
172            .filter(|r| r.is_correlated)
173            .collect()
174    }
175
176    /// Evaluate how well traffic shaping resists correlation.
177    pub fn resistance_score(&self, original: &FlowRecord, shaped: &FlowRecord) -> f64 {
178        let result = self.correlate(original, shaped);
179        // Lower correlation = better resistance
180        1.0 - result.combined_score
181    }
182}
183
184impl Default for CorrelationAnalyzer {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189
190/// Pearson correlation coefficient.
191fn pearson(a: &[f64], b: &[f64]) -> f64 {
192    let n = a.len().min(b.len());
193    if n < 2 {
194        return 0.0;
195    }
196
197    let mean_a = a[..n].iter().sum::<f64>() / n as f64;
198    let mean_b = b[..n].iter().sum::<f64>() / n as f64;
199
200    let mut cov = 0.0;
201    let mut var_a = 0.0;
202    let mut var_b = 0.0;
203
204    for i in 0..n {
205        let da = a[i] - mean_a;
206        let db = b[i] - mean_b;
207        cov += da * db;
208        var_a += da * da;
209        var_b += db * db;
210    }
211
212    let denom = (var_a * var_b).sqrt();
213    if denom < 1e-12 {
214        0.0
215    } else {
216        cov / denom
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    fn make_flow(id: &str, sizes: Vec<usize>, timestamps: Vec<u64>) -> FlowRecord {
225        let mut flow = FlowRecord::new(id);
226        for i in 0..sizes.len() {
227            flow.add_packet(sizes[i], timestamps[i], true);
228        }
229        flow
230    }
231
232    #[test]
233    fn test_identical_flows_correlated() {
234        let analyzer = CorrelationAnalyzer::new().with_threshold(0.9);
235        let a = make_flow("a", vec![100, 200, 150, 300], vec![0, 1000, 2500, 4000]);
236        let b = make_flow("b", vec![100, 200, 150, 300], vec![50, 1050, 2550, 4050]);
237
238        let result = analyzer.correlate(&a, &b);
239        assert!(result.combined_score > 0.9);
240        assert!(result.is_correlated);
241    }
242
243    #[test]
244    fn test_different_flows_uncorrelated() {
245        let analyzer = CorrelationAnalyzer::new().with_threshold(0.7);
246        let a = make_flow("a", vec![100, 200, 100, 200], vec![0, 1000, 2000, 3000]);
247        let b = make_flow("b", vec![500, 50, 500, 50], vec![0, 5000, 5500, 20000]);
248
249        let result = analyzer.correlate(&a, &b);
250        assert!(result.combined_score < 0.7, "Different patterns should not correlate: {}", result.combined_score);
251    }
252
253    #[test]
254    fn test_resistance_score() {
255        let analyzer = CorrelationAnalyzer::new();
256        let original = make_flow("orig", vec![100, 200, 150], vec![0, 1000, 2500]);
257        let shaped = make_flow("shaped", vec![256, 256, 256], vec![0, 1000, 2000]);
258
259        let resistance = analyzer.resistance_score(&original, &shaped);
260        // Constant-size traffic should reduce correlation
261        assert!(resistance > 0.0, "Shaped traffic should show some resistance");
262    }
263
264    #[test]
265    fn test_analyze_all_pairs() {
266        let analyzer = CorrelationAnalyzer::new();
267        let flows = vec![
268            make_flow("a", vec![100, 200, 300], vec![0, 1000, 2000]),
269            make_flow("b", vec![100, 200, 300], vec![50, 1050, 2050]),
270            make_flow("c", vec![400, 50, 400], vec![0, 5000, 10000]),
271        ];
272
273        let results = analyzer.analyze_all(&flows);
274        assert_eq!(results.len(), 3); // C(3,2) = 3 pairs
275    }
276
277    #[test]
278    fn test_flow_record_stats() {
279        let flow = make_flow("test", vec![100, 200, 300], vec![0, 1000, 3000]);
280        assert_eq!(flow.packet_count(), 3);
281        assert_eq!(flow.total_bytes(), 600);
282        assert_eq!(flow.duration_us(), 3000);
283        assert!((flow.avg_packet_size() - 200.0).abs() < 0.01);
284
285        let delays = flow.inter_packet_delays();
286        assert_eq!(delays, vec![1000, 2000]);
287    }
288}