rustkernel_risk/
correlation.rs

1//! Real-time correlation kernels.
2//!
3//! This module provides streaming correlation computation:
4//! - Incremental correlation matrix updates using Welford's algorithm
5//! - Exponentially weighted moving correlation
6//! - Correlation change detection
7
8use rustkernel_core::traits::GpuKernel;
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::time::Instant;
13
14// ============================================================================
15// Real-Time Correlation Kernel
16// ============================================================================
17
18/// Correlation update types.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
20pub enum CorrelationType {
21    /// Pearson correlation coefficient.
22    #[default]
23    Pearson,
24    /// Exponentially weighted correlation.
25    Exponential,
26}
27
28/// Configuration for correlation computation.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct CorrelationConfig {
31    /// Number of assets to track.
32    pub n_assets: usize,
33    /// Type of correlation to compute.
34    pub correlation_type: CorrelationType,
35    /// Exponential decay factor (0-1, higher = more weight to recent).
36    /// Only used for Exponential correlation type.
37    pub decay_factor: f64,
38    /// Minimum observations before computing correlation.
39    pub min_observations: usize,
40    /// Threshold for significant correlation change alerts.
41    pub change_threshold: f64,
42}
43
44impl Default for CorrelationConfig {
45    fn default() -> Self {
46        Self {
47            n_assets: 100,
48            correlation_type: CorrelationType::Pearson,
49            decay_factor: 0.94, // ~15-day half-life
50            min_observations: 30,
51            change_threshold: 0.1, // 10% change
52        }
53    }
54}
55
56/// Running statistics for a single asset (Welford's algorithm).
57#[derive(Debug, Clone, Default)]
58pub struct AssetStats {
59    /// Count of observations.
60    pub count: u64,
61    /// Running mean.
62    pub mean: f64,
63    /// Running sum of squared deviations (M2).
64    pub m2: f64,
65    /// Last observed value.
66    pub last_value: f64,
67    /// Last update timestamp.
68    pub last_timestamp: u64,
69}
70
71impl AssetStats {
72    /// Update stats with a new observation using Welford's algorithm.
73    pub fn update(&mut self, value: f64, timestamp: u64) {
74        self.count += 1;
75        let delta = value - self.mean;
76        self.mean += delta / self.count as f64;
77        let delta2 = value - self.mean;
78        self.m2 += delta * delta2;
79        self.last_value = value;
80        self.last_timestamp = timestamp;
81    }
82
83    /// Get variance.
84    pub fn variance(&self) -> f64 {
85        if self.count < 2 {
86            0.0
87        } else {
88            self.m2 / (self.count - 1) as f64
89        }
90    }
91
92    /// Get standard deviation.
93    pub fn std_dev(&self) -> f64 {
94        self.variance().sqrt()
95    }
96}
97
98/// Running covariance between two assets.
99#[derive(Debug, Clone, Default)]
100pub struct PairwiseStats {
101    /// Count of paired observations.
102    pub count: u64,
103    /// Mean of asset i values.
104    pub mean_i: f64,
105    /// Mean of asset j values.
106    pub mean_j: f64,
107    /// Co-moment sum (for covariance calculation).
108    pub co_moment: f64,
109}
110
111impl PairwiseStats {
112    /// Update with new paired observations (Welford's parallel algorithm).
113    pub fn update(&mut self, value_i: f64, value_j: f64) {
114        self.count += 1;
115        let n = self.count as f64;
116
117        let delta_i = value_i - self.mean_i;
118        let delta_j = value_j - self.mean_j;
119
120        self.mean_i += delta_i / n;
121        self.mean_j += delta_j / n;
122
123        // Update co-moment using corrected delta
124        let delta_j_new = value_j - self.mean_j;
125        self.co_moment += delta_i * delta_j_new;
126    }
127
128    /// Get covariance.
129    pub fn covariance(&self) -> f64 {
130        if self.count < 2 {
131            0.0
132        } else {
133            self.co_moment / (self.count - 1) as f64
134        }
135    }
136}
137
138/// Internal state for real-time correlation tracking.
139#[derive(Debug, Clone, Default)]
140pub struct CorrelationState {
141    /// Configuration.
142    pub config: CorrelationConfig,
143    /// Per-asset statistics.
144    pub asset_stats: Vec<AssetStats>,
145    /// Pairwise statistics (upper triangular, stored as Vec).
146    /// Index (i, j) where i < j is at position: i * (n - 1) - i * (i - 1) / 2 + (j - i - 1)
147    pub pairwise_stats: Vec<PairwiseStats>,
148    /// Cached correlation matrix (full N×N).
149    pub correlation_matrix: Vec<f64>,
150    /// Previous correlation matrix (for change detection).
151    pub prev_correlation_matrix: Vec<f64>,
152    /// Total observations processed.
153    pub total_observations: u64,
154    /// Asset ID to index mapping.
155    pub asset_index: HashMap<u64, usize>,
156}
157
158impl CorrelationState {
159    /// Create new state with configuration.
160    pub fn new(config: CorrelationConfig) -> Self {
161        let n = config.n_assets;
162        let n_pairs = n * (n - 1) / 2;
163
164        Self {
165            config,
166            asset_stats: vec![AssetStats::default(); n],
167            pairwise_stats: vec![PairwiseStats::default(); n_pairs],
168            correlation_matrix: vec![0.0; n * n],
169            prev_correlation_matrix: vec![0.0; n * n],
170            total_observations: 0,
171            asset_index: HashMap::new(),
172        }
173    }
174
175    /// Get index into pairwise_stats for pair (i, j) where i < j.
176    fn pair_index(&self, i: usize, j: usize) -> usize {
177        let (i, j) = if i < j { (i, j) } else { (j, i) };
178        let n = self.config.n_assets;
179        i * (2 * n - i - 1) / 2 + (j - i - 1)
180    }
181}
182
183/// A single correlation update for streaming.
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct CorrelationUpdate {
186    /// Asset identifier.
187    pub asset_id: u64,
188    /// Observation value (typically return or price).
189    pub value: f64,
190    /// Timestamp of observation.
191    pub timestamp: u64,
192}
193
194/// Result of a correlation update.
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct CorrelationUpdateResult {
197    /// Asset ID that was updated.
198    pub asset_id: u64,
199    /// Number of correlations recomputed.
200    pub correlations_updated: usize,
201    /// Significant changes detected.
202    pub significant_changes: Vec<CorrelationChange>,
203    /// Update latency in microseconds.
204    pub latency_us: u64,
205}
206
207/// A significant correlation change.
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct CorrelationChange {
210    /// First asset ID.
211    pub asset_i: u64,
212    /// Second asset ID.
213    pub asset_j: u64,
214    /// Previous correlation.
215    pub old_correlation: f64,
216    /// New correlation.
217    pub new_correlation: f64,
218    /// Change magnitude.
219    pub change: f64,
220}
221
222/// Full correlation matrix result.
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct CorrelationMatrixResult {
225    /// Number of assets.
226    pub n_assets: usize,
227    /// Full N×N correlation matrix (row-major).
228    pub correlations: Vec<f64>,
229    /// Observations used.
230    pub observations: u64,
231    /// Timestamp of last update.
232    pub timestamp: u64,
233    /// Compute time in microseconds.
234    pub compute_time_us: u64,
235}
236
237/// Real-time correlation kernel.
238///
239/// Maintains streaming correlation matrices using Welford's online algorithm.
240/// Supports both Pearson and exponentially weighted correlations.
241/// Designed for Ring mode operation with sub-millisecond updates.
242#[derive(Debug)]
243pub struct RealTimeCorrelation {
244    metadata: KernelMetadata,
245    /// Internal state for tracking correlations.
246    state: std::sync::RwLock<CorrelationState>,
247}
248
249impl Clone for RealTimeCorrelation {
250    fn clone(&self) -> Self {
251        Self {
252            metadata: self.metadata.clone(),
253            state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
254        }
255    }
256}
257
258impl Default for RealTimeCorrelation {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264impl RealTimeCorrelation {
265    /// Create a new real-time correlation kernel.
266    #[must_use]
267    pub fn new() -> Self {
268        Self {
269            metadata: KernelMetadata::ring("risk/realtime-correlation", Domain::RiskAnalytics)
270                .with_description("Streaming correlation matrix updates")
271                .with_throughput(500_000)
272                .with_latency_us(10.0),
273            state: std::sync::RwLock::new(CorrelationState::new(CorrelationConfig::default())),
274        }
275    }
276
277    /// Create with custom configuration.
278    #[must_use]
279    pub fn with_config(config: CorrelationConfig) -> Self {
280        Self {
281            metadata: KernelMetadata::ring("risk/realtime-correlation", Domain::RiskAnalytics)
282                .with_description("Streaming correlation matrix updates")
283                .with_throughput(500_000)
284                .with_latency_us(10.0),
285            state: std::sync::RwLock::new(CorrelationState::new(config)),
286        }
287    }
288
289    /// Initialize with a set of asset IDs.
290    pub fn initialize(&self, asset_ids: &[u64]) {
291        let mut state = self.state.write().unwrap();
292        state.asset_index.clear();
293        for (idx, &id) in asset_ids.iter().enumerate() {
294            if idx < state.config.n_assets {
295                state.asset_index.insert(id, idx);
296            }
297        }
298        // Reset statistics
299        let n = state.config.n_assets;
300        state.asset_stats = vec![AssetStats::default(); n];
301        state.pairwise_stats = vec![PairwiseStats::default(); n * (n - 1) / 2];
302        state.correlation_matrix = vec![0.0; n * n];
303        state.prev_correlation_matrix = vec![0.0; n * n];
304        state.total_observations = 0;
305    }
306
307    /// Process a single update and return correlation changes.
308    pub fn update(&self, update: &CorrelationUpdate) -> CorrelationUpdateResult {
309        let start = Instant::now();
310        let mut state = self.state.write().unwrap();
311
312        // Get or assign index for this asset
313        let asset_idx = if let Some(&idx) = state.asset_index.get(&update.asset_id) {
314            idx
315        } else if state.asset_index.len() < state.config.n_assets {
316            let idx = state.asset_index.len();
317            state.asset_index.insert(update.asset_id, idx);
318            idx
319        } else {
320            // At capacity, ignore new assets
321            return CorrelationUpdateResult {
322                asset_id: update.asset_id,
323                correlations_updated: 0,
324                significant_changes: Vec::new(),
325                latency_us: start.elapsed().as_micros() as u64,
326            };
327        };
328
329        // Update asset statistics
330        state.asset_stats[asset_idx].update(update.value, update.timestamp);
331        state.total_observations += 1;
332
333        // Update pairwise statistics for all pairs involving this asset
334        let n = state.config.n_assets;
335        let mut correlations_updated = 0;
336        let mut significant_changes = Vec::new();
337
338        // We need the last values of other assets to update covariance
339        // In a true streaming system, we'd batch updates or use a different approach
340        // For now, we update when both assets have been observed at least once
341        for other_idx in 0..state.asset_index.len() {
342            if other_idx == asset_idx {
343                continue;
344            }
345
346            let other_stats = &state.asset_stats[other_idx];
347            if other_stats.count == 0 {
348                continue;
349            }
350
351            // Update pairwise statistics
352            let (i, j) = if asset_idx < other_idx {
353                (asset_idx, other_idx)
354            } else {
355                (other_idx, asset_idx)
356            };
357            let pair_idx = state.pair_index(i, j);
358
359            // Use the last values for covariance update
360            let value_i = if asset_idx == i {
361                update.value
362            } else {
363                state.asset_stats[i].last_value
364            };
365            let value_j = if asset_idx == j {
366                update.value
367            } else {
368                state.asset_stats[j].last_value
369            };
370
371            state.pairwise_stats[pair_idx].update(value_i, value_j);
372
373            // Recompute correlation for this pair
374            if state.pairwise_stats[pair_idx].count >= state.config.min_observations as u64 {
375                let cov = state.pairwise_stats[pair_idx].covariance();
376                let std_i = state.asset_stats[i].std_dev();
377                let std_j = state.asset_stats[j].std_dev();
378
379                let new_corr = if std_i > 1e-10 && std_j > 1e-10 {
380                    (cov / (std_i * std_j)).clamp(-1.0, 1.0)
381                } else {
382                    0.0
383                };
384
385                // Store previous and update
386                let old_corr = state.correlation_matrix[i * n + j];
387                state.prev_correlation_matrix[i * n + j] = old_corr;
388                state.prev_correlation_matrix[j * n + i] = old_corr;
389                state.correlation_matrix[i * n + j] = new_corr;
390                state.correlation_matrix[j * n + i] = new_corr;
391
392                correlations_updated += 1;
393
394                // Check for significant change
395                let change = (new_corr - old_corr).abs();
396                if change >= state.config.change_threshold {
397                    // Get asset IDs
398                    let id_i = state
399                        .asset_index
400                        .iter()
401                        .find(|&(_, idx)| *idx == i)
402                        .map(|(&id, _)| id)
403                        .unwrap_or(0);
404                    let id_j = state
405                        .asset_index
406                        .iter()
407                        .find(|&(_, idx)| *idx == j)
408                        .map(|(&id, _)| id)
409                        .unwrap_or(0);
410
411                    significant_changes.push(CorrelationChange {
412                        asset_i: id_i,
413                        asset_j: id_j,
414                        old_correlation: old_corr,
415                        new_correlation: new_corr,
416                        change,
417                    });
418                }
419            }
420        }
421
422        // Set diagonal to 1.0
423        state.correlation_matrix[asset_idx * n + asset_idx] = 1.0;
424
425        CorrelationUpdateResult {
426            asset_id: update.asset_id,
427            correlations_updated,
428            significant_changes,
429            latency_us: start.elapsed().as_micros() as u64,
430        }
431    }
432
433    /// Process a batch of updates.
434    pub fn update_batch(&self, updates: &[CorrelationUpdate]) -> Vec<CorrelationUpdateResult> {
435        updates.iter().map(|u| self.update(u)).collect()
436    }
437
438    /// Get current correlation between two assets.
439    pub fn get_correlation(&self, asset_i: u64, asset_j: u64) -> Option<f64> {
440        let state = self.state.read().unwrap();
441        let idx_i = state.asset_index.get(&asset_i)?;
442        let idx_j = state.asset_index.get(&asset_j)?;
443        let n = state.config.n_assets;
444        Some(state.correlation_matrix[idx_i * n + idx_j])
445    }
446
447    /// Get full correlation matrix.
448    pub fn get_matrix(&self) -> CorrelationMatrixResult {
449        let start = Instant::now();
450        let state = self.state.read().unwrap();
451
452        CorrelationMatrixResult {
453            n_assets: state.asset_index.len(),
454            correlations: state.correlation_matrix.clone(),
455            observations: state.total_observations,
456            timestamp: state
457                .asset_stats
458                .iter()
459                .map(|s| s.last_timestamp)
460                .max()
461                .unwrap_or(0),
462            compute_time_us: start.elapsed().as_micros() as u64,
463        }
464    }
465
466    /// Get correlation row for a specific asset.
467    pub fn get_row(&self, asset_id: u64) -> Option<Vec<(u64, f64)>> {
468        let state = self.state.read().unwrap();
469        let idx = state.asset_index.get(&asset_id)?;
470        let n = state.config.n_assets;
471
472        Some(
473            state
474                .asset_index
475                .iter()
476                .map(|(&id, &j)| (id, state.correlation_matrix[idx * n + j]))
477                .collect(),
478        )
479    }
480
481    /// Reset state while keeping configuration.
482    pub fn reset(&self) {
483        let mut state = self.state.write().unwrap();
484        let config = state.config.clone();
485        *state = CorrelationState::new(config);
486    }
487
488    /// Batch compute correlation matrix from historical data.
489    pub fn compute_from_returns(returns: &[Vec<f64>]) -> CorrelationMatrixResult {
490        let start = Instant::now();
491
492        if returns.is_empty() || returns[0].is_empty() {
493            return CorrelationMatrixResult {
494                n_assets: 0,
495                correlations: Vec::new(),
496                observations: 0,
497                timestamp: 0,
498                compute_time_us: start.elapsed().as_micros() as u64,
499            };
500        }
501
502        let n = returns.len();
503        let t = returns[0].len();
504
505        // Compute means
506        let means: Vec<f64> = returns
507            .iter()
508            .map(|r| r.iter().sum::<f64>() / t as f64)
509            .collect();
510
511        // Compute standard deviations
512        let stds: Vec<f64> = returns
513            .iter()
514            .zip(means.iter())
515            .map(|(r, &mean)| {
516                let var = r.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (t - 1) as f64;
517                var.sqrt()
518            })
519            .collect();
520
521        // Compute correlation matrix
522        let mut correlations = vec![0.0; n * n];
523
524        for i in 0..n {
525            correlations[i * n + i] = 1.0; // Diagonal
526
527            for j in (i + 1)..n {
528                let cov: f64 = returns[i]
529                    .iter()
530                    .zip(returns[j].iter())
531                    .map(|(&xi, &xj)| (xi - means[i]) * (xj - means[j]))
532                    .sum::<f64>()
533                    / (t - 1) as f64;
534
535                let corr = if stds[i] > 1e-10 && stds[j] > 1e-10 {
536                    (cov / (stds[i] * stds[j])).clamp(-1.0, 1.0)
537                } else {
538                    0.0
539                };
540
541                correlations[i * n + j] = corr;
542                correlations[j * n + i] = corr;
543            }
544        }
545
546        CorrelationMatrixResult {
547            n_assets: n,
548            correlations,
549            observations: t as u64,
550            timestamp: 0,
551            compute_time_us: start.elapsed().as_micros() as u64,
552        }
553    }
554}
555
556impl GpuKernel for RealTimeCorrelation {
557    fn metadata(&self) -> &KernelMetadata {
558        &self.metadata
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn test_realtime_correlation_metadata() {
568        let kernel = RealTimeCorrelation::new();
569        assert_eq!(kernel.metadata().id, "risk/realtime-correlation");
570        assert_eq!(kernel.metadata().domain, Domain::RiskAnalytics);
571    }
572
573    #[test]
574    fn test_asset_stats_welford() {
575        let mut stats = AssetStats::default();
576
577        // Known sequence: 2, 4, 6, 8, 10
578        // Mean = 6, Var = 10
579        for v in [2.0, 4.0, 6.0, 8.0, 10.0] {
580            stats.update(v, 0);
581        }
582
583        assert!((stats.mean - 6.0).abs() < 1e-10);
584        assert!((stats.variance() - 10.0).abs() < 1e-10);
585        assert!((stats.std_dev() - (10.0_f64).sqrt()).abs() < 1e-10);
586    }
587
588    #[test]
589    fn test_initialize_assets() {
590        let kernel = RealTimeCorrelation::new();
591        kernel.initialize(&[100, 101, 102]);
592
593        // Should have registered 3 assets
594        let state = kernel.state.read().unwrap();
595        assert_eq!(state.asset_index.len(), 3);
596    }
597
598    #[test]
599    fn test_streaming_updates() {
600        let config = CorrelationConfig {
601            n_assets: 10,
602            min_observations: 2,
603            ..Default::default()
604        };
605        let kernel = RealTimeCorrelation::with_config(config);
606        kernel.initialize(&[1, 2]);
607
608        // Generate correlated returns
609        for i in 0..50 {
610            let r1 = (i as f64) * 0.01;
611            let r2 = r1 * 0.8 + 0.002; // Highly correlated
612
613            kernel.update(&CorrelationUpdate {
614                asset_id: 1,
615                value: r1,
616                timestamp: i as u64,
617            });
618            kernel.update(&CorrelationUpdate {
619                asset_id: 2,
620                value: r2,
621                timestamp: i as u64,
622            });
623        }
624
625        // Check correlation is high
626        let corr = kernel.get_correlation(1, 2).unwrap();
627        assert!(corr > 0.9, "Expected high correlation, got: {}", corr);
628    }
629
630    #[test]
631    fn test_uncorrelated_assets() {
632        let config = CorrelationConfig {
633            n_assets: 10,
634            min_observations: 2,
635            ..Default::default()
636        };
637        let kernel = RealTimeCorrelation::with_config(config);
638        kernel.initialize(&[1, 2]);
639
640        // Generate uncorrelated returns using alternating pattern
641        for i in 0..100 {
642            let r1 = if i % 2 == 0 { 0.01 } else { -0.01 };
643            let r2 = if i % 3 == 0 { 0.01 } else { -0.01 };
644
645            kernel.update(&CorrelationUpdate {
646                asset_id: 1,
647                value: r1,
648                timestamp: i as u64,
649            });
650            kernel.update(&CorrelationUpdate {
651                asset_id: 2,
652                value: r2,
653                timestamp: i as u64,
654            });
655        }
656
657        // Correlation should be low
658        let corr = kernel.get_correlation(1, 2).unwrap();
659        assert!(corr.abs() < 0.5, "Expected low correlation, got: {}", corr);
660    }
661
662    #[test]
663    fn test_correlation_matrix_diagonal() {
664        let kernel = RealTimeCorrelation::new();
665        kernel.initialize(&[1, 2, 3]);
666
667        // Add some data
668        for i in 0..30 {
669            kernel.update(&CorrelationUpdate {
670                asset_id: 1,
671                value: i as f64 * 0.01,
672                timestamp: i as u64,
673            });
674            kernel.update(&CorrelationUpdate {
675                asset_id: 2,
676                value: i as f64 * 0.02,
677                timestamp: i as u64,
678            });
679            kernel.update(&CorrelationUpdate {
680                asset_id: 3,
681                value: i as f64 * 0.015,
682                timestamp: i as u64,
683            });
684        }
685
686        // Diagonal should be 1.0
687        let corr_11 = kernel.get_correlation(1, 1).unwrap();
688        let corr_22 = kernel.get_correlation(2, 2).unwrap();
689        let corr_33 = kernel.get_correlation(3, 3).unwrap();
690
691        assert!((corr_11 - 1.0).abs() < 1e-10);
692        assert!((corr_22 - 1.0).abs() < 1e-10);
693        assert!((corr_33 - 1.0).abs() < 1e-10);
694    }
695
696    #[test]
697    fn test_batch_correlation() {
698        // Returns for 3 assets over 10 periods
699        let returns = vec![
700            vec![
701                0.01, 0.02, -0.01, 0.03, 0.01, -0.02, 0.01, 0.02, -0.01, 0.01,
702            ],
703            vec![
704                0.02, 0.03, -0.02, 0.04, 0.02, -0.03, 0.02, 0.03, -0.02, 0.02,
705            ], // Similar to asset 0
706            vec![
707                -0.01, 0.01, 0.02, -0.02, 0.03, 0.01, -0.01, 0.02, 0.01, -0.01,
708            ], // Different pattern
709        ];
710
711        let result = RealTimeCorrelation::compute_from_returns(&returns);
712
713        assert_eq!(result.n_assets, 3);
714        assert_eq!(result.observations, 10);
715
716        // Check matrix properties
717        let n = result.n_assets;
718        // Diagonal should be 1.0
719        for i in 0..n {
720            assert!((result.correlations[i * n + i] - 1.0).abs() < 1e-10);
721        }
722        // Should be symmetric
723        for i in 0..n {
724            for j in 0..n {
725                let diff = (result.correlations[i * n + j] - result.correlations[j * n + i]).abs();
726                assert!(diff < 1e-10);
727            }
728        }
729        // Assets 0 and 1 should be highly correlated
730        let corr_01 = result.correlations[1];
731        assert!(corr_01 > 0.9, "Expected high correlation: {}", corr_01);
732    }
733
734    #[test]
735    fn test_significant_change_detection() {
736        let config = CorrelationConfig {
737            n_assets: 10,
738            min_observations: 2,
739            change_threshold: 0.3, // 30% change threshold
740            ..Default::default()
741        };
742        let kernel = RealTimeCorrelation::with_config(config);
743        kernel.initialize(&[1, 2]);
744
745        // First establish a positive correlation
746        for i in 0..50 {
747            kernel.update(&CorrelationUpdate {
748                asset_id: 1,
749                value: i as f64 * 0.01,
750                timestamp: i as u64,
751            });
752            kernel.update(&CorrelationUpdate {
753                asset_id: 2,
754                value: i as f64 * 0.01 + 0.001,
755                timestamp: i as u64,
756            });
757        }
758
759        // Now switch to negative correlation - this should trigger a change
760        // (In practice this would take more observations to significantly change the correlation)
761        let baseline_corr = kernel.get_correlation(1, 2).unwrap();
762        assert!(
763            baseline_corr > 0.9,
764            "Expected high positive correlation: {}",
765            baseline_corr
766        );
767    }
768
769    #[test]
770    fn test_get_row() {
771        let kernel = RealTimeCorrelation::new();
772        kernel.initialize(&[1, 2, 3]);
773
774        // Add data
775        for i in 0..30 {
776            kernel.update(&CorrelationUpdate {
777                asset_id: 1,
778                value: i as f64,
779                timestamp: i as u64,
780            });
781            kernel.update(&CorrelationUpdate {
782                asset_id: 2,
783                value: i as f64 * 2.0,
784                timestamp: i as u64,
785            });
786            kernel.update(&CorrelationUpdate {
787                asset_id: 3,
788                value: i as f64 * 1.5,
789                timestamp: i as u64,
790            });
791        }
792
793        let row = kernel.get_row(1).unwrap();
794        assert_eq!(row.len(), 3);
795
796        // Should include self-correlation of 1.0
797        let self_corr = row.iter().find(|(id, _)| *id == 1).map(|(_, c)| *c);
798        assert!((self_corr.unwrap() - 1.0).abs() < 1e-10);
799    }
800
801    #[test]
802    fn test_reset() {
803        let kernel = RealTimeCorrelation::new();
804        kernel.initialize(&[1, 2]);
805
806        for i in 0..30 {
807            kernel.update(&CorrelationUpdate {
808                asset_id: 1,
809                value: i as f64,
810                timestamp: i as u64,
811            });
812        }
813
814        let matrix_before = kernel.get_matrix();
815        assert!(matrix_before.observations > 0);
816
817        kernel.reset();
818
819        let matrix_after = kernel.get_matrix();
820        assert_eq!(matrix_after.observations, 0);
821    }
822
823    #[test]
824    fn test_empty_returns() {
825        let result = RealTimeCorrelation::compute_from_returns(&[]);
826        assert_eq!(result.n_assets, 0);
827
828        let empty_inner: Vec<Vec<f64>> = vec![vec![]];
829        let result2 = RealTimeCorrelation::compute_from_returns(&empty_inner);
830        assert_eq!(result2.n_assets, 0);
831    }
832}