1use super::{Detector, DetectorResult};
16use crate::correlation::{CampaignUpdate, CorrelationReason, CorrelationType, FingerprintIndex};
17use dashmap::DashMap;
18use sha2::{Digest, Sha256};
19use std::collections::{HashSet, VecDeque};
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::{Duration, Instant};
22
23#[derive(Debug, Clone, Default)]
25pub struct GraphExportOptions {
26 pub limit: Option<usize>,
28 pub offset: Option<usize>,
30 pub hash_identifiers: bool,
32}
33
34#[derive(Debug, Clone, serde::Serialize)]
36pub struct PaginatedGraph {
37 pub nodes: Vec<serde_json::Value>,
39 pub edges: Vec<serde_json::Value>,
40 pub total_nodes: usize,
42 pub has_more: bool,
44 pub snapshot_version: u64,
46}
47
48fn hash_identifier(id: &str) -> String {
50 let mut hasher = Sha256::new();
51 hasher.update(id.as_bytes());
52 let result = hasher.finalize();
53 format!("{:x}", result)[..12].to_string() }
55
56#[derive(Debug, Clone)]
58pub struct GraphConfig {
59 pub min_component_size: usize,
62
63 pub max_traversal_depth: usize,
66
67 pub edge_ttl: Duration,
70
71 pub weight: u8,
74
75 pub max_nodes: usize,
78
79 pub max_edges_per_node: usize,
82
83 pub max_bfs_iterations: usize,
86}
87
88impl Default for GraphConfig {
89 fn default() -> Self {
90 Self {
91 min_component_size: 3,
92 max_traversal_depth: 3,
93 edge_ttl: Duration::from_secs(3600),
94 weight: 20,
95 max_nodes: 10_000,
96 max_edges_per_node: 1_000,
97 max_bfs_iterations: 50_000,
98 }
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Hash)]
104struct GraphNode {
105 id: String,
106 node_type: NodeType,
107 last_seen: Instant,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Hash)]
111enum NodeType {
112 Ip,
113 Fingerprint,
114 Token,
115 Asn,
116 Other,
117}
118
119impl NodeType {
120 fn from_id(id: &str) -> Self {
121 if id.starts_with("ip:") {
122 NodeType::Ip
123 } else if id.starts_with("fp:") {
124 NodeType::Fingerprint
125 } else if id.starts_with("token:") {
126 NodeType::Token
127 } else if id.starts_with("asn:") {
128 NodeType::Asn
129 } else {
130 NodeType::Other
131 }
132 }
133}
134
135pub struct GraphDetector {
137 config: GraphConfig,
138 adjacency: DashMap<String, HashSet<String>>,
141 nodes: DashMap<String, GraphNode>,
143 last_cleanup: std::sync::Mutex<Instant>,
145 edges_count: AtomicU64,
147}
148
149impl GraphDetector {
150 pub fn new(config: GraphConfig) -> Self {
151 Self {
152 config,
153 adjacency: DashMap::new(),
154 nodes: DashMap::new(),
155 last_cleanup: std::sync::Mutex::new(Instant::now()),
156 edges_count: AtomicU64::new(0),
157 }
158 }
159
160 pub fn record_relation(&self, entity_a: &str, entity_b: &str) -> bool {
165 if entity_a == entity_b {
166 return true;
167 }
168
169 let now = Instant::now();
170
171 let current_node_count = self.nodes.len();
173 let is_a_new = !self.nodes.contains_key(entity_a);
174 let is_b_new = !self.nodes.contains_key(entity_b);
175 let new_nodes_needed = (is_a_new as usize) + (is_b_new as usize);
176
177 if current_node_count + new_nodes_needed > self.config.max_nodes {
178 tracing::warn!(
179 current = current_node_count,
180 max = self.config.max_nodes,
181 "Graph node limit reached, skipping relation"
182 );
183 return false;
184 }
185
186 self.update_node(entity_a, now);
188 self.update_node(entity_b, now);
189
190 let mut edge_added = false;
192
193 {
195 let mut entry = self.adjacency.entry(entity_a.to_string()).or_default();
196 if entry.len() < self.config.max_edges_per_node {
197 entry.insert(entity_b.to_string());
198 edge_added = true;
199 } else {
200 tracing::debug!(
201 node = entity_a,
202 edges = entry.len(),
203 "Edge limit reached for node"
204 );
205 }
206 }
207
208 {
210 let mut entry = self.adjacency.entry(entity_b.to_string()).or_default();
211 if entry.len() < self.config.max_edges_per_node {
212 entry.insert(entity_a.to_string());
213 }
214 }
215
216 if edge_added {
217 self.edges_count.fetch_add(1, Ordering::Relaxed);
218 }
219
220 true
221 }
222
223 fn update_node(&self, id: &str, now: Instant) {
225 self.nodes
227 .entry(id.to_string())
228 .and_modify(|node| {
229 node.last_seen = now;
230 })
231 .or_insert_with(|| GraphNode {
232 id: id.to_string(),
233 node_type: NodeType::from_id(id),
234 last_seen: now,
235 });
236 }
237
238 pub fn ip_id(ip: &str) -> String {
240 format!("ip:{}", ip)
241 }
242 pub fn fp_id(fp: &str) -> String {
243 format!("fp:{}", fp)
244 }
245 pub fn token_id(token: &str) -> String {
246 format!("token:{}", token)
247 }
248 pub fn asn_id(asn: &str) -> String {
249 format!("asn:{}", asn)
250 }
251
252 fn find_connected_ips(&self, start_node: &str) -> HashSet<String> {
254 let mut visited = HashSet::new();
255 let mut queue = VecDeque::new();
256 let mut ips = HashSet::new();
257 let mut iterations: usize = 0;
258
259 queue.push_back((start_node.to_string(), 0));
260 visited.insert(start_node.to_string());
261
262 while let Some((current_id, depth)) = queue.pop_front() {
263 iterations += 1;
265 if iterations > self.config.max_bfs_iterations {
266 tracing::warn!(
267 start = start_node,
268 iterations = iterations,
269 max = self.config.max_bfs_iterations,
270 "BFS iteration limit reached, returning partial result"
271 );
272 break;
273 }
274
275 if depth >= self.config.max_traversal_depth {
276 continue;
277 }
278
279 if NodeType::from_id(¤t_id) == NodeType::Ip {
281 if let Some(ip) = current_id.strip_prefix("ip:") {
283 ips.insert(ip.to_string());
284 }
285 }
286
287 if let Some(neighbors) = self.adjacency.get(¤t_id) {
289 for neighbor in neighbors.iter() {
290 if !visited.contains(neighbor) {
291 visited.insert(neighbor.clone());
292 queue.push_back((neighbor.clone(), depth + 1));
293 }
294 }
295 }
296 }
297
298 ips
299 }
300
301 pub fn get_cytoscape_data(&self, ips: &[String]) -> serde_json::Value {
304 let result = self.get_cytoscape_data_paginated(ips, GraphExportOptions::default());
305 serde_json::json!({
306 "nodes": result.nodes,
307 "edges": result.edges
308 })
309 }
310
311 pub fn get_cytoscape_data_paginated(
315 &self,
316 ips: &[String],
317 options: GraphExportOptions,
318 ) -> PaginatedGraph {
319 let limit = options.limit.unwrap_or(500);
320 let offset = options.offset.unwrap_or(0);
321 let hash_ids = options.hash_identifiers;
322
323 let mut all_nodes = Vec::new();
324 let mut edges = Vec::new();
325 let mut visited = HashSet::new();
326 let mut queue = VecDeque::new();
327
328 for ip in ips {
330 let id = Self::ip_id(ip);
331 if !visited.contains(&id) {
332 visited.insert(id.clone());
333 queue.push_back((id, 0));
334 }
335 }
336
337 while let Some((current_id, depth)) = queue.pop_front() {
338 let display_id = if hash_ids {
340 let node_type = NodeType::from_id(¤t_id);
341 let prefix = match node_type {
342 NodeType::Ip => "ip",
343 NodeType::Fingerprint => "fp",
344 NodeType::Token => "tok",
345 NodeType::Asn => "asn",
346 _ => "unk",
347 };
348 format!("{}:{}", prefix, hash_identifier(¤t_id))
349 } else {
350 current_id.clone()
351 };
352
353 let node_type = NodeType::from_id(¤t_id);
355 all_nodes.push((
356 current_id.clone(),
357 serde_json::json!({
358 "data": {
359 "id": display_id.clone(),
360 "label": if hash_ids {
361 display_id.split(':').nth(1).unwrap_or(&display_id).to_string()
362 } else {
363 current_id.split(':').nth(1).unwrap_or(¤t_id).to_string()
364 },
365 "type": match node_type {
366 NodeType::Ip => "ip",
367 NodeType::Fingerprint => "actor", NodeType::Token => "token",
369 NodeType::Asn => "asn",
370 _ => "other",
371 }
372 }
373 }),
374 ));
375
376 if depth >= self.config.max_traversal_depth {
377 continue;
378 }
379
380 if let Some(neighbors) = self.adjacency.get(¤t_id) {
382 for neighbor in neighbors.iter() {
383 let source_display = if hash_ids {
385 let node_type = NodeType::from_id(¤t_id);
386 let prefix = match node_type {
387 NodeType::Ip => "ip",
388 NodeType::Fingerprint => "fp",
389 NodeType::Token => "tok",
390 NodeType::Asn => "asn",
391 _ => "unk",
392 };
393 format!("{}:{}", prefix, hash_identifier(¤t_id))
394 } else {
395 current_id.clone()
396 };
397
398 let target_display = if hash_ids {
399 let node_type = NodeType::from_id(neighbor);
400 let prefix = match node_type {
401 NodeType::Ip => "ip",
402 NodeType::Fingerprint => "fp",
403 NodeType::Token => "tok",
404 NodeType::Asn => "asn",
405 _ => "unk",
406 };
407 format!("{}:{}", prefix, hash_identifier(neighbor))
408 } else {
409 neighbor.clone()
410 };
411
412 let mut edge_ids = [source_display.as_str(), target_display.as_str()];
414 edge_ids.sort();
415 let edge_id = format!("e_{}_{}", edge_ids[0], edge_ids[1]);
416
417 edges.push(serde_json::json!({
418 "data": {
419 "id": edge_id,
420 "source": source_display,
421 "target": target_display,
422 "label": "linked"
423 }
424 }));
425
426 if !visited.contains(neighbor) {
427 visited.insert(neighbor.clone());
428 queue.push_back((neighbor.clone(), depth + 1));
429 }
430 }
431 }
432 }
433
434 let total_nodes = all_nodes.len();
435
436 let paginated_nodes: Vec<serde_json::Value> = all_nodes
438 .into_iter()
439 .skip(offset)
440 .take(limit)
441 .map(|(_, node)| node)
442 .collect();
443
444 let mut unique_edges = Vec::new();
446 let mut edge_id_set = HashSet::new();
447 for edge in edges {
448 let id = edge["data"]["id"].as_str().unwrap().to_string();
449 if edge_id_set.insert(id) {
450 unique_edges.push(edge);
451 }
452 }
453
454 PaginatedGraph {
455 nodes: paginated_nodes,
456 edges: unique_edges,
457 total_nodes,
458 has_more: offset + limit < total_nodes,
459 snapshot_version: self.edges_count.load(Ordering::Relaxed),
460 }
461 }
462
463 fn cleanup(&self) {
465 let now = Instant::now();
466 let ttl = self.config.edge_ttl;
467
468 self.nodes
470 .retain(|_, node| now.duration_since(node.last_seen) < ttl);
471
472 self.adjacency.retain(|k, _| self.nodes.contains_key(k));
475
476 }
483}
484
485impl Detector for GraphDetector {
486 fn name(&self) -> &'static str {
487 "graph_correlation"
488 }
489
490 fn analyze(&self, _index: &FingerprintIndex) -> DetectorResult<Vec<CampaignUpdate>> {
491 let mut updates = Vec::new();
492 let mut processed_ips = HashSet::new();
493
494 let ip_nodes: Vec<String> = self
497 .nodes
498 .iter()
499 .filter(|r| r.value().node_type == NodeType::Ip)
500 .map(|r| r.key().clone())
501 .collect();
502
503 for ip_node in ip_nodes {
504 let raw_ip = ip_node.strip_prefix("ip:").unwrap_or(&ip_node);
507 if processed_ips.contains(raw_ip) {
508 continue;
509 }
510
511 let component_ips = self.find_connected_ips(&ip_node);
513
514 for ip in &component_ips {
516 processed_ips.insert(ip.clone());
517 }
518
519 if component_ips.len() >= self.config.min_component_size {
521 let reason = CorrelationReason {
522 correlation_type: CorrelationType::BehavioralSimilarity, confidence: 0.9, evidence: component_ips.into_iter().collect(),
525 description: format!(
526 "Graph correlation: {} IPs connected via shared attributes (depth {})",
527 self.config.min_component_size, self.config.max_traversal_depth
528 ),
529 };
530
531 updates.push(CampaignUpdate {
532 campaign_id: None, status: None,
534 risk_score: None,
535 add_correlation_reason: Some(reason),
536 attack_types: Some(vec!["coordinated_botnet".to_string()]),
537 confidence: Some(0.9),
538 add_member_ips: None,
539 increment_requests: None,
540 increment_blocked: None,
541 increment_rules: None,
542 });
543 }
544 }
545
546 if let Ok(mut last) = self.last_cleanup.try_lock() {
548 if last.elapsed() > Duration::from_secs(300) {
549 *last = Instant::now();
550 self.cleanup();
554 }
555 }
556
557 Ok(updates)
558 }
559
560 fn should_trigger(&self, _ip: &std::net::IpAddr, _index: &FingerprintIndex) -> bool {
561 false
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_graph_connection() {
574 let detector = GraphDetector::new(GraphConfig::default());
575
576 assert!(detector.record_relation(
578 &GraphDetector::ip_id("1.1.1.1"),
579 &GraphDetector::fp_id("fp_a")
580 ));
581 assert!(detector.record_relation(
582 &GraphDetector::fp_id("fp_a"),
583 &GraphDetector::ip_id("2.2.2.2")
584 ));
585
586 let ips = detector.find_connected_ips(&GraphDetector::ip_id("1.1.1.1"));
587 assert!(ips.contains("1.1.1.1"));
588 assert!(ips.contains("2.2.2.2"));
589 assert_eq!(ips.len(), 2);
590 }
591
592 #[test]
593 fn test_component_detection() {
594 let detector = GraphDetector::new(GraphConfig {
597 min_component_size: 3,
598 max_traversal_depth: 5, ..Default::default()
600 });
601
602 assert!(detector.record_relation("ip:1", "fp:a"));
604 assert!(detector.record_relation("fp:a", "ip:2"));
605 assert!(detector.record_relation("ip:2", "tok:x"));
606 assert!(detector.record_relation("tok:x", "ip:3"));
607
608 let updates = detector.analyze(&FingerprintIndex::new()).unwrap();
609 assert_eq!(updates.len(), 1);
610
611 let update = &updates[0];
612 let reason = update.add_correlation_reason.as_ref().unwrap();
613 assert!(reason.evidence.contains(&"1".to_string()));
614 assert!(reason.evidence.contains(&"2".to_string()));
615 assert!(reason.evidence.contains(&"3".to_string()));
616 }
617
618 #[test]
619 fn test_node_limit_enforced() {
620 let detector = GraphDetector::new(GraphConfig {
621 max_nodes: 5,
622 ..Default::default()
623 });
624
625 assert!(detector.record_relation("ip:1", "fp:a")); assert!(detector.record_relation("ip:2", "fp:b")); assert!(detector.record_relation("ip:3", "fp:a")); assert!(!detector.record_relation("ip:4", "fp:c")); assert!(detector.record_relation("ip:1", "ip:2"));
635 }
636
637 #[test]
638 fn test_edge_limit_enforced() {
639 let detector = GraphDetector::new(GraphConfig {
640 max_edges_per_node: 2,
641 ..Default::default()
642 });
643
644 assert!(detector.record_relation("ip:hub", "fp:a"));
646 assert!(detector.record_relation("ip:hub", "fp:b"));
647
648 detector.record_relation("ip:hub", "fp:c");
650
651 let neighbors = detector.adjacency.get("ip:hub").unwrap();
653 assert_eq!(neighbors.len(), 2);
654 }
655
656 #[test]
657 fn test_bfs_iteration_limit() {
658 let detector = GraphDetector::new(GraphConfig {
659 max_bfs_iterations: 10,
660 max_traversal_depth: 100, ..Default::default()
662 });
663
664 for i in 0..20 {
666 detector.record_relation(&format!("ip:{}", i), &format!("fp:{}", i));
667 if i > 0 {
668 detector.record_relation(&format!("fp:{}", i), &format!("ip:{}", i - 1));
669 }
670 }
671
672 let ips = detector.find_connected_ips("ip:0");
674 assert!(
676 ips.len() < 20,
677 "Should have stopped early due to iteration limit"
678 );
679 }
680}