Skip to main content

synapse_pingora/correlation/detectors/
graph.rs

1//! Graph-Based Correlation Detector
2//!
3//! Maintains a graph of relationships between entities (IPs, fingerprints, tokens, ASNs).
4//! Detects campaigns by identifying connected components in the graph.
5//!
6//! # Graph Structure
7//! - **Nodes**: Strings with type prefixes (e.g., "ip:1.2.3.4", "fp:abc", "token:xyz")
8//! - **Edges**: Undirected connections representing observed co-occurrence
9//!
10//! # Detection Logic
11//! - Finds connected components of IP addresses linked by shared attributes
12//! - Triggers campaign if component size exceeds threshold
13//! - Supports depth-limited traversal to limit performance impact
14
15use super::{Detector, DetectorResult};
16use crate::correlation::{CampaignUpdate, CorrelationReason, CorrelationType, FingerprintIndex};
17use dashmap::DashMap;
18use sha2::{Digest, Sha256};
19use std::collections::{HashSet, VecDeque};
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::{Duration, Instant};
22
23/// Options for graph export
24#[derive(Debug, Clone, Default)]
25pub struct GraphExportOptions {
26    /// Maximum number of nodes to return (default: 500)
27    pub limit: Option<usize>,
28    /// Skip this many nodes (for pagination)
29    pub offset: Option<usize>,
30    /// Hash sensitive identifiers (IPs, tokens) for external exposure
31    pub hash_identifiers: bool,
32}
33
34/// Paginated graph export result
35#[derive(Debug, Clone, serde::Serialize)]
36pub struct PaginatedGraph {
37    /// Cytoscape-format nodes and edges
38    pub nodes: Vec<serde_json::Value>,
39    pub edges: Vec<serde_json::Value>,
40    /// Total node count (before pagination)
41    pub total_nodes: usize,
42    /// Whether there are more nodes
43    pub has_more: bool,
44    /// Snapshot version for consistency checking
45    pub snapshot_version: u64,
46}
47
48/// Hash an identifier for external exposure
49fn hash_identifier(id: &str) -> String {
50    let mut hasher = Sha256::new();
51    hasher.update(id.as_bytes());
52    let result = hasher.finalize();
53    format!("{:x}", result)[..12].to_string() // First 12 hex chars
54}
55
56/// Configuration for GraphDetector.
57#[derive(Debug, Clone)]
58pub struct GraphConfig {
59    /// Minimum number of unique IPs in a connected component to trigger detection.
60    /// Default: 3
61    pub min_component_size: usize,
62
63    /// Maximum depth for graph traversal (BFS).
64    /// Default: 3 (e.g., IP -> FP -> IP -> Token -> IP)
65    pub max_traversal_depth: usize,
66
67    /// Time window to keep edges alive.
68    /// Default: 3600 seconds (1 hour)
69    pub edge_ttl: Duration,
70
71    /// Weight of this detector in campaign scoring.
72    /// Default: 20
73    pub weight: u8,
74
75    /// Maximum number of nodes in the graph to prevent memory exhaustion.
76    /// Default: 10,000
77    pub max_nodes: usize,
78
79    /// Maximum edges per node to prevent star explosion attacks.
80    /// Default: 1,000
81    pub max_edges_per_node: usize,
82
83    /// Maximum BFS iterations to prevent CPU exhaustion.
84    /// Default: 50,000
85    pub max_bfs_iterations: usize,
86}
87
88impl Default for GraphConfig {
89    fn default() -> Self {
90        Self {
91            min_component_size: 3,
92            max_traversal_depth: 3,
93            edge_ttl: Duration::from_secs(3600),
94            weight: 20,
95            max_nodes: 10_000,
96            max_edges_per_node: 1_000,
97            max_bfs_iterations: 50_000,
98        }
99    }
100}
101
102/// Node in the correlation graph.
103#[derive(Debug, Clone, PartialEq, Eq, Hash)]
104struct GraphNode {
105    id: String,
106    node_type: NodeType,
107    last_seen: Instant,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Hash)]
111enum NodeType {
112    Ip,
113    Fingerprint,
114    Token,
115    Asn,
116    Other,
117}
118
119impl NodeType {
120    fn from_id(id: &str) -> Self {
121        if id.starts_with("ip:") {
122            NodeType::Ip
123        } else if id.starts_with("fp:") {
124            NodeType::Fingerprint
125        } else if id.starts_with("token:") {
126            NodeType::Token
127        } else if id.starts_with("asn:") {
128            NodeType::Asn
129        } else {
130            NodeType::Other
131        }
132    }
133}
134
135/// Graph-based correlation detector.
136pub struct GraphDetector {
137    config: GraphConfig,
138    /// Adjacency list: Node -> Set of connected Nodes
139    /// Stores just string IDs to reduce cloning overhead
140    adjacency: DashMap<String, HashSet<String>>,
141    /// Node metadata (last seen, type)
142    nodes: DashMap<String, GraphNode>,
143    /// Last cleanup timestamp
144    last_cleanup: std::sync::Mutex<Instant>,
145    /// Statistics
146    edges_count: AtomicU64,
147}
148
149impl GraphDetector {
150    pub fn new(config: GraphConfig) -> Self {
151        Self {
152            config,
153            adjacency: DashMap::new(),
154            nodes: DashMap::new(),
155            last_cleanup: std::sync::Mutex::new(Instant::now()),
156            edges_count: AtomicU64::new(0),
157        }
158    }
159
160    /// Record a relationship between two entities.
161    /// e.g., IP "1.2.3.4" used Fingerprint "abc"
162    ///
163    /// Returns false if graph bounds are exceeded.
164    pub fn record_relation(&self, entity_a: &str, entity_b: &str) -> bool {
165        if entity_a == entity_b {
166            return true;
167        }
168
169        let now = Instant::now();
170
171        // Check node count limit before adding new nodes
172        let current_node_count = self.nodes.len();
173        let is_a_new = !self.nodes.contains_key(entity_a);
174        let is_b_new = !self.nodes.contains_key(entity_b);
175        let new_nodes_needed = (is_a_new as usize) + (is_b_new as usize);
176
177        if current_node_count + new_nodes_needed > self.config.max_nodes {
178            tracing::warn!(
179                current = current_node_count,
180                max = self.config.max_nodes,
181                "Graph node limit reached, skipping relation"
182            );
183            return false;
184        }
185
186        // Update or create nodes using atomic entry API
187        self.update_node(entity_a, now);
188        self.update_node(entity_b, now);
189
190        // Check edge count limit per node before adding
191        let mut edge_added = false;
192
193        // Add edge a -> b (if within limit)
194        {
195            let mut entry = self.adjacency.entry(entity_a.to_string()).or_default();
196            if entry.len() < self.config.max_edges_per_node {
197                entry.insert(entity_b.to_string());
198                edge_added = true;
199            } else {
200                tracing::debug!(
201                    node = entity_a,
202                    edges = entry.len(),
203                    "Edge limit reached for node"
204                );
205            }
206        }
207
208        // Add edge b -> a (if within limit)
209        {
210            let mut entry = self.adjacency.entry(entity_b.to_string()).or_default();
211            if entry.len() < self.config.max_edges_per_node {
212                entry.insert(entity_a.to_string());
213            }
214        }
215
216        if edge_added {
217            self.edges_count.fetch_add(1, Ordering::Relaxed);
218        }
219
220        true
221    }
222
223    /// Update or create a node using atomic entry API (fixes race condition).
224    fn update_node(&self, id: &str, now: Instant) {
225        // Use entry API for atomic update-or-insert (fixes race condition)
226        self.nodes
227            .entry(id.to_string())
228            .and_modify(|node| {
229                node.last_seen = now;
230            })
231            .or_insert_with(|| GraphNode {
232                id: id.to_string(),
233                node_type: NodeType::from_id(id),
234                last_seen: now,
235            });
236    }
237
238    /// Helpers to format IDs
239    pub fn ip_id(ip: &str) -> String {
240        format!("ip:{}", ip)
241    }
242    pub fn fp_id(fp: &str) -> String {
243        format!("fp:{}", fp)
244    }
245    pub fn token_id(token: &str) -> String {
246        format!("token:{}", token)
247    }
248    pub fn asn_id(asn: &str) -> String {
249        format!("asn:{}", asn)
250    }
251
252    /// BFS to find component
253    fn find_connected_ips(&self, start_node: &str) -> HashSet<String> {
254        let mut visited = HashSet::new();
255        let mut queue = VecDeque::new();
256        let mut ips = HashSet::new();
257        let mut iterations: usize = 0;
258
259        queue.push_back((start_node.to_string(), 0));
260        visited.insert(start_node.to_string());
261
262        while let Some((current_id, depth)) = queue.pop_front() {
263            // Check iteration limit to prevent CPU exhaustion
264            iterations += 1;
265            if iterations > self.config.max_bfs_iterations {
266                tracing::warn!(
267                    start = start_node,
268                    iterations = iterations,
269                    max = self.config.max_bfs_iterations,
270                    "BFS iteration limit reached, returning partial result"
271                );
272                break;
273            }
274
275            if depth >= self.config.max_traversal_depth {
276                continue;
277            }
278
279            // If current node is an IP, add to results
280            if NodeType::from_id(&current_id) == NodeType::Ip {
281                // Strip prefix
282                if let Some(ip) = current_id.strip_prefix("ip:") {
283                    ips.insert(ip.to_string());
284                }
285            }
286
287            // Visit neighbors
288            if let Some(neighbors) = self.adjacency.get(&current_id) {
289                for neighbor in neighbors.iter() {
290                    if !visited.contains(neighbor) {
291                        visited.insert(neighbor.clone());
292                        queue.push_back((neighbor.clone(), depth + 1));
293                    }
294                }
295            }
296        }
297
298        ips
299    }
300
301    /// Export graph data for a connected component starting from a given set of IPs.
302    /// Legacy method - delegates to get_cytoscape_data_paginated with default options.
303    pub fn get_cytoscape_data(&self, ips: &[String]) -> serde_json::Value {
304        let result = self.get_cytoscape_data_paginated(ips, GraphExportOptions::default());
305        serde_json::json!({
306            "nodes": result.nodes,
307            "edges": result.edges
308        })
309    }
310
311    /// Export graph data with pagination and optional identifier hashing.
312    /// P1 fix: Adds pagination to prevent unbounded memory usage and
313    /// hashes identifiers to prevent information disclosure.
314    pub fn get_cytoscape_data_paginated(
315        &self,
316        ips: &[String],
317        options: GraphExportOptions,
318    ) -> PaginatedGraph {
319        let limit = options.limit.unwrap_or(500);
320        let offset = options.offset.unwrap_or(0);
321        let hash_ids = options.hash_identifiers;
322
323        let mut all_nodes = Vec::new();
324        let mut edges = Vec::new();
325        let mut visited = HashSet::new();
326        let mut queue = VecDeque::new();
327
328        // Start from all campaign IPs
329        for ip in ips {
330            let id = Self::ip_id(ip);
331            if !visited.contains(&id) {
332                visited.insert(id.clone());
333                queue.push_back((id, 0));
334            }
335        }
336
337        while let Some((current_id, depth)) = queue.pop_front() {
338            // Create display ID (hashed or raw)
339            let display_id = if hash_ids {
340                let node_type = NodeType::from_id(&current_id);
341                let prefix = match node_type {
342                    NodeType::Ip => "ip",
343                    NodeType::Fingerprint => "fp",
344                    NodeType::Token => "tok",
345                    NodeType::Asn => "asn",
346                    _ => "unk",
347                };
348                format!("{}:{}", prefix, hash_identifier(&current_id))
349            } else {
350                current_id.clone()
351            };
352
353            // Add node
354            let node_type = NodeType::from_id(&current_id);
355            all_nodes.push((
356                current_id.clone(),
357                serde_json::json!({
358                    "data": {
359                        "id": display_id.clone(),
360                        "label": if hash_ids {
361                            display_id.split(':').nth(1).unwrap_or(&display_id).to_string()
362                        } else {
363                            current_id.split(':').nth(1).unwrap_or(&current_id).to_string()
364                        },
365                        "type": match node_type {
366                            NodeType::Ip => "ip",
367                            NodeType::Fingerprint => "actor", // Mapping to UI terminology
368                            NodeType::Token => "token",
369                            NodeType::Asn => "asn",
370                            _ => "other",
371                        }
372                    }
373                }),
374            ));
375
376            if depth >= self.config.max_traversal_depth {
377                continue;
378            }
379
380            // Add neighbors and edges
381            if let Some(neighbors) = self.adjacency.get(&current_id) {
382                for neighbor in neighbors.iter() {
383                    // Create display IDs for edge
384                    let source_display = if hash_ids {
385                        let node_type = NodeType::from_id(&current_id);
386                        let prefix = match node_type {
387                            NodeType::Ip => "ip",
388                            NodeType::Fingerprint => "fp",
389                            NodeType::Token => "tok",
390                            NodeType::Asn => "asn",
391                            _ => "unk",
392                        };
393                        format!("{}:{}", prefix, hash_identifier(&current_id))
394                    } else {
395                        current_id.clone()
396                    };
397
398                    let target_display = if hash_ids {
399                        let node_type = NodeType::from_id(neighbor);
400                        let prefix = match node_type {
401                            NodeType::Ip => "ip",
402                            NodeType::Fingerprint => "fp",
403                            NodeType::Token => "tok",
404                            NodeType::Asn => "asn",
405                            _ => "unk",
406                        };
407                        format!("{}:{}", prefix, hash_identifier(neighbor))
408                    } else {
409                        neighbor.clone()
410                    };
411
412                    // Always add edge (deduplicated below)
413                    let mut edge_ids = [source_display.as_str(), target_display.as_str()];
414                    edge_ids.sort();
415                    let edge_id = format!("e_{}_{}", edge_ids[0], edge_ids[1]);
416
417                    edges.push(serde_json::json!({
418                        "data": {
419                            "id": edge_id,
420                            "source": source_display,
421                            "target": target_display,
422                            "label": "linked"
423                        }
424                    }));
425
426                    if !visited.contains(neighbor) {
427                        visited.insert(neighbor.clone());
428                        queue.push_back((neighbor.clone(), depth + 1));
429                    }
430                }
431            }
432        }
433
434        let total_nodes = all_nodes.len();
435
436        // Apply pagination to nodes
437        let paginated_nodes: Vec<serde_json::Value> = all_nodes
438            .into_iter()
439            .skip(offset)
440            .take(limit)
441            .map(|(_, node)| node)
442            .collect();
443
444        // Deduplicate edges
445        let mut unique_edges = Vec::new();
446        let mut edge_id_set = HashSet::new();
447        for edge in edges {
448            let id = edge["data"]["id"].as_str().unwrap().to_string();
449            if edge_id_set.insert(id) {
450                unique_edges.push(edge);
451            }
452        }
453
454        PaginatedGraph {
455            nodes: paginated_nodes,
456            edges: unique_edges,
457            total_nodes,
458            has_more: offset + limit < total_nodes,
459            snapshot_version: self.edges_count.load(Ordering::Relaxed),
460        }
461    }
462
463    /// Clean up old nodes and edges.
464    fn cleanup(&self) {
465        let now = Instant::now();
466        let ttl = self.config.edge_ttl;
467
468        // Remove old nodes
469        self.nodes
470            .retain(|_, node| now.duration_since(node.last_seen) < ttl);
471
472        // Clean up adjacency list (remove keys that no longer exist in nodes)
473        // This is expensive, so it should run infrequently
474        self.adjacency.retain(|k, _| self.nodes.contains_key(k));
475
476        // We also need to remove values from the HashSets inside adjacency
477        // This requires iterating all values. For performance in this PoC,
478        // we might rely on the fact that if A links to B, and B expires,
479        // A's link to B becomes a dead end which find_connected_ips handles gracefully
480        // (it just won't find B in adjacency or won't find B's neighbors).
481        // A complete cleanup would iterate all sets.
482    }
483}
484
485impl Detector for GraphDetector {
486    fn name(&self) -> &'static str {
487        "graph_correlation"
488    }
489
490    fn analyze(&self, _index: &FingerprintIndex) -> DetectorResult<Vec<CampaignUpdate>> {
491        let mut updates = Vec::new();
492        let mut processed_ips = HashSet::new();
493
494        // Iterate over all IP nodes to find components
495        // We clone the keys to avoid holding locks during traversal
496        let ip_nodes: Vec<String> = self
497            .nodes
498            .iter()
499            .filter(|r| r.value().node_type == NodeType::Ip)
500            .map(|r| r.key().clone())
501            .collect();
502
503        for ip_node in ip_nodes {
504            // Skip if already part of a processed component
505            // Note: `processed_ips` tracks raw IPs ("1.2.3.4"), `ip_node` is "ip:1.2.3.4"
506            let raw_ip = ip_node.strip_prefix("ip:").unwrap_or(&ip_node);
507            if processed_ips.contains(raw_ip) {
508                continue;
509            }
510
511            // BFS to find component
512            let component_ips = self.find_connected_ips(&ip_node);
513
514            // Mark all as processed
515            for ip in &component_ips {
516                processed_ips.insert(ip.clone());
517            }
518
519            // Check if component meets threshold
520            if component_ips.len() >= self.config.min_component_size {
521                let reason = CorrelationReason {
522                    correlation_type: CorrelationType::BehavioralSimilarity, // Graph falls under behavioral/structural
523                    confidence: 0.9, // High confidence for graph connections
524                    evidence: component_ips.into_iter().collect(),
525                    description: format!(
526                        "Graph correlation: {} IPs connected via shared attributes (depth {})",
527                        self.config.min_component_size, self.config.max_traversal_depth
528                    ),
529                };
530
531                updates.push(CampaignUpdate {
532                    campaign_id: None, // New campaign or update existing
533                    status: None,
534                    risk_score: None,
535                    add_correlation_reason: Some(reason),
536                    attack_types: Some(vec!["coordinated_botnet".to_string()]),
537                    confidence: Some(0.9),
538                    add_member_ips: None,
539                    increment_requests: None,
540                    increment_blocked: None,
541                    increment_rules: None,
542                });
543            }
544        }
545
546        // Run cleanup if needed (e.g., every 5 minutes)
547        if let Ok(mut last) = self.last_cleanup.try_lock() {
548            if last.elapsed() > Duration::from_secs(300) {
549                *last = Instant::now();
550                // Spawn cleanup to avoid blocking analyze?
551                // For safety in this trait method, we'll run it synchronously but it might be slow.
552                // In production, use a background task.
553                self.cleanup();
554            }
555        }
556
557        Ok(updates)
558    }
559
560    fn should_trigger(&self, _ip: &std::net::IpAddr, _index: &FingerprintIndex) -> bool {
561        // Graph updates are implicit via record_relation, this check is less relevant
562        // unless we want to do immediate subgraph checks.
563        // For now, return false to rely on periodic analyze().
564        false
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_graph_connection() {
574        let detector = GraphDetector::new(GraphConfig::default());
575
576        // Link IP1 -> FP1 -> IP2
577        assert!(detector.record_relation(
578            &GraphDetector::ip_id("1.1.1.1"),
579            &GraphDetector::fp_id("fp_a")
580        ));
581        assert!(detector.record_relation(
582            &GraphDetector::fp_id("fp_a"),
583            &GraphDetector::ip_id("2.2.2.2")
584        ));
585
586        let ips = detector.find_connected_ips(&GraphDetector::ip_id("1.1.1.1"));
587        assert!(ips.contains("1.1.1.1"));
588        assert!(ips.contains("2.2.2.2"));
589        assert_eq!(ips.len(), 2);
590    }
591
592    #[test]
593    fn test_component_detection() {
594        // Chain: IP1 -> FP -> IP2 -> Token -> IP3
595        // This requires depth 5 to traverse (0->1->2->3->4)
596        let detector = GraphDetector::new(GraphConfig {
597            min_component_size: 3,
598            max_traversal_depth: 5, // Need depth 5 to reach ip:3
599            ..Default::default()
600        });
601
602        // Triangle: IP1-FP-IP2, IP2-Token-IP3
603        assert!(detector.record_relation("ip:1", "fp:a"));
604        assert!(detector.record_relation("fp:a", "ip:2"));
605        assert!(detector.record_relation("ip:2", "tok:x"));
606        assert!(detector.record_relation("tok:x", "ip:3"));
607
608        let updates = detector.analyze(&FingerprintIndex::new()).unwrap();
609        assert_eq!(updates.len(), 1);
610
611        let update = &updates[0];
612        let reason = update.add_correlation_reason.as_ref().unwrap();
613        assert!(reason.evidence.contains(&"1".to_string()));
614        assert!(reason.evidence.contains(&"2".to_string()));
615        assert!(reason.evidence.contains(&"3".to_string()));
616    }
617
618    #[test]
619    fn test_node_limit_enforced() {
620        let detector = GraphDetector::new(GraphConfig {
621            max_nodes: 5,
622            ..Default::default()
623        });
624
625        // Add 5 unique nodes (should succeed)
626        assert!(detector.record_relation("ip:1", "fp:a")); // 2 nodes
627        assert!(detector.record_relation("ip:2", "fp:b")); // 4 nodes
628        assert!(detector.record_relation("ip:3", "fp:a")); // 5 nodes (ip:3 is new, fp:a exists)
629
630        // Try to add 2 more new nodes (should fail - would exceed limit)
631        assert!(!detector.record_relation("ip:4", "fp:c")); // Would need 2 new nodes
632
633        // But adding a relation between existing nodes should work
634        assert!(detector.record_relation("ip:1", "ip:2"));
635    }
636
637    #[test]
638    fn test_edge_limit_enforced() {
639        let detector = GraphDetector::new(GraphConfig {
640            max_edges_per_node: 2,
641            ..Default::default()
642        });
643
644        // Add edges up to limit
645        assert!(detector.record_relation("ip:hub", "fp:a"));
646        assert!(detector.record_relation("ip:hub", "fp:b"));
647
648        // Third edge should be rejected (but relation still returns true since node exists)
649        detector.record_relation("ip:hub", "fp:c");
650
651        // Verify hub only has 2 edges
652        let neighbors = detector.adjacency.get("ip:hub").unwrap();
653        assert_eq!(neighbors.len(), 2);
654    }
655
656    #[test]
657    fn test_bfs_iteration_limit() {
658        let detector = GraphDetector::new(GraphConfig {
659            max_bfs_iterations: 10,
660            max_traversal_depth: 100, // High depth to ensure iteration limit is hit
661            ..Default::default()
662        });
663
664        // Create a chain of nodes
665        for i in 0..20 {
666            detector.record_relation(&format!("ip:{}", i), &format!("fp:{}", i));
667            if i > 0 {
668                detector.record_relation(&format!("fp:{}", i), &format!("ip:{}", i - 1));
669            }
670        }
671
672        // BFS should terminate early
673        let ips = detector.find_connected_ips("ip:0");
674        // Due to iteration limit, we may not find all IPs
675        assert!(
676            ips.len() < 20,
677            "Should have stopped early due to iteration limit"
678        );
679    }
680}