rustkernel_behavioral/
correlation.rs

1//! Event correlation kernels.
2//!
3//! This module provides event correlation analysis:
4//! - Temporal correlation detection
5//! - User/session/device-based correlation
6//! - Event clustering
7
8use crate::types::{
9    CorrelationCluster, CorrelationResult, CorrelationType, EventCorrelation, UserEvent,
10};
11use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
12use std::collections::{HashMap, HashSet};
13
14// ============================================================================
15// Event Correlation Kernel
16// ============================================================================
17
18/// Event correlation kernel.
19///
20/// Identifies correlated events based on temporal, user, session,
21/// device, and location relationships.
22#[derive(Debug, Clone)]
23pub struct EventCorrelationKernel {
24    metadata: KernelMetadata,
25}
26
27impl Default for EventCorrelationKernel {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl EventCorrelationKernel {
34    /// Create a new event correlation kernel.
35    #[must_use]
36    pub fn new() -> Self {
37        Self {
38            metadata: KernelMetadata::ring(
39                "behavioral/event-correlation",
40                Domain::BehavioralAnalytics,
41            )
42            .with_description("Event correlation and clustering")
43            .with_throughput(50_000)
44            .with_latency_us(100.0),
45        }
46    }
47
48    /// Find correlations for an event.
49    ///
50    /// # Arguments
51    /// * `event` - The event to find correlations for
52    /// * `all_events` - Pool of events to correlate against
53    /// * `config` - Correlation configuration
54    pub fn compute(
55        event: &UserEvent,
56        all_events: &[UserEvent],
57        config: &CorrelationConfig,
58    ) -> CorrelationResult {
59        let mut correlations = Vec::new();
60
61        for candidate in all_events {
62            if candidate.id == event.id {
63                continue;
64            }
65
66            // Calculate correlation score and type
67            if let Some(correlation) = Self::calculate_correlation(event, candidate, config) {
68                correlations.push(correlation);
69            }
70        }
71
72        // Sort by score descending
73        correlations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
74
75        // Limit results
76        if let Some(max) = config.max_correlations {
77            correlations.truncate(max);
78        }
79
80        // Build clusters from correlations
81        let clusters = Self::build_clusters(&correlations, all_events, config);
82
83        CorrelationResult {
84            event_id: event.id,
85            correlations,
86            clusters,
87        }
88    }
89
90    /// Batch correlation analysis for multiple events.
91    pub fn compute_batch(
92        events: &[UserEvent],
93        config: &CorrelationConfig,
94    ) -> Vec<CorrelationResult> {
95        events
96            .iter()
97            .map(|e| Self::compute(e, events, config))
98            .collect()
99    }
100
101    /// Calculate correlation between two events.
102    fn calculate_correlation(
103        event: &UserEvent,
104        candidate: &UserEvent,
105        config: &CorrelationConfig,
106    ) -> Option<EventCorrelation> {
107        let mut score = 0.0;
108        let mut correlation_types = Vec::new();
109
110        // Temporal correlation
111        let time_diff = (event.timestamp as i64 - candidate.timestamp as i64).abs();
112        if time_diff <= config.temporal_window_secs as i64 {
113            let temporal_score = 1.0 - (time_diff as f64 / config.temporal_window_secs as f64);
114            score += temporal_score * config.weights.temporal;
115            if temporal_score > 0.5 {
116                correlation_types.push(CorrelationType::Temporal);
117            }
118        }
119
120        // User correlation
121        if event.user_id == candidate.user_id {
122            score += config.weights.user;
123            correlation_types.push(CorrelationType::User);
124        }
125
126        // Session correlation
127        if let (Some(s1), Some(s2)) = (event.session_id, candidate.session_id) {
128            if s1 == s2 {
129                score += config.weights.session;
130                correlation_types.push(CorrelationType::Session);
131            }
132        }
133
134        // Device correlation
135        if let (Some(d1), Some(d2)) = (&event.device_id, &candidate.device_id) {
136            if d1 == d2 {
137                score += config.weights.device;
138                correlation_types.push(CorrelationType::Device);
139            }
140        }
141
142        // Location correlation
143        if let (Some(l1), Some(l2)) = (&event.location, &candidate.location) {
144            if l1 == l2 {
145                score += config.weights.location;
146                correlation_types.push(CorrelationType::Location);
147            }
148        }
149
150        // Normalize score
151        let max_possible = config.weights.temporal
152            + config.weights.user
153            + config.weights.session
154            + config.weights.device
155            + config.weights.location;
156        score /= max_possible;
157
158        if score < config.min_score {
159            return None;
160        }
161
162        // Determine dominant correlation type
163        let correlation_type = if correlation_types.is_empty() {
164            CorrelationType::Temporal
165        } else {
166            // Return the strongest type based on weights
167            correlation_types
168                .into_iter()
169                .max_by(|a, b| {
170                    Self::type_weight(a, &config.weights)
171                        .partial_cmp(&Self::type_weight(b, &config.weights))
172                        .unwrap()
173                })
174                .unwrap()
175        };
176
177        Some(EventCorrelation {
178            correlated_event_id: candidate.id,
179            score,
180            correlation_type,
181            time_diff: event.timestamp as i64 - candidate.timestamp as i64,
182        })
183    }
184
185    /// Get weight for a correlation type.
186    fn type_weight(t: &CorrelationType, weights: &CorrelationWeights) -> f64 {
187        match t {
188            CorrelationType::Temporal => weights.temporal,
189            CorrelationType::User => weights.user,
190            CorrelationType::Session => weights.session,
191            CorrelationType::Device => weights.device,
192            CorrelationType::Location => weights.location,
193            CorrelationType::Causal => 1.0, // Causal is highest priority if detected
194        }
195    }
196
197    /// Build clusters from correlations using union-find.
198    fn build_clusters(
199        correlations: &[EventCorrelation],
200        all_events: &[UserEvent],
201        config: &CorrelationConfig,
202    ) -> Vec<CorrelationCluster> {
203        if correlations.is_empty() {
204            return Vec::new();
205        }
206
207        // Build event ID to index mapping
208        let id_to_idx: HashMap<u64, usize> = all_events
209            .iter()
210            .enumerate()
211            .map(|(i, e)| (e.id, i))
212            .collect();
213
214        // Union-Find data structure
215        let n = all_events.len();
216        let mut parent: Vec<usize> = (0..n).collect();
217        let mut rank: Vec<usize> = vec![0; n];
218
219        fn find(parent: &mut [usize], i: usize) -> usize {
220            if parent[i] != i {
221                parent[i] = find(parent, parent[i]);
222            }
223            parent[i]
224        }
225
226        fn union(parent: &mut [usize], rank: &mut [usize], x: usize, y: usize) {
227            let px = find(parent, x);
228            let py = find(parent, y);
229
230            if px == py {
231                return;
232            }
233
234            match rank[px].cmp(&rank[py]) {
235                std::cmp::Ordering::Less => parent[px] = py,
236                std::cmp::Ordering::Greater => parent[py] = px,
237                std::cmp::Ordering::Equal => {
238                    parent[py] = px;
239                    rank[px] += 1;
240                }
241            }
242        }
243
244        // Build high-score correlation edges
245        // For each correlation, union the correlated events
246        for (i, e1) in all_events.iter().enumerate() {
247            for corr in correlations
248                .iter()
249                .filter(|c| c.score >= config.cluster_threshold)
250            {
251                if let Some(&idx2) = id_to_idx.get(&corr.correlated_event_id) {
252                    // Check if e1 could be the source of this correlation
253                    // by checking if their IDs are related
254                    if e1.id != corr.correlated_event_id && idx2 < n {
255                        union(&mut parent, &mut rank, i, idx2);
256                    }
257                }
258            }
259        }
260
261        // Group events by cluster
262        let mut cluster_members: HashMap<usize, Vec<u64>> = HashMap::new();
263        let mut cluster_types: HashMap<usize, HashMap<CorrelationType, usize>> = HashMap::new();
264
265        for event in all_events {
266            if let Some(&idx) = id_to_idx.get(&event.id) {
267                let root = find(&mut parent, idx);
268                cluster_members.entry(root).or_default().push(event.id);
269            }
270        }
271
272        // Calculate dominant types from correlations
273        for corr in correlations {
274            if let Some(&idx) = id_to_idx.get(&corr.correlated_event_id) {
275                let root = find(&mut parent, idx);
276                *cluster_types
277                    .entry(root)
278                    .or_default()
279                    .entry(corr.correlation_type)
280                    .or_insert(0) += 1;
281            }
282        }
283
284        // Build cluster results
285        let mut clusters: Vec<CorrelationCluster> = Vec::new();
286        let mut cluster_id = 0u64;
287
288        for (root, event_ids) in cluster_members {
289            if event_ids.len() < 2 {
290                continue; // Skip singleton clusters
291            }
292
293            // Calculate coherence (average correlation score within cluster)
294            let cluster_event_set: HashSet<_> = event_ids.iter().collect();
295            let internal_correlations: Vec<_> = correlations
296                .iter()
297                .filter(|c| cluster_event_set.contains(&c.correlated_event_id))
298                .collect();
299
300            let coherence = if internal_correlations.is_empty() {
301                0.0
302            } else {
303                internal_correlations.iter().map(|c| c.score).sum::<f64>()
304                    / internal_correlations.len() as f64
305            };
306
307            // Find dominant type
308            let type_counts = cluster_types.get(&root);
309            let dominant_type = type_counts
310                .and_then(|counts| {
311                    counts
312                        .iter()
313                        .max_by_key(|&(_, count)| *count)
314                        .map(|(&t, _)| t)
315                })
316                .unwrap_or(CorrelationType::Temporal);
317
318            clusters.push(CorrelationCluster {
319                id: cluster_id,
320                event_ids,
321                coherence,
322                dominant_type,
323            });
324
325            cluster_id += 1;
326        }
327
328        // Sort by coherence descending
329        clusters.sort_by(|a, b| b.coherence.partial_cmp(&a.coherence).unwrap());
330
331        clusters
332    }
333
334    /// Detect causal correlations (A causes B pattern).
335    pub fn detect_causal_correlations(
336        events: &[UserEvent],
337        config: &CorrelationConfig,
338    ) -> Vec<EventCorrelation> {
339        let mut causal = Vec::new();
340
341        // Sort by timestamp
342        let mut sorted: Vec<_> = events.iter().collect();
343        sorted.sort_by_key(|e| e.timestamp);
344
345        // Look for consistent A->B patterns
346        let mut pair_counts: HashMap<(&str, &str), Vec<i64>> = HashMap::new();
347
348        for window in sorted.windows(2) {
349            let time_diff = (window[1].timestamp - window[0].timestamp) as i64;
350            if time_diff <= config.temporal_window_secs as i64 {
351                pair_counts
352                    .entry((&window[0].event_type, &window[1].event_type))
353                    .or_default()
354                    .push(time_diff);
355            }
356        }
357
358        // Find pairs with consistent timing (low variance)
359        for ((type_a, type_b), time_diffs) in pair_counts {
360            if time_diffs.len() < 3 {
361                continue;
362            }
363
364            let mean = time_diffs.iter().sum::<i64>() as f64 / time_diffs.len() as f64;
365            let variance = time_diffs
366                .iter()
367                .map(|&t| (t as f64 - mean).powi(2))
368                .sum::<f64>()
369                / time_diffs.len() as f64;
370            let cv = variance.sqrt() / mean.abs().max(1.0); // Coefficient of variation
371
372            // Low CV suggests consistent causal relationship
373            if cv < 0.5 {
374                // Find specific event pairs
375                for window in sorted.windows(2) {
376                    if window[0].event_type == *type_a && window[1].event_type == *type_b {
377                        let score = 1.0 - cv;
378                        causal.push(EventCorrelation {
379                            correlated_event_id: window[1].id,
380                            score,
381                            correlation_type: CorrelationType::Causal,
382                            time_diff: (window[1].timestamp - window[0].timestamp) as i64,
383                        });
384                    }
385                }
386            }
387        }
388
389        causal
390    }
391
392    /// Find events correlated by all specified types.
393    pub fn find_strongly_correlated(
394        events: &[UserEvent],
395        required_types: &[CorrelationType],
396    ) -> Vec<(u64, u64, f64)> {
397        let mut pairs = Vec::new();
398
399        for (i, e1) in events.iter().enumerate() {
400            for e2 in events.iter().skip(i + 1) {
401                let mut matches = Vec::new();
402
403                // Check each required type
404                for req_type in required_types {
405                    let matched = match req_type {
406                        CorrelationType::User => e1.user_id == e2.user_id,
407                        CorrelationType::Session => {
408                            e1.session_id.is_some() && e1.session_id == e2.session_id
409                        }
410                        CorrelationType::Device => {
411                            e1.device_id.is_some() && e1.device_id == e2.device_id
412                        }
413                        CorrelationType::Location => {
414                            e1.location.is_some() && e1.location == e2.location
415                        }
416                        CorrelationType::Temporal => {
417                            (e1.timestamp as i64 - e2.timestamp as i64).abs() < 3600
418                        }
419                        CorrelationType::Causal => false, // Requires separate analysis
420                    };
421                    matches.push(matched);
422                }
423
424                if matches.iter().all(|&m| m) {
425                    let score =
426                        matches.iter().filter(|&&m| m).count() as f64 / required_types.len() as f64;
427                    pairs.push((e1.id, e2.id, score));
428                }
429            }
430        }
431
432        pairs
433    }
434}
435
436impl GpuKernel for EventCorrelationKernel {
437    fn metadata(&self) -> &KernelMetadata {
438        &self.metadata
439    }
440}
441
442/// Correlation configuration.
443#[derive(Debug, Clone)]
444pub struct CorrelationConfig {
445    /// Time window for temporal correlation (seconds).
446    pub temporal_window_secs: u64,
447    /// Minimum correlation score to include.
448    pub min_score: f64,
449    /// Maximum correlations to return per event.
450    pub max_correlations: Option<usize>,
451    /// Minimum score for cluster membership.
452    pub cluster_threshold: f64,
453    /// Correlation type weights.
454    pub weights: CorrelationWeights,
455}
456
457impl Default for CorrelationConfig {
458    fn default() -> Self {
459        Self {
460            temporal_window_secs: 3600, // 1 hour
461            min_score: 0.3,
462            max_correlations: Some(50),
463            cluster_threshold: 0.5,
464            weights: CorrelationWeights::default(),
465        }
466    }
467}
468
469/// Weights for different correlation types.
470#[derive(Debug, Clone)]
471pub struct CorrelationWeights {
472    /// Weight for temporal proximity.
473    pub temporal: f64,
474    /// Weight for same user.
475    pub user: f64,
476    /// Weight for same session.
477    pub session: f64,
478    /// Weight for same device.
479    pub device: f64,
480    /// Weight for same location.
481    pub location: f64,
482}
483
484impl Default for CorrelationWeights {
485    fn default() -> Self {
486        Self {
487            temporal: 0.2,
488            user: 0.3,
489            session: 0.25,
490            device: 0.15,
491            location: 0.1,
492        }
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    fn create_correlated_events() -> Vec<UserEvent> {
501        let base_ts = 1700000000u64;
502        vec![
503            UserEvent {
504                id: 1,
505                user_id: 100,
506                event_type: "login".to_string(),
507                timestamp: base_ts,
508                attributes: HashMap::new(),
509                session_id: Some(1),
510                device_id: Some("device_a".to_string()),
511                ip_address: Some("192.168.1.1".to_string()),
512                location: Some("US".to_string()),
513            },
514            UserEvent {
515                id: 2,
516                user_id: 100,
517                event_type: "view".to_string(),
518                timestamp: base_ts + 30,
519                attributes: HashMap::new(),
520                session_id: Some(1),
521                device_id: Some("device_a".to_string()),
522                ip_address: Some("192.168.1.1".to_string()),
523                location: Some("US".to_string()),
524            },
525            UserEvent {
526                id: 3,
527                user_id: 100,
528                event_type: "purchase".to_string(),
529                timestamp: base_ts + 60,
530                attributes: HashMap::new(),
531                session_id: Some(1),
532                device_id: Some("device_a".to_string()),
533                ip_address: Some("192.168.1.1".to_string()),
534                location: Some("US".to_string()),
535            },
536            // Different user, same time window
537            UserEvent {
538                id: 4,
539                user_id: 200,
540                event_type: "login".to_string(),
541                timestamp: base_ts + 15,
542                attributes: HashMap::new(),
543                session_id: Some(2),
544                device_id: Some("device_b".to_string()),
545                ip_address: Some("10.0.0.1".to_string()),
546                location: Some("UK".to_string()),
547            },
548            // Same user, different session
549            UserEvent {
550                id: 5,
551                user_id: 100,
552                event_type: "login".to_string(),
553                timestamp: base_ts + 7200, // 2 hours later
554                attributes: HashMap::new(),
555                session_id: Some(3),
556                device_id: Some("device_a".to_string()),
557                ip_address: Some("192.168.1.1".to_string()),
558                location: Some("US".to_string()),
559            },
560        ]
561    }
562
563    #[test]
564    fn test_correlation_kernel_metadata() {
565        let kernel = EventCorrelationKernel::new();
566        assert_eq!(kernel.metadata().id, "behavioral/event-correlation");
567        assert_eq!(kernel.metadata().domain, Domain::BehavioralAnalytics);
568    }
569
570    #[test]
571    fn test_same_user_correlation() {
572        let events = create_correlated_events();
573        let config = CorrelationConfig::default();
574
575        let result = EventCorrelationKernel::compute(&events[0], &events, &config);
576
577        // Should find correlations with events 2 and 3 (same user, session, device)
578        assert!(!result.correlations.is_empty(), "Should find correlations");
579
580        // Highest correlation should be with same-session events
581        let same_user_corrs: Vec<_> = result
582            .correlations
583            .iter()
584            .filter(|c| {
585                events
586                    .iter()
587                    .find(|e| e.id == c.correlated_event_id)
588                    .is_some_and(|e| e.user_id == 100)
589            })
590            .collect();
591
592        assert!(!same_user_corrs.is_empty());
593    }
594
595    #[test]
596    fn test_temporal_correlation() {
597        let events = create_correlated_events();
598        let config = CorrelationConfig {
599            temporal_window_secs: 100,
600            ..Default::default()
601        };
602
603        let result = EventCorrelationKernel::compute(&events[0], &events, &config);
604
605        // Events 2, 3, and 4 are within temporal window
606        let temporal_corrs: Vec<_> = result
607            .correlations
608            .iter()
609            .filter(|c| c.time_diff.abs() < 100)
610            .collect();
611
612        assert!(!temporal_corrs.is_empty());
613    }
614
615    #[test]
616    fn test_session_correlation() {
617        let events = create_correlated_events();
618        let config = CorrelationConfig::default();
619
620        let result = EventCorrelationKernel::compute(&events[0], &events, &config);
621
622        // Should highly correlate with events 2 and 3 (same session)
623        let session_corrs: Vec<_> = result
624            .correlations
625            .iter()
626            .filter(|c| c.correlation_type == CorrelationType::Session)
627            .collect();
628
629        // At least some correlations should be session-based
630        // (events 2 and 3 share session with event 1)
631        assert!(
632            result
633                .correlations
634                .iter()
635                .any(|c| c.correlated_event_id == 2 || c.correlated_event_id == 3),
636            "Should correlate with same-session events"
637        );
638        let _ = session_corrs; // Mark as used
639    }
640
641    #[test]
642    fn test_min_score_filter() {
643        let events = create_correlated_events();
644        let config = CorrelationConfig {
645            min_score: 0.8, // High threshold
646            ..Default::default()
647        };
648
649        let result = EventCorrelationKernel::compute(&events[0], &events, &config);
650
651        // All correlations should be above threshold
652        assert!(result.correlations.iter().all(|c| c.score >= 0.8));
653    }
654
655    #[test]
656    fn test_max_correlations_limit() {
657        let events = create_correlated_events();
658        let config = CorrelationConfig {
659            max_correlations: Some(2),
660            min_score: 0.0, // Allow all
661            ..Default::default()
662        };
663
664        let result = EventCorrelationKernel::compute(&events[0], &events, &config);
665
666        assert!(result.correlations.len() <= 2);
667    }
668
669    #[test]
670    fn test_cluster_building() {
671        let events = create_correlated_events();
672        let config = CorrelationConfig {
673            cluster_threshold: 0.3,
674            ..Default::default()
675        };
676
677        let result = EventCorrelationKernel::compute(&events[0], &events, &config);
678
679        // May or may not have clusters depending on correlations
680        for cluster in &result.clusters {
681            assert!(
682                cluster.event_ids.len() >= 2,
683                "Clusters should have 2+ events"
684            );
685            assert!(cluster.coherence >= 0.0 && cluster.coherence <= 1.0);
686        }
687    }
688
689    #[test]
690    fn test_batch_correlation() {
691        let events = create_correlated_events();
692        let config = CorrelationConfig::default();
693
694        let results = EventCorrelationKernel::compute_batch(&events, &config);
695
696        assert_eq!(results.len(), events.len());
697        for result in &results {
698            assert!(events.iter().any(|e| e.id == result.event_id));
699        }
700    }
701
702    #[test]
703    fn test_causal_correlation_detection() {
704        let base_ts = 1700000000u64;
705        // Create events with consistent A->B pattern
706        let events: Vec<UserEvent> = (0u64..10)
707            .flat_map(|i| {
708                vec![
709                    UserEvent {
710                        id: i * 2,
711                        user_id: 100,
712                        event_type: "cause".to_string(),
713                        timestamp: base_ts + (i * 1000),
714                        attributes: HashMap::new(),
715                        session_id: Some(i),
716                        device_id: None,
717                        ip_address: None,
718                        location: None,
719                    },
720                    UserEvent {
721                        id: i * 2 + 1,
722                        user_id: 100,
723                        event_type: "effect".to_string(),
724                        timestamp: base_ts + (i * 1000) + 50, // Consistent 50s delay
725                        attributes: HashMap::new(),
726                        session_id: Some(i),
727                        device_id: None,
728                        ip_address: None,
729                        location: None,
730                    },
731                ]
732            })
733            .collect();
734
735        let config = CorrelationConfig::default();
736        let causal = EventCorrelationKernel::detect_causal_correlations(&events, &config);
737
738        // Should detect causal relationships
739        assert!(
740            !causal.is_empty(),
741            "Should detect causal correlations in consistent patterns"
742        );
743
744        // All should be marked as causal
745        assert!(
746            causal
747                .iter()
748                .all(|c| c.correlation_type == CorrelationType::Causal)
749        );
750    }
751
752    #[test]
753    fn test_strongly_correlated() {
754        let events = create_correlated_events();
755        let required = vec![CorrelationType::User, CorrelationType::Session];
756
757        let pairs = EventCorrelationKernel::find_strongly_correlated(&events, &required);
758
759        // Events 1, 2, 3 share user and session
760        assert!(!pairs.is_empty());
761        assert!(pairs.iter().all(|(_, _, score)| *score == 1.0));
762    }
763
764    #[test]
765    fn test_empty_events() {
766        let events: Vec<UserEvent> = Vec::new();
767        let config = CorrelationConfig::default();
768
769        let result = EventCorrelationKernel::compute(
770            &UserEvent {
771                id: 1,
772                user_id: 100,
773                event_type: "test".to_string(),
774                timestamp: 0,
775                attributes: HashMap::new(),
776                session_id: None,
777                device_id: None,
778                ip_address: None,
779                location: None,
780            },
781            &events,
782            &config,
783        );
784
785        assert!(result.correlations.is_empty());
786        assert!(result.clusters.is_empty());
787    }
788
789    #[test]
790    fn test_correlation_weights() {
791        let events = create_correlated_events();
792
793        // High user weight
794        let user_config = CorrelationConfig {
795            weights: CorrelationWeights {
796                user: 0.8,
797                session: 0.1,
798                device: 0.05,
799                location: 0.03,
800                temporal: 0.02,
801            },
802            ..Default::default()
803        };
804
805        let result = EventCorrelationKernel::compute(&events[0], &events, &user_config);
806
807        // Same-user events should have higher scores
808        if let Some(same_user) = result
809            .correlations
810            .iter()
811            .find(|c| c.correlated_event_id == 2)
812        {
813            if let Some(diff_user) = result
814                .correlations
815                .iter()
816                .find(|c| c.correlated_event_id == 4)
817            {
818                assert!(
819                    same_user.score > diff_user.score,
820                    "Same-user correlation should be stronger with high user weight"
821                );
822            }
823        }
824    }
825}