1use crate::types::{
9 CorrelationCluster, CorrelationResult, CorrelationType, EventCorrelation, UserEvent,
10};
11use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
12use std::collections::{HashMap, HashSet};
13
14#[derive(Debug, Clone)]
23pub struct EventCorrelationKernel {
24 metadata: KernelMetadata,
25}
26
27impl Default for EventCorrelationKernel {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl EventCorrelationKernel {
34 #[must_use]
36 pub fn new() -> Self {
37 Self {
38 metadata: KernelMetadata::ring(
39 "behavioral/event-correlation",
40 Domain::BehavioralAnalytics,
41 )
42 .with_description("Event correlation and clustering")
43 .with_throughput(50_000)
44 .with_latency_us(100.0),
45 }
46 }
47
48 pub fn compute(
55 event: &UserEvent,
56 all_events: &[UserEvent],
57 config: &CorrelationConfig,
58 ) -> CorrelationResult {
59 let mut correlations = Vec::new();
60
61 for candidate in all_events {
62 if candidate.id == event.id {
63 continue;
64 }
65
66 if let Some(correlation) = Self::calculate_correlation(event, candidate, config) {
68 correlations.push(correlation);
69 }
70 }
71
72 correlations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
74
75 if let Some(max) = config.max_correlations {
77 correlations.truncate(max);
78 }
79
80 let clusters = Self::build_clusters(&correlations, all_events, config);
82
83 CorrelationResult {
84 event_id: event.id,
85 correlations,
86 clusters,
87 }
88 }
89
90 pub fn compute_batch(
92 events: &[UserEvent],
93 config: &CorrelationConfig,
94 ) -> Vec<CorrelationResult> {
95 events
96 .iter()
97 .map(|e| Self::compute(e, events, config))
98 .collect()
99 }
100
101 fn calculate_correlation(
103 event: &UserEvent,
104 candidate: &UserEvent,
105 config: &CorrelationConfig,
106 ) -> Option<EventCorrelation> {
107 let mut score = 0.0;
108 let mut correlation_types = Vec::new();
109
110 let time_diff = (event.timestamp as i64 - candidate.timestamp as i64).abs();
112 if time_diff <= config.temporal_window_secs as i64 {
113 let temporal_score = 1.0 - (time_diff as f64 / config.temporal_window_secs as f64);
114 score += temporal_score * config.weights.temporal;
115 if temporal_score > 0.5 {
116 correlation_types.push(CorrelationType::Temporal);
117 }
118 }
119
120 if event.user_id == candidate.user_id {
122 score += config.weights.user;
123 correlation_types.push(CorrelationType::User);
124 }
125
126 if let (Some(s1), Some(s2)) = (event.session_id, candidate.session_id) {
128 if s1 == s2 {
129 score += config.weights.session;
130 correlation_types.push(CorrelationType::Session);
131 }
132 }
133
134 if let (Some(d1), Some(d2)) = (&event.device_id, &candidate.device_id) {
136 if d1 == d2 {
137 score += config.weights.device;
138 correlation_types.push(CorrelationType::Device);
139 }
140 }
141
142 if let (Some(l1), Some(l2)) = (&event.location, &candidate.location) {
144 if l1 == l2 {
145 score += config.weights.location;
146 correlation_types.push(CorrelationType::Location);
147 }
148 }
149
150 let max_possible = config.weights.temporal
152 + config.weights.user
153 + config.weights.session
154 + config.weights.device
155 + config.weights.location;
156 score /= max_possible;
157
158 if score < config.min_score {
159 return None;
160 }
161
162 let correlation_type = if correlation_types.is_empty() {
164 CorrelationType::Temporal
165 } else {
166 correlation_types
168 .into_iter()
169 .max_by(|a, b| {
170 Self::type_weight(a, &config.weights)
171 .partial_cmp(&Self::type_weight(b, &config.weights))
172 .unwrap()
173 })
174 .unwrap()
175 };
176
177 Some(EventCorrelation {
178 correlated_event_id: candidate.id,
179 score,
180 correlation_type,
181 time_diff: event.timestamp as i64 - candidate.timestamp as i64,
182 })
183 }
184
185 fn type_weight(t: &CorrelationType, weights: &CorrelationWeights) -> f64 {
187 match t {
188 CorrelationType::Temporal => weights.temporal,
189 CorrelationType::User => weights.user,
190 CorrelationType::Session => weights.session,
191 CorrelationType::Device => weights.device,
192 CorrelationType::Location => weights.location,
193 CorrelationType::Causal => 1.0, }
195 }
196
197 fn build_clusters(
199 correlations: &[EventCorrelation],
200 all_events: &[UserEvent],
201 config: &CorrelationConfig,
202 ) -> Vec<CorrelationCluster> {
203 if correlations.is_empty() {
204 return Vec::new();
205 }
206
207 let id_to_idx: HashMap<u64, usize> = all_events
209 .iter()
210 .enumerate()
211 .map(|(i, e)| (e.id, i))
212 .collect();
213
214 let n = all_events.len();
216 let mut parent: Vec<usize> = (0..n).collect();
217 let mut rank: Vec<usize> = vec![0; n];
218
219 fn find(parent: &mut [usize], i: usize) -> usize {
220 if parent[i] != i {
221 parent[i] = find(parent, parent[i]);
222 }
223 parent[i]
224 }
225
226 fn union(parent: &mut [usize], rank: &mut [usize], x: usize, y: usize) {
227 let px = find(parent, x);
228 let py = find(parent, y);
229
230 if px == py {
231 return;
232 }
233
234 match rank[px].cmp(&rank[py]) {
235 std::cmp::Ordering::Less => parent[px] = py,
236 std::cmp::Ordering::Greater => parent[py] = px,
237 std::cmp::Ordering::Equal => {
238 parent[py] = px;
239 rank[px] += 1;
240 }
241 }
242 }
243
244 for (i, e1) in all_events.iter().enumerate() {
247 for corr in correlations
248 .iter()
249 .filter(|c| c.score >= config.cluster_threshold)
250 {
251 if let Some(&idx2) = id_to_idx.get(&corr.correlated_event_id) {
252 if e1.id != corr.correlated_event_id && idx2 < n {
255 union(&mut parent, &mut rank, i, idx2);
256 }
257 }
258 }
259 }
260
261 let mut cluster_members: HashMap<usize, Vec<u64>> = HashMap::new();
263 let mut cluster_types: HashMap<usize, HashMap<CorrelationType, usize>> = HashMap::new();
264
265 for event in all_events {
266 if let Some(&idx) = id_to_idx.get(&event.id) {
267 let root = find(&mut parent, idx);
268 cluster_members.entry(root).or_default().push(event.id);
269 }
270 }
271
272 for corr in correlations {
274 if let Some(&idx) = id_to_idx.get(&corr.correlated_event_id) {
275 let root = find(&mut parent, idx);
276 *cluster_types
277 .entry(root)
278 .or_default()
279 .entry(corr.correlation_type)
280 .or_insert(0) += 1;
281 }
282 }
283
284 let mut clusters: Vec<CorrelationCluster> = Vec::new();
286 let mut cluster_id = 0u64;
287
288 for (root, event_ids) in cluster_members {
289 if event_ids.len() < 2 {
290 continue; }
292
293 let cluster_event_set: HashSet<_> = event_ids.iter().collect();
295 let internal_correlations: Vec<_> = correlations
296 .iter()
297 .filter(|c| cluster_event_set.contains(&c.correlated_event_id))
298 .collect();
299
300 let coherence = if internal_correlations.is_empty() {
301 0.0
302 } else {
303 internal_correlations.iter().map(|c| c.score).sum::<f64>()
304 / internal_correlations.len() as f64
305 };
306
307 let type_counts = cluster_types.get(&root);
309 let dominant_type = type_counts
310 .and_then(|counts| {
311 counts
312 .iter()
313 .max_by_key(|&(_, count)| *count)
314 .map(|(&t, _)| t)
315 })
316 .unwrap_or(CorrelationType::Temporal);
317
318 clusters.push(CorrelationCluster {
319 id: cluster_id,
320 event_ids,
321 coherence,
322 dominant_type,
323 });
324
325 cluster_id += 1;
326 }
327
328 clusters.sort_by(|a, b| b.coherence.partial_cmp(&a.coherence).unwrap());
330
331 clusters
332 }
333
334 pub fn detect_causal_correlations(
336 events: &[UserEvent],
337 config: &CorrelationConfig,
338 ) -> Vec<EventCorrelation> {
339 let mut causal = Vec::new();
340
341 let mut sorted: Vec<_> = events.iter().collect();
343 sorted.sort_by_key(|e| e.timestamp);
344
345 let mut pair_counts: HashMap<(&str, &str), Vec<i64>> = HashMap::new();
347
348 for window in sorted.windows(2) {
349 let time_diff = (window[1].timestamp - window[0].timestamp) as i64;
350 if time_diff <= config.temporal_window_secs as i64 {
351 pair_counts
352 .entry((&window[0].event_type, &window[1].event_type))
353 .or_default()
354 .push(time_diff);
355 }
356 }
357
358 for ((type_a, type_b), time_diffs) in pair_counts {
360 if time_diffs.len() < 3 {
361 continue;
362 }
363
364 let mean = time_diffs.iter().sum::<i64>() as f64 / time_diffs.len() as f64;
365 let variance = time_diffs
366 .iter()
367 .map(|&t| (t as f64 - mean).powi(2))
368 .sum::<f64>()
369 / time_diffs.len() as f64;
370 let cv = variance.sqrt() / mean.abs().max(1.0); if cv < 0.5 {
374 for window in sorted.windows(2) {
376 if window[0].event_type == *type_a && window[1].event_type == *type_b {
377 let score = 1.0 - cv;
378 causal.push(EventCorrelation {
379 correlated_event_id: window[1].id,
380 score,
381 correlation_type: CorrelationType::Causal,
382 time_diff: (window[1].timestamp - window[0].timestamp) as i64,
383 });
384 }
385 }
386 }
387 }
388
389 causal
390 }
391
392 pub fn find_strongly_correlated(
394 events: &[UserEvent],
395 required_types: &[CorrelationType],
396 ) -> Vec<(u64, u64, f64)> {
397 let mut pairs = Vec::new();
398
399 for (i, e1) in events.iter().enumerate() {
400 for e2 in events.iter().skip(i + 1) {
401 let mut matches = Vec::new();
402
403 for req_type in required_types {
405 let matched = match req_type {
406 CorrelationType::User => e1.user_id == e2.user_id,
407 CorrelationType::Session => {
408 e1.session_id.is_some() && e1.session_id == e2.session_id
409 }
410 CorrelationType::Device => {
411 e1.device_id.is_some() && e1.device_id == e2.device_id
412 }
413 CorrelationType::Location => {
414 e1.location.is_some() && e1.location == e2.location
415 }
416 CorrelationType::Temporal => {
417 (e1.timestamp as i64 - e2.timestamp as i64).abs() < 3600
418 }
419 CorrelationType::Causal => false, };
421 matches.push(matched);
422 }
423
424 if matches.iter().all(|&m| m) {
425 let score =
426 matches.iter().filter(|&&m| m).count() as f64 / required_types.len() as f64;
427 pairs.push((e1.id, e2.id, score));
428 }
429 }
430 }
431
432 pairs
433 }
434}
435
436impl GpuKernel for EventCorrelationKernel {
437 fn metadata(&self) -> &KernelMetadata {
438 &self.metadata
439 }
440}
441
442#[derive(Debug, Clone)]
444pub struct CorrelationConfig {
445 pub temporal_window_secs: u64,
447 pub min_score: f64,
449 pub max_correlations: Option<usize>,
451 pub cluster_threshold: f64,
453 pub weights: CorrelationWeights,
455}
456
457impl Default for CorrelationConfig {
458 fn default() -> Self {
459 Self {
460 temporal_window_secs: 3600, min_score: 0.3,
462 max_correlations: Some(50),
463 cluster_threshold: 0.5,
464 weights: CorrelationWeights::default(),
465 }
466 }
467}
468
469#[derive(Debug, Clone)]
471pub struct CorrelationWeights {
472 pub temporal: f64,
474 pub user: f64,
476 pub session: f64,
478 pub device: f64,
480 pub location: f64,
482}
483
484impl Default for CorrelationWeights {
485 fn default() -> Self {
486 Self {
487 temporal: 0.2,
488 user: 0.3,
489 session: 0.25,
490 device: 0.15,
491 location: 0.1,
492 }
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 fn create_correlated_events() -> Vec<UserEvent> {
501 let base_ts = 1700000000u64;
502 vec![
503 UserEvent {
504 id: 1,
505 user_id: 100,
506 event_type: "login".to_string(),
507 timestamp: base_ts,
508 attributes: HashMap::new(),
509 session_id: Some(1),
510 device_id: Some("device_a".to_string()),
511 ip_address: Some("192.168.1.1".to_string()),
512 location: Some("US".to_string()),
513 },
514 UserEvent {
515 id: 2,
516 user_id: 100,
517 event_type: "view".to_string(),
518 timestamp: base_ts + 30,
519 attributes: HashMap::new(),
520 session_id: Some(1),
521 device_id: Some("device_a".to_string()),
522 ip_address: Some("192.168.1.1".to_string()),
523 location: Some("US".to_string()),
524 },
525 UserEvent {
526 id: 3,
527 user_id: 100,
528 event_type: "purchase".to_string(),
529 timestamp: base_ts + 60,
530 attributes: HashMap::new(),
531 session_id: Some(1),
532 device_id: Some("device_a".to_string()),
533 ip_address: Some("192.168.1.1".to_string()),
534 location: Some("US".to_string()),
535 },
536 UserEvent {
538 id: 4,
539 user_id: 200,
540 event_type: "login".to_string(),
541 timestamp: base_ts + 15,
542 attributes: HashMap::new(),
543 session_id: Some(2),
544 device_id: Some("device_b".to_string()),
545 ip_address: Some("10.0.0.1".to_string()),
546 location: Some("UK".to_string()),
547 },
548 UserEvent {
550 id: 5,
551 user_id: 100,
552 event_type: "login".to_string(),
553 timestamp: base_ts + 7200, attributes: HashMap::new(),
555 session_id: Some(3),
556 device_id: Some("device_a".to_string()),
557 ip_address: Some("192.168.1.1".to_string()),
558 location: Some("US".to_string()),
559 },
560 ]
561 }
562
563 #[test]
564 fn test_correlation_kernel_metadata() {
565 let kernel = EventCorrelationKernel::new();
566 assert_eq!(kernel.metadata().id, "behavioral/event-correlation");
567 assert_eq!(kernel.metadata().domain, Domain::BehavioralAnalytics);
568 }
569
570 #[test]
571 fn test_same_user_correlation() {
572 let events = create_correlated_events();
573 let config = CorrelationConfig::default();
574
575 let result = EventCorrelationKernel::compute(&events[0], &events, &config);
576
577 assert!(!result.correlations.is_empty(), "Should find correlations");
579
580 let same_user_corrs: Vec<_> = result
582 .correlations
583 .iter()
584 .filter(|c| {
585 events
586 .iter()
587 .find(|e| e.id == c.correlated_event_id)
588 .is_some_and(|e| e.user_id == 100)
589 })
590 .collect();
591
592 assert!(!same_user_corrs.is_empty());
593 }
594
595 #[test]
596 fn test_temporal_correlation() {
597 let events = create_correlated_events();
598 let config = CorrelationConfig {
599 temporal_window_secs: 100,
600 ..Default::default()
601 };
602
603 let result = EventCorrelationKernel::compute(&events[0], &events, &config);
604
605 let temporal_corrs: Vec<_> = result
607 .correlations
608 .iter()
609 .filter(|c| c.time_diff.abs() < 100)
610 .collect();
611
612 assert!(!temporal_corrs.is_empty());
613 }
614
615 #[test]
616 fn test_session_correlation() {
617 let events = create_correlated_events();
618 let config = CorrelationConfig::default();
619
620 let result = EventCorrelationKernel::compute(&events[0], &events, &config);
621
622 let session_corrs: Vec<_> = result
624 .correlations
625 .iter()
626 .filter(|c| c.correlation_type == CorrelationType::Session)
627 .collect();
628
629 assert!(
632 result
633 .correlations
634 .iter()
635 .any(|c| c.correlated_event_id == 2 || c.correlated_event_id == 3),
636 "Should correlate with same-session events"
637 );
638 let _ = session_corrs; }
640
641 #[test]
642 fn test_min_score_filter() {
643 let events = create_correlated_events();
644 let config = CorrelationConfig {
645 min_score: 0.8, ..Default::default()
647 };
648
649 let result = EventCorrelationKernel::compute(&events[0], &events, &config);
650
651 assert!(result.correlations.iter().all(|c| c.score >= 0.8));
653 }
654
655 #[test]
656 fn test_max_correlations_limit() {
657 let events = create_correlated_events();
658 let config = CorrelationConfig {
659 max_correlations: Some(2),
660 min_score: 0.0, ..Default::default()
662 };
663
664 let result = EventCorrelationKernel::compute(&events[0], &events, &config);
665
666 assert!(result.correlations.len() <= 2);
667 }
668
669 #[test]
670 fn test_cluster_building() {
671 let events = create_correlated_events();
672 let config = CorrelationConfig {
673 cluster_threshold: 0.3,
674 ..Default::default()
675 };
676
677 let result = EventCorrelationKernel::compute(&events[0], &events, &config);
678
679 for cluster in &result.clusters {
681 assert!(
682 cluster.event_ids.len() >= 2,
683 "Clusters should have 2+ events"
684 );
685 assert!(cluster.coherence >= 0.0 && cluster.coherence <= 1.0);
686 }
687 }
688
689 #[test]
690 fn test_batch_correlation() {
691 let events = create_correlated_events();
692 let config = CorrelationConfig::default();
693
694 let results = EventCorrelationKernel::compute_batch(&events, &config);
695
696 assert_eq!(results.len(), events.len());
697 for result in &results {
698 assert!(events.iter().any(|e| e.id == result.event_id));
699 }
700 }
701
702 #[test]
703 fn test_causal_correlation_detection() {
704 let base_ts = 1700000000u64;
705 let events: Vec<UserEvent> = (0u64..10)
707 .flat_map(|i| {
708 vec![
709 UserEvent {
710 id: i * 2,
711 user_id: 100,
712 event_type: "cause".to_string(),
713 timestamp: base_ts + (i * 1000),
714 attributes: HashMap::new(),
715 session_id: Some(i),
716 device_id: None,
717 ip_address: None,
718 location: None,
719 },
720 UserEvent {
721 id: i * 2 + 1,
722 user_id: 100,
723 event_type: "effect".to_string(),
724 timestamp: base_ts + (i * 1000) + 50, attributes: HashMap::new(),
726 session_id: Some(i),
727 device_id: None,
728 ip_address: None,
729 location: None,
730 },
731 ]
732 })
733 .collect();
734
735 let config = CorrelationConfig::default();
736 let causal = EventCorrelationKernel::detect_causal_correlations(&events, &config);
737
738 assert!(
740 !causal.is_empty(),
741 "Should detect causal correlations in consistent patterns"
742 );
743
744 assert!(
746 causal
747 .iter()
748 .all(|c| c.correlation_type == CorrelationType::Causal)
749 );
750 }
751
752 #[test]
753 fn test_strongly_correlated() {
754 let events = create_correlated_events();
755 let required = vec![CorrelationType::User, CorrelationType::Session];
756
757 let pairs = EventCorrelationKernel::find_strongly_correlated(&events, &required);
758
759 assert!(!pairs.is_empty());
761 assert!(pairs.iter().all(|(_, _, score)| *score == 1.0));
762 }
763
764 #[test]
765 fn test_empty_events() {
766 let events: Vec<UserEvent> = Vec::new();
767 let config = CorrelationConfig::default();
768
769 let result = EventCorrelationKernel::compute(
770 &UserEvent {
771 id: 1,
772 user_id: 100,
773 event_type: "test".to_string(),
774 timestamp: 0,
775 attributes: HashMap::new(),
776 session_id: None,
777 device_id: None,
778 ip_address: None,
779 location: None,
780 },
781 &events,
782 &config,
783 );
784
785 assert!(result.correlations.is_empty());
786 assert!(result.clusters.is_empty());
787 }
788
789 #[test]
790 fn test_correlation_weights() {
791 let events = create_correlated_events();
792
793 let user_config = CorrelationConfig {
795 weights: CorrelationWeights {
796 user: 0.8,
797 session: 0.1,
798 device: 0.05,
799 location: 0.03,
800 temporal: 0.02,
801 },
802 ..Default::default()
803 };
804
805 let result = EventCorrelationKernel::compute(&events[0], &events, &user_config);
806
807 if let Some(same_user) = result
809 .correlations
810 .iter()
811 .find(|c| c.correlated_event_id == 2)
812 {
813 if let Some(diff_user) = result
814 .correlations
815 .iter()
816 .find(|c| c.correlated_event_id == 4)
817 {
818 assert!(
819 same_user.score > diff_user.score,
820 "Same-user correlation should be stronger with high user weight"
821 );
822 }
823 }
824 }
825}