1use crate::types::{CausalEdge, CausalGraphResult, CausalNode, UserEvent};
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone)]
21pub struct CausalGraphConstruction {
22 metadata: KernelMetadata,
23}
24
25impl Default for CausalGraphConstruction {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl CausalGraphConstruction {
32 #[must_use]
34 pub fn new() -> Self {
35 Self {
36 metadata: KernelMetadata::batch("behavioral/causal-graph", Domain::BehavioralAnalytics)
37 .with_description("Causal DAG inference from event streams")
38 .with_throughput(10_000)
39 .with_latency_us(500.0),
40 }
41 }
42
43 pub fn compute(events: &[UserEvent], config: &CausalConfig) -> CausalGraphResult {
49 if events.len() < 2 {
50 return CausalGraphResult {
51 nodes: Vec::new(),
52 edges: Vec::new(),
53 root_causes: Vec::new(),
54 effects: Vec::new(),
55 };
56 }
57
58 let mut sorted_events: Vec<_> = events.iter().collect();
60 sorted_events.sort_by_key(|e| e.timestamp);
61
62 let (nodes, type_to_id) = Self::build_nodes(&sorted_events);
64
65 let edges = Self::build_edges(&sorted_events, &type_to_id, config);
67
68 let root_causes = Self::identify_root_causes(&nodes, &edges);
70
71 let effects = Self::identify_effects(&nodes, &edges);
73
74 CausalGraphResult {
75 nodes,
76 edges,
77 root_causes,
78 effects,
79 }
80 }
81
82 fn build_nodes(events: &[&UserEvent]) -> (Vec<CausalNode>, HashMap<String, u64>) {
84 let mut type_counts: HashMap<&str, u64> = HashMap::new();
85 let total = events.len() as f64;
86
87 for event in events {
88 *type_counts.entry(&event.event_type).or_insert(0) += 1;
89 }
90
91 let mut sorted_types: Vec<_> = type_counts.into_iter().collect();
93 sorted_types.sort_by(|a, b| a.0.cmp(b.0));
94
95 let mut nodes = Vec::new();
96 let mut type_to_id = HashMap::new();
97
98 for (i, (event_type, count)) in sorted_types.iter().enumerate() {
99 let node_id = i as u64;
100 nodes.push(CausalNode {
101 id: node_id,
102 event_type: event_type.to_string(),
103 probability: *count as f64 / total,
104 });
105 type_to_id.insert(event_type.to_string(), node_id);
106 }
107
108 (nodes, type_to_id)
109 }
110
111 fn build_edges(
113 events: &[&UserEvent],
114 type_to_id: &HashMap<String, u64>,
115 config: &CausalConfig,
116 ) -> Vec<CausalEdge> {
117 let mut transitions: HashMap<(u64, u64), TransitionStats> = HashMap::new();
119
120 for window in events.windows(2) {
121 let source_id = type_to_id.get(&window[0].event_type);
122 let target_id = type_to_id.get(&window[1].event_type);
123
124 if let (Some(&src), Some(&tgt)) = (source_id, target_id) {
125 if src == tgt && !config.allow_self_loops {
126 continue;
127 }
128
129 let time_diff = window[1].timestamp.saturating_sub(window[0].timestamp);
130
131 if time_diff > config.max_lag_seconds {
132 continue;
133 }
134
135 let stats = transitions.entry((src, tgt)).or_default();
136 stats.add(time_diff);
137 }
138 }
139
140 let mut source_totals: HashMap<u64, u64> = HashMap::new();
142 for ((src, _), stats) in &transitions {
143 *source_totals.entry(*src).or_insert(0) += stats.count;
144 }
145
146 let mut edges = Vec::new();
148
149 for ((source, target), stats) in transitions {
150 let source_total = source_totals.get(&source).copied().unwrap_or(1);
151 let strength = stats.count as f64 / source_total as f64;
152
153 if strength < config.min_strength {
154 continue;
155 }
156
157 if stats.count < config.min_observations as u64 {
158 continue;
159 }
160
161 edges.push(CausalEdge {
162 source,
163 target,
164 strength,
165 lag: stats.mean_lag(),
166 count: stats.count,
167 });
168 }
169
170 if config.enforce_dag {
172 Self::prune_to_dag(&mut edges);
173 }
174
175 edges
176 }
177
178 fn prune_to_dag(edges: &mut Vec<CausalEdge>) {
180 edges.sort_by(|a, b| {
182 b.strength
183 .partial_cmp(&a.strength)
184 .unwrap()
185 .then_with(|| a.source.cmp(&b.source))
186 .then_with(|| a.target.cmp(&b.target))
187 });
188
189 let mut graph: HashMap<u64, HashSet<u64>> = HashMap::new();
190
191 let mut kept_edges = Vec::new();
193
194 for edge in edges.iter() {
195 if !Self::would_create_cycle(&graph, edge.source, edge.target) {
197 graph.entry(edge.source).or_default().insert(edge.target);
198 kept_edges.push(edge.clone());
199 }
200 }
201
202 *edges = kept_edges;
203 }
204
205 fn would_create_cycle(graph: &HashMap<u64, HashSet<u64>>, source: u64, target: u64) -> bool {
207 let mut visited = HashSet::new();
209 let mut queue = vec![target];
210
211 while let Some(node) = queue.pop() {
212 if node == source {
213 return true;
214 }
215
216 if visited.contains(&node) {
217 continue;
218 }
219 visited.insert(node);
220
221 if let Some(neighbors) = graph.get(&node) {
222 queue.extend(neighbors.iter());
223 }
224 }
225
226 false
227 }
228
229 fn identify_root_causes(nodes: &[CausalNode], edges: &[CausalEdge]) -> Vec<u64> {
231 let mut out_degree: HashMap<u64, u64> = HashMap::new();
232 let mut in_degree: HashMap<u64, u64> = HashMap::new();
233
234 for edge in edges {
235 *out_degree.entry(edge.source).or_insert(0) += 1;
236 *in_degree.entry(edge.target).or_insert(0) += 1;
237 }
238
239 let mut root_scores: Vec<(u64, f64)> = nodes
240 .iter()
241 .map(|n| {
242 let out = out_degree.get(&n.id).copied().unwrap_or(0) as f64;
243 let in_d = in_degree.get(&n.id).copied().unwrap_or(0) as f64;
244 let score = out / (in_d + 1.0);
246 (n.id, score)
247 })
248 .collect();
249
250 root_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
251
252 root_scores
254 .iter()
255 .filter(|(_, score)| *score >= 1.0)
256 .map(|(id, _)| *id)
257 .collect()
258 }
259
260 fn identify_effects(nodes: &[CausalNode], edges: &[CausalEdge]) -> Vec<u64> {
262 let mut out_degree: HashMap<u64, u64> = HashMap::new();
263 let mut in_degree: HashMap<u64, u64> = HashMap::new();
264
265 for edge in edges {
266 *out_degree.entry(edge.source).or_insert(0) += 1;
267 *in_degree.entry(edge.target).or_insert(0) += 1;
268 }
269
270 let mut effect_scores: Vec<(u64, f64)> = nodes
271 .iter()
272 .map(|n| {
273 let out = out_degree.get(&n.id).copied().unwrap_or(0) as f64;
274 let in_d = in_degree.get(&n.id).copied().unwrap_or(0) as f64;
275 let score = in_d / (out + 1.0);
277 (n.id, score)
278 })
279 .collect();
280
281 effect_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
282
283 effect_scores
285 .iter()
286 .filter(|(_, score)| *score >= 1.0)
287 .map(|(id, _)| *id)
288 .collect()
289 }
290
291 pub fn calculate_impact(graph: &CausalGraphResult, event_type: &str) -> CausalImpact {
293 let node_id = graph
294 .nodes
295 .iter()
296 .find(|n| n.event_type == event_type)
297 .map(|n| n.id);
298
299 let node_id = match node_id {
300 Some(id) => id,
301 None => {
302 return CausalImpact {
303 event_type: event_type.to_string(),
304 direct_effects: Vec::new(),
305 indirect_effects: Vec::new(),
306 total_impact: 0.0,
307 };
308 }
309 };
310
311 let direct_effects: Vec<_> = graph
313 .edges
314 .iter()
315 .filter(|e| e.source == node_id)
316 .map(|e| {
317 let target_type = graph
318 .nodes
319 .iter()
320 .find(|n| n.id == e.target)
321 .map(|n| n.event_type.clone())
322 .unwrap_or_default();
323 (target_type, e.strength)
324 })
325 .collect();
326
327 let mut indirect_effects = Vec::new();
329 let mut visited: HashSet<u64> = HashSet::new();
330 visited.insert(node_id);
331
332 let mut current_level: Vec<u64> = direct_effects
333 .iter()
334 .map(|(t, _)| {
335 graph
336 .nodes
337 .iter()
338 .find(|n| n.event_type == *t)
339 .map(|n| n.id)
340 .unwrap_or(0)
341 })
342 .collect();
343
344 let mut depth = 1;
345 while !current_level.is_empty() && depth < 3 {
346 let mut next_level = Vec::new();
347
348 for &node in ¤t_level {
349 if visited.contains(&node) {
350 continue;
351 }
352 visited.insert(node);
353
354 for edge in graph.edges.iter().filter(|e| e.source == node) {
355 let target_type = graph
356 .nodes
357 .iter()
358 .find(|n| n.id == edge.target)
359 .map(|n| n.event_type.clone())
360 .unwrap_or_default();
361
362 let decayed_strength = edge.strength / (depth as f64 + 1.0);
364 indirect_effects.push((target_type, decayed_strength, depth));
365
366 next_level.push(edge.target);
367 }
368 }
369
370 current_level = next_level;
371 depth += 1;
372 }
373
374 let total_impact = direct_effects.iter().map(|(_, s)| s).sum::<f64>()
375 + indirect_effects.iter().map(|(_, s, _)| s).sum::<f64>();
376
377 CausalImpact {
378 event_type: event_type.to_string(),
379 direct_effects,
380 indirect_effects,
381 total_impact,
382 }
383 }
384}
385
386impl GpuKernel for CausalGraphConstruction {
387 fn metadata(&self) -> &KernelMetadata {
388 &self.metadata
389 }
390}
391
392#[derive(Debug, Default)]
394struct TransitionStats {
395 count: u64,
396 total_lag: u64,
397}
398
399impl TransitionStats {
400 fn add(&mut self, lag: u64) {
401 self.count += 1;
402 self.total_lag += lag;
403 }
404
405 fn mean_lag(&self) -> f64 {
406 if self.count == 0 {
407 0.0
408 } else {
409 self.total_lag as f64 / self.count as f64
410 }
411 }
412}
413
414#[derive(Debug, Clone)]
416pub struct CausalConfig {
417 pub min_strength: f64,
419 pub max_lag_seconds: u64,
421 pub min_observations: u32,
423 pub enforce_dag: bool,
425 pub allow_self_loops: bool,
427}
428
429impl Default for CausalConfig {
430 fn default() -> Self {
431 Self {
432 min_strength: 0.1,
433 max_lag_seconds: 3600,
434 min_observations: 3,
435 enforce_dag: true,
436 allow_self_loops: false,
437 }
438 }
439}
440
441#[derive(Debug, Clone)]
443pub struct CausalImpact {
444 pub event_type: String,
446 pub direct_effects: Vec<(String, f64)>,
448 pub indirect_effects: Vec<(String, f64, usize)>,
450 pub total_impact: f64,
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 fn create_causal_chain_events() -> Vec<UserEvent> {
459 let base_ts = 1700000000u64;
460 let mut events = Vec::new();
461
462 for i in 0u64..30 {
464 events.push(UserEvent {
465 id: i * 3,
466 user_id: 100,
467 event_type: "event_a".to_string(),
468 timestamp: base_ts + (i * 1000),
469 attributes: HashMap::new(),
470 session_id: Some(i),
471 device_id: None,
472 ip_address: None,
473 location: None,
474 });
475 events.push(UserEvent {
476 id: i * 3 + 1,
477 user_id: 100,
478 event_type: "event_b".to_string(),
479 timestamp: base_ts + (i * 1000) + 10,
480 attributes: HashMap::new(),
481 session_id: Some(i),
482 device_id: None,
483 ip_address: None,
484 location: None,
485 });
486 events.push(UserEvent {
487 id: i * 3 + 2,
488 user_id: 100,
489 event_type: "event_c".to_string(),
490 timestamp: base_ts + (i * 1000) + 20,
491 attributes: HashMap::new(),
492 session_id: Some(i),
493 device_id: None,
494 ip_address: None,
495 location: None,
496 });
497 }
498
499 events
500 }
501
502 #[test]
503 fn test_causal_graph_metadata() {
504 let kernel = CausalGraphConstruction::new();
505 assert_eq!(kernel.metadata().id, "behavioral/causal-graph");
506 assert_eq!(kernel.metadata().domain, Domain::BehavioralAnalytics);
507 }
508
509 #[test]
510 fn test_causal_graph_construction() {
511 let events = create_causal_chain_events();
512 let config = CausalConfig::default();
513
514 let result = CausalGraphConstruction::compute(&events, &config);
515
516 assert_eq!(result.nodes.len(), 3);
518
519 assert!(
521 result.edges.len() >= 2,
522 "Should have at least 2 edges, got {}",
523 result.edges.len()
524 );
525 }
526
527 #[test]
528 fn test_root_cause_identification() {
529 let events = create_causal_chain_events();
530 let config = CausalConfig {
532 max_lag_seconds: 100, ..Default::default()
534 };
535
536 let result = CausalGraphConstruction::compute(&events, &config);
537
538 let a_node_id = result
540 .nodes
541 .iter()
542 .find(|n| n.event_type == "event_a")
543 .map(|n| n.id);
544
545 if let Some(a_id) = a_node_id {
546 assert!(
547 result.root_causes.contains(&a_id),
548 "event_a should be root cause"
549 );
550 }
551 }
552
553 #[test]
554 fn test_effect_identification() {
555 let events = create_causal_chain_events();
556 let config = CausalConfig {
558 max_lag_seconds: 100, ..Default::default()
560 };
561
562 let result = CausalGraphConstruction::compute(&events, &config);
563
564 let c_node_id = result
566 .nodes
567 .iter()
568 .find(|n| n.event_type == "event_c")
569 .map(|n| n.id);
570
571 if let Some(c_id) = c_node_id {
572 assert!(
573 result.effects.contains(&c_id),
574 "event_c should be an effect"
575 );
576 }
577 }
578
579 #[test]
580 fn test_causal_strength() {
581 let events = create_causal_chain_events();
582 let config = CausalConfig::default();
583
584 let result = CausalGraphConstruction::compute(&events, &config);
585
586 let a_id = result
588 .nodes
589 .iter()
590 .find(|n| n.event_type == "event_a")
591 .map(|n| n.id)
592 .unwrap();
593 let b_id = result
594 .nodes
595 .iter()
596 .find(|n| n.event_type == "event_b")
597 .map(|n| n.id)
598 .unwrap();
599
600 let ab_edge = result
601 .edges
602 .iter()
603 .find(|e| e.source == a_id && e.target == b_id);
604
605 assert!(ab_edge.is_some(), "Should have A->B edge");
606 assert!(
607 ab_edge.unwrap().strength > 0.5,
608 "A->B should have high strength"
609 );
610 }
611
612 #[test]
613 fn test_dag_enforcement() {
614 let base_ts = 1700000000u64;
616 let mut events = Vec::new();
617
618 for i in 0u64..20 {
619 events.push(UserEvent {
620 id: i * 2,
621 user_id: 100,
622 event_type: "type_a".to_string(),
623 timestamp: base_ts + (i * 100),
624 attributes: HashMap::new(),
625 session_id: None,
626 device_id: None,
627 ip_address: None,
628 location: None,
629 });
630 events.push(UserEvent {
631 id: i * 2 + 1,
632 user_id: 100,
633 event_type: "type_b".to_string(),
634 timestamp: base_ts + (i * 100) + 10,
635 attributes: HashMap::new(),
636 session_id: None,
637 device_id: None,
638 ip_address: None,
639 location: None,
640 });
641 }
642
643 let config = CausalConfig {
644 enforce_dag: true,
645 ..Default::default()
646 };
647
648 let result = CausalGraphConstruction::compute(&events, &config);
649
650 let has_cycle = detect_cycle(&result);
652 assert!(!has_cycle, "DAG should have no cycles");
653 }
654
655 fn detect_cycle(graph: &CausalGraphResult) -> bool {
656 let mut adjacency: HashMap<u64, Vec<u64>> = HashMap::new();
657 for edge in &graph.edges {
658 adjacency.entry(edge.source).or_default().push(edge.target);
659 }
660
661 let mut visited = HashSet::new();
662 let mut rec_stack = HashSet::new();
663
664 for node in &graph.nodes {
665 if dfs_cycle(&adjacency, node.id, &mut visited, &mut rec_stack) {
666 return true;
667 }
668 }
669 false
670 }
671
672 fn dfs_cycle(
673 adj: &HashMap<u64, Vec<u64>>,
674 node: u64,
675 visited: &mut HashSet<u64>,
676 rec_stack: &mut HashSet<u64>,
677 ) -> bool {
678 if rec_stack.contains(&node) {
679 return true;
680 }
681 if visited.contains(&node) {
682 return false;
683 }
684
685 visited.insert(node);
686 rec_stack.insert(node);
687
688 if let Some(neighbors) = adj.get(&node) {
689 for &neighbor in neighbors {
690 if dfs_cycle(adj, neighbor, visited, rec_stack) {
691 return true;
692 }
693 }
694 }
695
696 rec_stack.remove(&node);
697 false
698 }
699
700 #[test]
701 fn test_impact_analysis() {
702 let events = create_causal_chain_events();
703 let config = CausalConfig::default();
704
705 let graph = CausalGraphConstruction::compute(&events, &config);
706
707 assert_eq!(graph.nodes.len(), 3, "Should have 3 event types");
709 assert!(!graph.edges.is_empty(), "Graph should have edges");
710
711 let impact = CausalGraphConstruction::calculate_impact(&graph, "event_a");
713
714 assert_eq!(impact.event_type, "event_a");
715 assert!(impact.total_impact >= 0.0);
719 }
720
721 #[test]
722 fn test_empty_events() {
723 let config = CausalConfig::default();
724 let result = CausalGraphConstruction::compute(&[], &config);
725
726 assert!(result.nodes.is_empty());
727 assert!(result.edges.is_empty());
728 }
729
730 #[test]
731 fn test_min_observations_filter() {
732 let base_ts = 1700000000u64;
733 let events = vec![
734 UserEvent {
735 id: 1,
736 user_id: 100,
737 event_type: "rare_a".to_string(),
738 timestamp: base_ts,
739 attributes: HashMap::new(),
740 session_id: None,
741 device_id: None,
742 ip_address: None,
743 location: None,
744 },
745 UserEvent {
746 id: 2,
747 user_id: 100,
748 event_type: "rare_b".to_string(),
749 timestamp: base_ts + 10,
750 attributes: HashMap::new(),
751 session_id: None,
752 device_id: None,
753 ip_address: None,
754 location: None,
755 },
756 ];
757
758 let config = CausalConfig {
759 min_observations: 5, ..Default::default()
761 };
762
763 let result = CausalGraphConstruction::compute(&events, &config);
764
765 assert!(
767 result.edges.is_empty(),
768 "Should filter out edges with few observations"
769 );
770 }
771}