1use serde::{Serialize, Deserialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct FlowRecord {
8 pub flow_id: String,
10 pub packet_sizes: Vec<usize>,
12 pub timestamps_us: Vec<u64>,
14 pub directions: Vec<bool>,
16 pub source_label: String,
18}
19
20impl FlowRecord {
21 pub fn new(flow_id: impl Into<String>) -> Self {
22 Self {
23 flow_id: flow_id.into(),
24 packet_sizes: Vec::new(),
25 timestamps_us: Vec::new(),
26 directions: Vec::new(),
27 source_label: String::new(),
28 }
29 }
30
31 pub fn add_packet(&mut self, size: usize, timestamp_us: u64, outgoing: bool) {
33 self.packet_sizes.push(size);
34 self.timestamps_us.push(timestamp_us);
35 self.directions.push(outgoing);
36 }
37
38 pub fn duration_us(&self) -> u64 {
40 if self.timestamps_us.len() < 2 {
41 return 0;
42 }
43 self.timestamps_us.last().unwrap() - self.timestamps_us.first().unwrap()
44 }
45
46 pub fn total_bytes(&self) -> usize {
48 self.packet_sizes.iter().sum()
49 }
50
51 pub fn inter_packet_delays(&self) -> Vec<u64> {
53 self.timestamps_us
54 .windows(2)
55 .map(|w| w[1] - w[0])
56 .collect()
57 }
58
59 pub fn avg_packet_size(&self) -> f64 {
61 if self.packet_sizes.is_empty() {
62 return 0.0;
63 }
64 self.packet_sizes.iter().sum::<usize>() as f64 / self.packet_sizes.len() as f64
65 }
66
67 pub fn packet_count(&self) -> usize {
69 self.packet_sizes.len()
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct CorrelationResult {
76 pub flow_a: String,
77 pub flow_b: String,
78 pub timing_correlation: f64,
80 pub size_correlation: f64,
82 pub combined_score: f64,
84 pub is_correlated: bool,
86}
87
88pub struct CorrelationAnalyzer {
90 pub correlation_threshold: f64,
92 pub timing_tolerance_us: u64,
94}
95
96impl CorrelationAnalyzer {
97 pub fn new() -> Self {
98 Self {
99 correlation_threshold: 0.7,
100 timing_tolerance_us: 10_000, }
102 }
103
104 pub fn with_threshold(mut self, threshold: f64) -> Self {
105 self.correlation_threshold = threshold;
106 self
107 }
108
109 pub fn timing_correlation(&self, a: &FlowRecord, b: &FlowRecord) -> f64 {
111 let delays_a: Vec<f64> = a.inter_packet_delays().iter().map(|&d| d as f64).collect();
112 let delays_b: Vec<f64> = b.inter_packet_delays().iter().map(|&d| d as f64).collect();
113
114 if delays_a.is_empty() || delays_b.is_empty() {
115 return 0.0;
116 }
117
118 pearson(&delays_a, &delays_b).abs()
119 }
120
121 pub fn size_correlation(&self, a: &FlowRecord, b: &FlowRecord) -> f64 {
123 let sizes_a: Vec<f64> = a.packet_sizes.iter().map(|&s| s as f64).collect();
124 let sizes_b: Vec<f64> = b.packet_sizes.iter().map(|&s| s as f64).collect();
125
126 if sizes_a.is_empty() || sizes_b.is_empty() {
127 return 0.0;
128 }
129
130 pearson(&sizes_a, &sizes_b).abs()
131 }
132
133 pub fn correlate(&self, a: &FlowRecord, b: &FlowRecord) -> CorrelationResult {
135 let timing = self.timing_correlation(a, b);
136 let size = self.size_correlation(a, b);
137
138 let combined = 0.6 * timing + 0.4 * size;
140
141 CorrelationResult {
142 flow_a: a.flow_id.clone(),
143 flow_b: b.flow_id.clone(),
144 timing_correlation: timing,
145 size_correlation: size,
146 combined_score: combined,
147 is_correlated: combined >= self.correlation_threshold,
148 }
149 }
150
151 pub fn analyze_all(&self, flows: &[FlowRecord]) -> Vec<CorrelationResult> {
153 let mut results = Vec::new();
154 for i in 0..flows.len() {
155 for j in (i + 1)..flows.len() {
156 results.push(self.correlate(&flows[i], &flows[j]));
157 }
158 }
159 results
160 }
161
162 pub fn find_correlated(
164 &self,
165 target: &FlowRecord,
166 candidates: &[FlowRecord],
167 ) -> Vec<CorrelationResult> {
168 candidates
169 .iter()
170 .filter(|c| c.flow_id != target.flow_id)
171 .map(|c| self.correlate(target, c))
172 .filter(|r| r.is_correlated)
173 .collect()
174 }
175
176 pub fn resistance_score(&self, original: &FlowRecord, shaped: &FlowRecord) -> f64 {
178 let result = self.correlate(original, shaped);
179 1.0 - result.combined_score
181 }
182}
183
184impl Default for CorrelationAnalyzer {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190fn pearson(a: &[f64], b: &[f64]) -> f64 {
192 let n = a.len().min(b.len());
193 if n < 2 {
194 return 0.0;
195 }
196
197 let mean_a = a[..n].iter().sum::<f64>() / n as f64;
198 let mean_b = b[..n].iter().sum::<f64>() / n as f64;
199
200 let mut cov = 0.0;
201 let mut var_a = 0.0;
202 let mut var_b = 0.0;
203
204 for i in 0..n {
205 let da = a[i] - mean_a;
206 let db = b[i] - mean_b;
207 cov += da * db;
208 var_a += da * da;
209 var_b += db * db;
210 }
211
212 let denom = (var_a * var_b).sqrt();
213 if denom < 1e-12 {
214 0.0
215 } else {
216 cov / denom
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 fn make_flow(id: &str, sizes: Vec<usize>, timestamps: Vec<u64>) -> FlowRecord {
225 let mut flow = FlowRecord::new(id);
226 for i in 0..sizes.len() {
227 flow.add_packet(sizes[i], timestamps[i], true);
228 }
229 flow
230 }
231
232 #[test]
233 fn test_identical_flows_correlated() {
234 let analyzer = CorrelationAnalyzer::new().with_threshold(0.9);
235 let a = make_flow("a", vec![100, 200, 150, 300], vec![0, 1000, 2500, 4000]);
236 let b = make_flow("b", vec![100, 200, 150, 300], vec![50, 1050, 2550, 4050]);
237
238 let result = analyzer.correlate(&a, &b);
239 assert!(result.combined_score > 0.9);
240 assert!(result.is_correlated);
241 }
242
243 #[test]
244 fn test_different_flows_uncorrelated() {
245 let analyzer = CorrelationAnalyzer::new().with_threshold(0.7);
246 let a = make_flow("a", vec![100, 200, 100, 200], vec![0, 1000, 2000, 3000]);
247 let b = make_flow("b", vec![500, 50, 500, 50], vec![0, 5000, 5500, 20000]);
248
249 let result = analyzer.correlate(&a, &b);
250 assert!(result.combined_score < 0.7, "Different patterns should not correlate: {}", result.combined_score);
251 }
252
253 #[test]
254 fn test_resistance_score() {
255 let analyzer = CorrelationAnalyzer::new();
256 let original = make_flow("orig", vec![100, 200, 150], vec![0, 1000, 2500]);
257 let shaped = make_flow("shaped", vec![256, 256, 256], vec![0, 1000, 2000]);
258
259 let resistance = analyzer.resistance_score(&original, &shaped);
260 assert!(resistance > 0.0, "Shaped traffic should show some resistance");
262 }
263
264 #[test]
265 fn test_analyze_all_pairs() {
266 let analyzer = CorrelationAnalyzer::new();
267 let flows = vec![
268 make_flow("a", vec![100, 200, 300], vec![0, 1000, 2000]),
269 make_flow("b", vec![100, 200, 300], vec![50, 1050, 2050]),
270 make_flow("c", vec![400, 50, 400], vec![0, 5000, 10000]),
271 ];
272
273 let results = analyzer.analyze_all(&flows);
274 assert_eq!(results.len(), 3); }
276
277 #[test]
278 fn test_flow_record_stats() {
279 let flow = make_flow("test", vec![100, 200, 300], vec![0, 1000, 3000]);
280 assert_eq!(flow.packet_count(), 3);
281 assert_eq!(flow.total_bytes(), 600);
282 assert_eq!(flow.duration_us(), 3000);
283 assert!((flow.avg_packet_size() - 200.0).abs() < 0.01);
284
285 let delays = flow.inter_packet_delays();
286 assert_eq!(delays, vec![1000, 2000]);
287 }
288}