1use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5
6use super::types::{Signal, SignalType};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum CorrelationType {
12 EntityCluster,
14 SignalChain,
16 TemporalCorrelation,
18 FingerprintFamily,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Correlation {
25 pub id: String,
26 pub correlation_type: CorrelationType,
27 pub strength: f64,
29 pub entities: Vec<String>,
31 pub signals: Vec<Signal>,
33 pub description: String,
34 pub detected_at: i64,
35 pub metadata: CorrelationMetadata,
36}
37
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
40pub struct CorrelationMetadata {
41 pub shared_value: Option<String>,
42 pub signal_count: Option<usize>,
43 pub time_window: Option<i64>,
44}
45
46#[derive(Debug, Clone, Default)]
48pub struct CorrelationQueryOptions {
49 pub correlation_type: Option<CorrelationType>,
50 pub entity_id: Option<String>,
51 pub signal_type: Option<SignalType>,
52 pub from: Option<i64>,
53 pub to: Option<i64>,
54 pub min_strength: Option<f64>,
55 pub limit: Option<usize>,
56}
57
58pub struct CorrelationEngine {
60 min_cluster_size: usize,
62 temporal_window_ms: i64,
64 min_strength: f64,
66}
67
68impl Default for CorrelationEngine {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl CorrelationEngine {
75 pub fn new() -> Self {
77 Self {
78 min_cluster_size: 3,
79 temporal_window_ms: 60_000,
80 min_strength: 0.5,
81 }
82 }
83
84 pub fn with_settings(
86 min_cluster_size: usize,
87 temporal_window_ms: i64,
88 min_strength: f64,
89 ) -> Self {
90 Self {
91 min_cluster_size,
92 temporal_window_ms,
93 min_strength,
94 }
95 }
96
97 pub fn find_correlations(
99 &self,
100 signals: &[Signal],
101 options: &CorrelationQueryOptions,
102 ) -> Vec<Correlation> {
103 let mut correlations = Vec::new();
104
105 if options.correlation_type.is_none()
107 || options.correlation_type == Some(CorrelationType::EntityCluster)
108 {
109 correlations.extend(self.find_entity_clusters(signals));
110 }
111
112 if options.correlation_type.is_none()
114 || options.correlation_type == Some(CorrelationType::TemporalCorrelation)
115 {
116 correlations.extend(self.find_temporal_correlations(signals));
117 }
118
119 if options.correlation_type.is_none()
121 || options.correlation_type == Some(CorrelationType::FingerprintFamily)
122 {
123 correlations.extend(self.find_fingerprint_families(signals));
124 }
125
126 let mut filtered = correlations
128 .into_iter()
129 .filter(|c| {
130 if let Some(ref entity_id) = options.entity_id {
131 if !c.entities.contains(entity_id) {
132 return false;
133 }
134 }
135 if let Some(min_str) = options.min_strength {
136 if c.strength < min_str {
137 return false;
138 }
139 }
140 if let Some(from) = options.from {
141 if c.detected_at < from {
142 return false;
143 }
144 }
145 if let Some(to) = options.to {
146 if c.detected_at > to {
147 return false;
148 }
149 }
150 true
151 })
152 .collect::<Vec<_>>();
153
154 filtered.sort_by(|a, b| {
156 b.strength
157 .partial_cmp(&a.strength)
158 .unwrap_or(std::cmp::Ordering::Equal)
159 });
160
161 if let Some(limit) = options.limit {
163 filtered.truncate(limit);
164 }
165
166 filtered
167 }
168
169 fn find_entity_clusters(&self, signals: &[Signal]) -> Vec<Correlation> {
171 let mut correlations = Vec::new();
172
173 let mut value_entities: HashMap<String, HashSet<String>> = HashMap::new();
175 for signal in signals {
176 value_entities
177 .entry(signal.value.clone())
178 .or_default()
179 .insert(signal.entity_id.clone());
180 }
181
182 for (value, entities) in value_entities {
183 let entity_count = entities.len();
184 if entity_count >= self.min_cluster_size {
185 let strength = (entity_count as f64 - 2.0) / 10.0;
186 let strength = strength.min(1.0).max(self.min_strength);
187
188 correlations.push(Correlation {
189 id: uuid::Uuid::new_v4().to_string(),
190 correlation_type: CorrelationType::EntityCluster,
191 strength,
192 entities: entities.into_iter().collect(),
193 signals: signals
194 .iter()
195 .filter(|s| s.value == value)
196 .cloned()
197 .collect(),
198 description: format!("Entity cluster: {} IPs share signal value", entity_count),
199 detected_at: chrono::Utc::now().timestamp_millis(),
200 metadata: CorrelationMetadata {
201 shared_value: Some(value[..16.min(value.len())].to_string()),
202 signal_count: Some(signals.iter().filter(|s| s.value == value).count()),
203 ..Default::default()
204 },
205 });
206 }
207 }
208
209 correlations
210 }
211
212 fn find_temporal_correlations(&self, signals: &[Signal]) -> Vec<Correlation> {
214 let mut correlations = Vec::new();
215
216 if signals.len() < 2 {
217 return correlations;
218 }
219
220 let mut sorted = signals.to_vec();
222 sorted.sort_by_key(|s| s.timestamp);
223
224 let mut window_start = 0;
226 for i in 0..sorted.len() {
227 while sorted[i].timestamp - sorted[window_start].timestamp > self.temporal_window_ms {
229 window_start += 1;
230 }
231
232 let window = &sorted[window_start..=i];
234 let entities: HashSet<_> = window.iter().map(|s| &s.entity_id).collect();
235
236 let entity_count = entities.len();
237 if entity_count >= self.min_cluster_size {
238 let strength = (entity_count as f64 - 2.0) / 8.0;
240 let strength = strength.min(1.0).max(self.min_strength);
241
242 correlations.push(Correlation {
243 id: uuid::Uuid::new_v4().to_string(),
244 correlation_type: CorrelationType::TemporalCorrelation,
245 strength,
246 entities: entities.into_iter().cloned().collect(),
247 signals: window.to_vec(),
248 description: format!(
249 "Temporal burst: {} entities active within {}ms",
250 entity_count, self.temporal_window_ms
251 ),
252 detected_at: chrono::Utc::now().timestamp_millis(),
253 metadata: CorrelationMetadata {
254 signal_count: Some(window.len()),
255 time_window: Some(self.temporal_window_ms),
256 ..Default::default()
257 },
258 });
259 }
260 }
261
262 self.deduplicate_correlations(correlations)
264 }
265
266 fn find_fingerprint_families(&self, signals: &[Signal]) -> Vec<Correlation> {
268 let mut correlations = Vec::new();
269
270 let fingerprints: Vec<_> = signals
272 .iter()
273 .filter(|s| {
274 matches!(
275 s.signal_type,
276 SignalType::Ja4 | SignalType::Ja4h | SignalType::HttpFingerprint
277 )
278 })
279 .collect();
280
281 let mut prefix_groups: HashMap<String, Vec<&Signal>> = HashMap::new();
283 for fp in &fingerprints {
284 if fp.value.len() >= 8 {
285 let prefix = fp.value[..8].to_string();
286 prefix_groups.entry(prefix).or_default().push(fp);
287 }
288 }
289
290 for (prefix, group) in prefix_groups {
291 let unique_values: HashSet<_> = group.iter().map(|s| &s.value).collect();
292
293 if unique_values.len() >= 2 {
295 let entities: HashSet<_> = group.iter().map(|s| s.entity_id.clone()).collect();
296 let strength = unique_values.len() as f64 / 10.0;
297 let strength = strength.min(1.0).max(self.min_strength);
298
299 correlations.push(Correlation {
300 id: uuid::Uuid::new_v4().to_string(),
301 correlation_type: CorrelationType::FingerprintFamily,
302 strength,
303 entities: entities.into_iter().collect(),
304 signals: group.into_iter().cloned().collect(),
305 description: format!(
306 "Fingerprint family: {} variants with prefix {}...",
307 unique_values.len(),
308 prefix
309 ),
310 detected_at: chrono::Utc::now().timestamp_millis(),
311 metadata: CorrelationMetadata {
312 shared_value: Some(prefix),
313 signal_count: Some(unique_values.len()),
314 ..Default::default()
315 },
316 });
317 }
318 }
319
320 correlations
321 }
322
323 fn deduplicate_correlations(&self, correlations: Vec<Correlation>) -> Vec<Correlation> {
325 if correlations.is_empty() {
326 return correlations;
327 }
328
329 let mut result = Vec::new();
330 let mut seen_entities: HashSet<String> = HashSet::new();
331
332 for corr in correlations {
333 let entities_set: HashSet<_> = corr.entities.iter().cloned().collect();
335 let overlap = entities_set.intersection(&seen_entities).count();
336
337 if overlap as f64 / entities_set.len() as f64 <= 0.5 {
339 seen_entities.extend(corr.entities.iter().cloned());
340 result.push(corr);
341 }
342 }
343
344 result
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 fn create_test_signal(entity_id: &str, value: &str, timestamp: i64) -> Signal {
353 Signal {
354 id: uuid::Uuid::new_v4().to_string(),
355 timestamp,
356 category: super::super::types::SignalCategory::Network,
357 signal_type: SignalType::Ja4,
358 value: value.to_string(),
359 entity_id: entity_id.to_string(),
360 session_id: None,
361 metadata: super::super::types::SignalMetadata::default(),
362 }
363 }
364
365 #[test]
366 fn test_entity_cluster_detection() {
367 let engine = CorrelationEngine::new();
368
369 let signals = vec![
370 create_test_signal("ip-1", "shared_value", 1000),
371 create_test_signal("ip-2", "shared_value", 2000),
372 create_test_signal("ip-3", "shared_value", 3000),
373 ];
374
375 let correlations = engine.find_entity_clusters(&signals);
376 assert!(!correlations.is_empty());
377 assert_eq!(
378 correlations[0].correlation_type,
379 CorrelationType::EntityCluster
380 );
381 }
382
383 #[test]
384 fn test_temporal_correlation() {
385 let engine = CorrelationEngine::with_settings(2, 10_000, 0.3);
386
387 let now = chrono::Utc::now().timestamp_millis();
388 let signals = vec![
389 create_test_signal("ip-1", "value-1", now),
390 create_test_signal("ip-2", "value-2", now + 1000),
391 create_test_signal("ip-3", "value-3", now + 2000),
392 ];
393
394 let correlations = engine.find_temporal_correlations(&signals);
395 assert!(!correlations.is_empty());
396 assert_eq!(
397 correlations[0].correlation_type,
398 CorrelationType::TemporalCorrelation
399 );
400 }
401
402 #[test]
403 fn test_fingerprint_family() {
404 let engine = CorrelationEngine::new();
405
406 let signals = vec![
407 create_test_signal("ip-1", "t13d1516h2_variant1_abc", 1000),
408 create_test_signal("ip-2", "t13d1516h2_variant2_def", 2000),
409 create_test_signal("ip-3", "t13d1516h2_variant3_ghi", 3000),
410 ];
411
412 let correlations = engine.find_fingerprint_families(&signals);
413 assert!(!correlations.is_empty());
414 assert_eq!(
415 correlations[0].correlation_type,
416 CorrelationType::FingerprintFamily
417 );
418 }
419}