Skip to main content

synapse_pingora/trends/
correlation.rs

1//! Correlation engine for finding relationships between signals.
2
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5
6use super::types::{Signal, SignalType};
7
8/// Types of correlations we detect.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum CorrelationType {
12    /// Multiple IPs sharing signals
13    EntityCluster,
14    /// Sequence of related signals
15    SignalChain,
16    /// Signals occurring together in time
17    TemporalCorrelation,
18    /// Similar but not identical fingerprints
19    FingerprintFamily,
20}
21
22/// A detected correlation between signals/entities.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Correlation {
25    pub id: String,
26    pub correlation_type: CorrelationType,
27    /// Correlation strength (0-1)
28    pub strength: f64,
29    /// Involved entity IDs (IPs)
30    pub entities: Vec<String>,
31    /// Related signals
32    pub signals: Vec<Signal>,
33    pub description: String,
34    pub detected_at: i64,
35    pub metadata: CorrelationMetadata,
36}
37
38/// Correlation metadata.
39#[derive(Debug, Clone, Default, Serialize, Deserialize)]
40pub struct CorrelationMetadata {
41    pub shared_value: Option<String>,
42    pub signal_count: Option<usize>,
43    pub time_window: Option<i64>,
44}
45
46/// Query options for correlations.
47#[derive(Debug, Clone, Default)]
48pub struct CorrelationQueryOptions {
49    pub correlation_type: Option<CorrelationType>,
50    pub entity_id: Option<String>,
51    pub signal_type: Option<SignalType>,
52    pub from: Option<i64>,
53    pub to: Option<i64>,
54    pub min_strength: Option<f64>,
55    pub limit: Option<usize>,
56}
57
58/// Correlation engine for finding relationships.
59pub struct CorrelationEngine {
60    /// Minimum entities for a cluster
61    min_cluster_size: usize,
62    /// Time window for temporal correlation (ms)
63    temporal_window_ms: i64,
64    /// Minimum correlation strength threshold
65    min_strength: f64,
66}
67
68impl Default for CorrelationEngine {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl CorrelationEngine {
75    /// Create a new correlation engine.
76    pub fn new() -> Self {
77        Self {
78            min_cluster_size: 3,
79            temporal_window_ms: 60_000,
80            min_strength: 0.5,
81        }
82    }
83
84    /// Create with custom settings.
85    pub fn with_settings(
86        min_cluster_size: usize,
87        temporal_window_ms: i64,
88        min_strength: f64,
89    ) -> Self {
90        Self {
91            min_cluster_size,
92            temporal_window_ms,
93            min_strength,
94        }
95    }
96
97    /// Find correlations in a set of signals.
98    pub fn find_correlations(
99        &self,
100        signals: &[Signal],
101        options: &CorrelationQueryOptions,
102    ) -> Vec<Correlation> {
103        let mut correlations = Vec::new();
104
105        // Entity clusters
106        if options.correlation_type.is_none()
107            || options.correlation_type == Some(CorrelationType::EntityCluster)
108        {
109            correlations.extend(self.find_entity_clusters(signals));
110        }
111
112        // Temporal correlations
113        if options.correlation_type.is_none()
114            || options.correlation_type == Some(CorrelationType::TemporalCorrelation)
115        {
116            correlations.extend(self.find_temporal_correlations(signals));
117        }
118
119        // Fingerprint families
120        if options.correlation_type.is_none()
121            || options.correlation_type == Some(CorrelationType::FingerprintFamily)
122        {
123            correlations.extend(self.find_fingerprint_families(signals));
124        }
125
126        // Apply filters
127        let mut filtered = correlations
128            .into_iter()
129            .filter(|c| {
130                if let Some(ref entity_id) = options.entity_id {
131                    if !c.entities.contains(entity_id) {
132                        return false;
133                    }
134                }
135                if let Some(min_str) = options.min_strength {
136                    if c.strength < min_str {
137                        return false;
138                    }
139                }
140                if let Some(from) = options.from {
141                    if c.detected_at < from {
142                        return false;
143                    }
144                }
145                if let Some(to) = options.to {
146                    if c.detected_at > to {
147                        return false;
148                    }
149                }
150                true
151            })
152            .collect::<Vec<_>>();
153
154        // Sort by strength (strongest first)
155        filtered.sort_by(|a, b| {
156            b.strength
157                .partial_cmp(&a.strength)
158                .unwrap_or(std::cmp::Ordering::Equal)
159        });
160
161        // Apply limit
162        if let Some(limit) = options.limit {
163            filtered.truncate(limit);
164        }
165
166        filtered
167    }
168
169    /// Find entity clusters (IPs sharing signals).
170    fn find_entity_clusters(&self, signals: &[Signal]) -> Vec<Correlation> {
171        let mut correlations = Vec::new();
172
173        // Group signals by value
174        let mut value_entities: HashMap<String, HashSet<String>> = HashMap::new();
175        for signal in signals {
176            value_entities
177                .entry(signal.value.clone())
178                .or_default()
179                .insert(signal.entity_id.clone());
180        }
181
182        for (value, entities) in value_entities {
183            let entity_count = entities.len();
184            if entity_count >= self.min_cluster_size {
185                let strength = (entity_count as f64 - 2.0) / 10.0;
186                let strength = strength.min(1.0).max(self.min_strength);
187
188                correlations.push(Correlation {
189                    id: uuid::Uuid::new_v4().to_string(),
190                    correlation_type: CorrelationType::EntityCluster,
191                    strength,
192                    entities: entities.into_iter().collect(),
193                    signals: signals
194                        .iter()
195                        .filter(|s| s.value == value)
196                        .cloned()
197                        .collect(),
198                    description: format!("Entity cluster: {} IPs share signal value", entity_count),
199                    detected_at: chrono::Utc::now().timestamp_millis(),
200                    metadata: CorrelationMetadata {
201                        shared_value: Some(value[..16.min(value.len())].to_string()),
202                        signal_count: Some(signals.iter().filter(|s| s.value == value).count()),
203                        ..Default::default()
204                    },
205                });
206            }
207        }
208
209        correlations
210    }
211
212    /// Find temporal correlations (signals occurring together).
213    fn find_temporal_correlations(&self, signals: &[Signal]) -> Vec<Correlation> {
214        let mut correlations = Vec::new();
215
216        if signals.len() < 2 {
217            return correlations;
218        }
219
220        // Sort by timestamp
221        let mut sorted = signals.to_vec();
222        sorted.sort_by_key(|s| s.timestamp);
223
224        // Sliding window to find bursts
225        let mut window_start = 0;
226        for i in 0..sorted.len() {
227            // Shrink window from left
228            while sorted[i].timestamp - sorted[window_start].timestamp > self.temporal_window_ms {
229                window_start += 1;
230            }
231
232            // Check if window has multiple entities
233            let window = &sorted[window_start..=i];
234            let entities: HashSet<_> = window.iter().map(|s| &s.entity_id).collect();
235
236            let entity_count = entities.len();
237            if entity_count >= self.min_cluster_size {
238                // Found a temporal burst
239                let strength = (entity_count as f64 - 2.0) / 8.0;
240                let strength = strength.min(1.0).max(self.min_strength);
241
242                correlations.push(Correlation {
243                    id: uuid::Uuid::new_v4().to_string(),
244                    correlation_type: CorrelationType::TemporalCorrelation,
245                    strength,
246                    entities: entities.into_iter().cloned().collect(),
247                    signals: window.to_vec(),
248                    description: format!(
249                        "Temporal burst: {} entities active within {}ms",
250                        entity_count, self.temporal_window_ms
251                    ),
252                    detected_at: chrono::Utc::now().timestamp_millis(),
253                    metadata: CorrelationMetadata {
254                        signal_count: Some(window.len()),
255                        time_window: Some(self.temporal_window_ms),
256                        ..Default::default()
257                    },
258                });
259            }
260        }
261
262        // Deduplicate overlapping correlations
263        self.deduplicate_correlations(correlations)
264    }
265
266    /// Find fingerprint families (similar fingerprints).
267    fn find_fingerprint_families(&self, signals: &[Signal]) -> Vec<Correlation> {
268        let mut correlations = Vec::new();
269
270        // Get fingerprint signals
271        let fingerprints: Vec<_> = signals
272            .iter()
273            .filter(|s| {
274                matches!(
275                    s.signal_type,
276                    SignalType::Ja4 | SignalType::Ja4h | SignalType::HttpFingerprint
277                )
278            })
279            .collect();
280
281        // Group by prefix (first 8 chars)
282        let mut prefix_groups: HashMap<String, Vec<&Signal>> = HashMap::new();
283        for fp in &fingerprints {
284            if fp.value.len() >= 8 {
285                let prefix = fp.value[..8].to_string();
286                prefix_groups.entry(prefix).or_default().push(fp);
287            }
288        }
289
290        for (prefix, group) in prefix_groups {
291            let unique_values: HashSet<_> = group.iter().map(|s| &s.value).collect();
292
293            // Only if there are multiple similar but not identical fingerprints
294            if unique_values.len() >= 2 {
295                let entities: HashSet<_> = group.iter().map(|s| s.entity_id.clone()).collect();
296                let strength = unique_values.len() as f64 / 10.0;
297                let strength = strength.min(1.0).max(self.min_strength);
298
299                correlations.push(Correlation {
300                    id: uuid::Uuid::new_v4().to_string(),
301                    correlation_type: CorrelationType::FingerprintFamily,
302                    strength,
303                    entities: entities.into_iter().collect(),
304                    signals: group.into_iter().cloned().collect(),
305                    description: format!(
306                        "Fingerprint family: {} variants with prefix {}...",
307                        unique_values.len(),
308                        prefix
309                    ),
310                    detected_at: chrono::Utc::now().timestamp_millis(),
311                    metadata: CorrelationMetadata {
312                        shared_value: Some(prefix),
313                        signal_count: Some(unique_values.len()),
314                        ..Default::default()
315                    },
316                });
317            }
318        }
319
320        correlations
321    }
322
323    /// Deduplicate overlapping correlations.
324    fn deduplicate_correlations(&self, correlations: Vec<Correlation>) -> Vec<Correlation> {
325        if correlations.is_empty() {
326            return correlations;
327        }
328
329        let mut result = Vec::new();
330        let mut seen_entities: HashSet<String> = HashSet::new();
331
332        for corr in correlations {
333            // Check if any entity in this correlation is already covered
334            let entities_set: HashSet<_> = corr.entities.iter().cloned().collect();
335            let overlap = entities_set.intersection(&seen_entities).count();
336
337            // Only add if less than 50% overlap
338            if overlap as f64 / entities_set.len() as f64 <= 0.5 {
339                seen_entities.extend(corr.entities.iter().cloned());
340                result.push(corr);
341            }
342        }
343
344        result
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    fn create_test_signal(entity_id: &str, value: &str, timestamp: i64) -> Signal {
353        Signal {
354            id: uuid::Uuid::new_v4().to_string(),
355            timestamp,
356            category: super::super::types::SignalCategory::Network,
357            signal_type: SignalType::Ja4,
358            value: value.to_string(),
359            entity_id: entity_id.to_string(),
360            session_id: None,
361            metadata: super::super::types::SignalMetadata::default(),
362        }
363    }
364
365    #[test]
366    fn test_entity_cluster_detection() {
367        let engine = CorrelationEngine::new();
368
369        let signals = vec![
370            create_test_signal("ip-1", "shared_value", 1000),
371            create_test_signal("ip-2", "shared_value", 2000),
372            create_test_signal("ip-3", "shared_value", 3000),
373        ];
374
375        let correlations = engine.find_entity_clusters(&signals);
376        assert!(!correlations.is_empty());
377        assert_eq!(
378            correlations[0].correlation_type,
379            CorrelationType::EntityCluster
380        );
381    }
382
383    #[test]
384    fn test_temporal_correlation() {
385        let engine = CorrelationEngine::with_settings(2, 10_000, 0.3);
386
387        let now = chrono::Utc::now().timestamp_millis();
388        let signals = vec![
389            create_test_signal("ip-1", "value-1", now),
390            create_test_signal("ip-2", "value-2", now + 1000),
391            create_test_signal("ip-3", "value-3", now + 2000),
392        ];
393
394        let correlations = engine.find_temporal_correlations(&signals);
395        assert!(!correlations.is_empty());
396        assert_eq!(
397            correlations[0].correlation_type,
398            CorrelationType::TemporalCorrelation
399        );
400    }
401
402    #[test]
403    fn test_fingerprint_family() {
404        let engine = CorrelationEngine::new();
405
406        let signals = vec![
407            create_test_signal("ip-1", "t13d1516h2_variant1_abc", 1000),
408            create_test_signal("ip-2", "t13d1516h2_variant2_def", 2000),
409            create_test_signal("ip-3", "t13d1516h2_variant3_ghi", 3000),
410        ];
411
412        let correlations = engine.find_fingerprint_families(&signals);
413        assert!(!correlations.is_empty());
414        assert_eq!(
415            correlations[0].correlation_type,
416            CorrelationType::FingerprintFamily
417        );
418    }
419}