1use rustkernel_core::traits::GpuKernel;
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::time::Instant;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum CorrelationType {
21 Pearson,
23 Exponential,
25}
26
27impl Default for CorrelationType {
28 fn default() -> Self {
29 Self::Pearson
30 }
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CorrelationConfig {
36 pub n_assets: usize,
38 pub correlation_type: CorrelationType,
40 pub decay_factor: f64,
43 pub min_observations: usize,
45 pub change_threshold: f64,
47}
48
49impl Default for CorrelationConfig {
50 fn default() -> Self {
51 Self {
52 n_assets: 100,
53 correlation_type: CorrelationType::Pearson,
54 decay_factor: 0.94, min_observations: 30,
56 change_threshold: 0.1, }
58 }
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct AssetStats {
64 pub count: u64,
66 pub mean: f64,
68 pub m2: f64,
70 pub last_value: f64,
72 pub last_timestamp: u64,
74}
75
76impl AssetStats {
77 pub fn update(&mut self, value: f64, timestamp: u64) {
79 self.count += 1;
80 let delta = value - self.mean;
81 self.mean += delta / self.count as f64;
82 let delta2 = value - self.mean;
83 self.m2 += delta * delta2;
84 self.last_value = value;
85 self.last_timestamp = timestamp;
86 }
87
88 pub fn variance(&self) -> f64 {
90 if self.count < 2 {
91 0.0
92 } else {
93 self.m2 / (self.count - 1) as f64
94 }
95 }
96
97 pub fn std_dev(&self) -> f64 {
99 self.variance().sqrt()
100 }
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct PairwiseStats {
106 pub count: u64,
108 pub mean_i: f64,
110 pub mean_j: f64,
112 pub co_moment: f64,
114}
115
116impl PairwiseStats {
117 pub fn update(&mut self, value_i: f64, value_j: f64) {
119 self.count += 1;
120 let n = self.count as f64;
121
122 let delta_i = value_i - self.mean_i;
123 let delta_j = value_j - self.mean_j;
124
125 self.mean_i += delta_i / n;
126 self.mean_j += delta_j / n;
127
128 let delta_j_new = value_j - self.mean_j;
130 self.co_moment += delta_i * delta_j_new;
131 }
132
133 pub fn covariance(&self) -> f64 {
135 if self.count < 2 {
136 0.0
137 } else {
138 self.co_moment / (self.count - 1) as f64
139 }
140 }
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct CorrelationState {
146 pub config: CorrelationConfig,
148 pub asset_stats: Vec<AssetStats>,
150 pub pairwise_stats: Vec<PairwiseStats>,
153 pub correlation_matrix: Vec<f64>,
155 pub prev_correlation_matrix: Vec<f64>,
157 pub total_observations: u64,
159 pub asset_index: HashMap<u64, usize>,
161}
162
163impl CorrelationState {
164 pub fn new(config: CorrelationConfig) -> Self {
166 let n = config.n_assets;
167 let n_pairs = n * (n - 1) / 2;
168
169 Self {
170 config,
171 asset_stats: vec![AssetStats::default(); n],
172 pairwise_stats: vec![PairwiseStats::default(); n_pairs],
173 correlation_matrix: vec![0.0; n * n],
174 prev_correlation_matrix: vec![0.0; n * n],
175 total_observations: 0,
176 asset_index: HashMap::new(),
177 }
178 }
179
180 fn pair_index(&self, i: usize, j: usize) -> usize {
182 let (i, j) = if i < j { (i, j) } else { (j, i) };
183 let n = self.config.n_assets;
184 i * (2 * n - i - 1) / 2 + (j - i - 1)
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct CorrelationUpdate {
191 pub asset_id: u64,
193 pub value: f64,
195 pub timestamp: u64,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct CorrelationUpdateResult {
202 pub asset_id: u64,
204 pub correlations_updated: usize,
206 pub significant_changes: Vec<CorrelationChange>,
208 pub latency_us: u64,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct CorrelationChange {
215 pub asset_i: u64,
217 pub asset_j: u64,
219 pub old_correlation: f64,
221 pub new_correlation: f64,
223 pub change: f64,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct CorrelationMatrixResult {
230 pub n_assets: usize,
232 pub correlations: Vec<f64>,
234 pub observations: u64,
236 pub timestamp: u64,
238 pub compute_time_us: u64,
240}
241
242#[derive(Debug)]
248pub struct RealTimeCorrelation {
249 metadata: KernelMetadata,
250 state: std::sync::RwLock<CorrelationState>,
252}
253
254impl Clone for RealTimeCorrelation {
255 fn clone(&self) -> Self {
256 Self {
257 metadata: self.metadata.clone(),
258 state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
259 }
260 }
261}
262
263impl Default for RealTimeCorrelation {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl RealTimeCorrelation {
270 #[must_use]
272 pub fn new() -> Self {
273 Self {
274 metadata: KernelMetadata::ring("risk/realtime-correlation", Domain::RiskAnalytics)
275 .with_description("Streaming correlation matrix updates")
276 .with_throughput(500_000)
277 .with_latency_us(10.0),
278 state: std::sync::RwLock::new(CorrelationState::new(CorrelationConfig::default())),
279 }
280 }
281
282 #[must_use]
284 pub fn with_config(config: CorrelationConfig) -> Self {
285 Self {
286 metadata: KernelMetadata::ring("risk/realtime-correlation", Domain::RiskAnalytics)
287 .with_description("Streaming correlation matrix updates")
288 .with_throughput(500_000)
289 .with_latency_us(10.0),
290 state: std::sync::RwLock::new(CorrelationState::new(config)),
291 }
292 }
293
294 pub fn initialize(&self, asset_ids: &[u64]) {
296 let mut state = self.state.write().unwrap();
297 state.asset_index.clear();
298 for (idx, &id) in asset_ids.iter().enumerate() {
299 if idx < state.config.n_assets {
300 state.asset_index.insert(id, idx);
301 }
302 }
303 let n = state.config.n_assets;
305 state.asset_stats = vec![AssetStats::default(); n];
306 state.pairwise_stats = vec![PairwiseStats::default(); n * (n - 1) / 2];
307 state.correlation_matrix = vec![0.0; n * n];
308 state.prev_correlation_matrix = vec![0.0; n * n];
309 state.total_observations = 0;
310 }
311
312 pub fn update(&self, update: &CorrelationUpdate) -> CorrelationUpdateResult {
314 let start = Instant::now();
315 let mut state = self.state.write().unwrap();
316
317 let asset_idx = if let Some(&idx) = state.asset_index.get(&update.asset_id) {
319 idx
320 } else if state.asset_index.len() < state.config.n_assets {
321 let idx = state.asset_index.len();
322 state.asset_index.insert(update.asset_id, idx);
323 idx
324 } else {
325 return CorrelationUpdateResult {
327 asset_id: update.asset_id,
328 correlations_updated: 0,
329 significant_changes: Vec::new(),
330 latency_us: start.elapsed().as_micros() as u64,
331 };
332 };
333
334 state.asset_stats[asset_idx].update(update.value, update.timestamp);
336 state.total_observations += 1;
337
338 let n = state.config.n_assets;
340 let mut correlations_updated = 0;
341 let mut significant_changes = Vec::new();
342
343 for other_idx in 0..state.asset_index.len() {
347 if other_idx == asset_idx {
348 continue;
349 }
350
351 let other_stats = &state.asset_stats[other_idx];
352 if other_stats.count == 0 {
353 continue;
354 }
355
356 let (i, j) = if asset_idx < other_idx {
358 (asset_idx, other_idx)
359 } else {
360 (other_idx, asset_idx)
361 };
362 let pair_idx = state.pair_index(i, j);
363
364 let value_i = if asset_idx == i {
366 update.value
367 } else {
368 state.asset_stats[i].last_value
369 };
370 let value_j = if asset_idx == j {
371 update.value
372 } else {
373 state.asset_stats[j].last_value
374 };
375
376 state.pairwise_stats[pair_idx].update(value_i, value_j);
377
378 if state.pairwise_stats[pair_idx].count >= state.config.min_observations as u64 {
380 let cov = state.pairwise_stats[pair_idx].covariance();
381 let std_i = state.asset_stats[i].std_dev();
382 let std_j = state.asset_stats[j].std_dev();
383
384 let new_corr = if std_i > 1e-10 && std_j > 1e-10 {
385 (cov / (std_i * std_j)).clamp(-1.0, 1.0)
386 } else {
387 0.0
388 };
389
390 let old_corr = state.correlation_matrix[i * n + j];
392 state.prev_correlation_matrix[i * n + j] = old_corr;
393 state.prev_correlation_matrix[j * n + i] = old_corr;
394 state.correlation_matrix[i * n + j] = new_corr;
395 state.correlation_matrix[j * n + i] = new_corr;
396
397 correlations_updated += 1;
398
399 let change = (new_corr - old_corr).abs();
401 if change >= state.config.change_threshold {
402 let id_i = state
404 .asset_index
405 .iter()
406 .find(|&(_, idx)| *idx == i)
407 .map(|(&id, _)| id)
408 .unwrap_or(0);
409 let id_j = state
410 .asset_index
411 .iter()
412 .find(|&(_, idx)| *idx == j)
413 .map(|(&id, _)| id)
414 .unwrap_or(0);
415
416 significant_changes.push(CorrelationChange {
417 asset_i: id_i,
418 asset_j: id_j,
419 old_correlation: old_corr,
420 new_correlation: new_corr,
421 change,
422 });
423 }
424 }
425 }
426
427 state.correlation_matrix[asset_idx * n + asset_idx] = 1.0;
429
430 CorrelationUpdateResult {
431 asset_id: update.asset_id,
432 correlations_updated,
433 significant_changes,
434 latency_us: start.elapsed().as_micros() as u64,
435 }
436 }
437
438 pub fn update_batch(&self, updates: &[CorrelationUpdate]) -> Vec<CorrelationUpdateResult> {
440 updates.iter().map(|u| self.update(u)).collect()
441 }
442
443 pub fn get_correlation(&self, asset_i: u64, asset_j: u64) -> Option<f64> {
445 let state = self.state.read().unwrap();
446 let idx_i = state.asset_index.get(&asset_i)?;
447 let idx_j = state.asset_index.get(&asset_j)?;
448 let n = state.config.n_assets;
449 Some(state.correlation_matrix[idx_i * n + idx_j])
450 }
451
452 pub fn get_matrix(&self) -> CorrelationMatrixResult {
454 let start = Instant::now();
455 let state = self.state.read().unwrap();
456
457 CorrelationMatrixResult {
458 n_assets: state.asset_index.len(),
459 correlations: state.correlation_matrix.clone(),
460 observations: state.total_observations,
461 timestamp: state
462 .asset_stats
463 .iter()
464 .map(|s| s.last_timestamp)
465 .max()
466 .unwrap_or(0),
467 compute_time_us: start.elapsed().as_micros() as u64,
468 }
469 }
470
471 pub fn get_row(&self, asset_id: u64) -> Option<Vec<(u64, f64)>> {
473 let state = self.state.read().unwrap();
474 let idx = state.asset_index.get(&asset_id)?;
475 let n = state.config.n_assets;
476
477 Some(
478 state
479 .asset_index
480 .iter()
481 .map(|(&id, &j)| (id, state.correlation_matrix[idx * n + j]))
482 .collect(),
483 )
484 }
485
486 pub fn reset(&self) {
488 let mut state = self.state.write().unwrap();
489 let config = state.config.clone();
490 *state = CorrelationState::new(config);
491 }
492
493 pub fn compute_from_returns(returns: &[Vec<f64>]) -> CorrelationMatrixResult {
495 let start = Instant::now();
496
497 if returns.is_empty() || returns[0].is_empty() {
498 return CorrelationMatrixResult {
499 n_assets: 0,
500 correlations: Vec::new(),
501 observations: 0,
502 timestamp: 0,
503 compute_time_us: start.elapsed().as_micros() as u64,
504 };
505 }
506
507 let n = returns.len();
508 let t = returns[0].len();
509
510 let means: Vec<f64> = returns
512 .iter()
513 .map(|r| r.iter().sum::<f64>() / t as f64)
514 .collect();
515
516 let stds: Vec<f64> = returns
518 .iter()
519 .zip(means.iter())
520 .map(|(r, &mean)| {
521 let var = r.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (t - 1) as f64;
522 var.sqrt()
523 })
524 .collect();
525
526 let mut correlations = vec![0.0; n * n];
528
529 for i in 0..n {
530 correlations[i * n + i] = 1.0; for j in (i + 1)..n {
533 let cov: f64 = returns[i]
534 .iter()
535 .zip(returns[j].iter())
536 .map(|(&xi, &xj)| (xi - means[i]) * (xj - means[j]))
537 .sum::<f64>()
538 / (t - 1) as f64;
539
540 let corr = if stds[i] > 1e-10 && stds[j] > 1e-10 {
541 (cov / (stds[i] * stds[j])).clamp(-1.0, 1.0)
542 } else {
543 0.0
544 };
545
546 correlations[i * n + j] = corr;
547 correlations[j * n + i] = corr;
548 }
549 }
550
551 CorrelationMatrixResult {
552 n_assets: n,
553 correlations,
554 observations: t as u64,
555 timestamp: 0,
556 compute_time_us: start.elapsed().as_micros() as u64,
557 }
558 }
559}
560
561impl GpuKernel for RealTimeCorrelation {
562 fn metadata(&self) -> &KernelMetadata {
563 &self.metadata
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn test_realtime_correlation_metadata() {
573 let kernel = RealTimeCorrelation::new();
574 assert_eq!(kernel.metadata().id, "risk/realtime-correlation");
575 assert_eq!(kernel.metadata().domain, Domain::RiskAnalytics);
576 }
577
578 #[test]
579 fn test_asset_stats_welford() {
580 let mut stats = AssetStats::default();
581
582 for v in [2.0, 4.0, 6.0, 8.0, 10.0] {
585 stats.update(v, 0);
586 }
587
588 assert!((stats.mean - 6.0).abs() < 1e-10);
589 assert!((stats.variance() - 10.0).abs() < 1e-10);
590 assert!((stats.std_dev() - (10.0_f64).sqrt()).abs() < 1e-10);
591 }
592
593 #[test]
594 fn test_initialize_assets() {
595 let kernel = RealTimeCorrelation::new();
596 kernel.initialize(&[100, 101, 102]);
597
598 let state = kernel.state.read().unwrap();
600 assert_eq!(state.asset_index.len(), 3);
601 }
602
603 #[test]
604 fn test_streaming_updates() {
605 let config = CorrelationConfig {
606 n_assets: 10,
607 min_observations: 2,
608 ..Default::default()
609 };
610 let kernel = RealTimeCorrelation::with_config(config);
611 kernel.initialize(&[1, 2]);
612
613 for i in 0..50 {
615 let r1 = (i as f64) * 0.01;
616 let r2 = r1 * 0.8 + 0.002; kernel.update(&CorrelationUpdate {
619 asset_id: 1,
620 value: r1,
621 timestamp: i as u64,
622 });
623 kernel.update(&CorrelationUpdate {
624 asset_id: 2,
625 value: r2,
626 timestamp: i as u64,
627 });
628 }
629
630 let corr = kernel.get_correlation(1, 2).unwrap();
632 assert!(corr > 0.9, "Expected high correlation, got: {}", corr);
633 }
634
635 #[test]
636 fn test_uncorrelated_assets() {
637 let config = CorrelationConfig {
638 n_assets: 10,
639 min_observations: 2,
640 ..Default::default()
641 };
642 let kernel = RealTimeCorrelation::with_config(config);
643 kernel.initialize(&[1, 2]);
644
645 for i in 0..100 {
647 let r1 = if i % 2 == 0 { 0.01 } else { -0.01 };
648 let r2 = if i % 3 == 0 { 0.01 } else { -0.01 };
649
650 kernel.update(&CorrelationUpdate {
651 asset_id: 1,
652 value: r1,
653 timestamp: i as u64,
654 });
655 kernel.update(&CorrelationUpdate {
656 asset_id: 2,
657 value: r2,
658 timestamp: i as u64,
659 });
660 }
661
662 let corr = kernel.get_correlation(1, 2).unwrap();
664 assert!(corr.abs() < 0.5, "Expected low correlation, got: {}", corr);
665 }
666
667 #[test]
668 fn test_correlation_matrix_diagonal() {
669 let kernel = RealTimeCorrelation::new();
670 kernel.initialize(&[1, 2, 3]);
671
672 for i in 0..30 {
674 kernel.update(&CorrelationUpdate {
675 asset_id: 1,
676 value: i as f64 * 0.01,
677 timestamp: i as u64,
678 });
679 kernel.update(&CorrelationUpdate {
680 asset_id: 2,
681 value: i as f64 * 0.02,
682 timestamp: i as u64,
683 });
684 kernel.update(&CorrelationUpdate {
685 asset_id: 3,
686 value: i as f64 * 0.015,
687 timestamp: i as u64,
688 });
689 }
690
691 let corr_11 = kernel.get_correlation(1, 1).unwrap();
693 let corr_22 = kernel.get_correlation(2, 2).unwrap();
694 let corr_33 = kernel.get_correlation(3, 3).unwrap();
695
696 assert!((corr_11 - 1.0).abs() < 1e-10);
697 assert!((corr_22 - 1.0).abs() < 1e-10);
698 assert!((corr_33 - 1.0).abs() < 1e-10);
699 }
700
701 #[test]
702 fn test_batch_correlation() {
703 let returns = vec![
705 vec![
706 0.01, 0.02, -0.01, 0.03, 0.01, -0.02, 0.01, 0.02, -0.01, 0.01,
707 ],
708 vec![
709 0.02, 0.03, -0.02, 0.04, 0.02, -0.03, 0.02, 0.03, -0.02, 0.02,
710 ], vec![
712 -0.01, 0.01, 0.02, -0.02, 0.03, 0.01, -0.01, 0.02, 0.01, -0.01,
713 ], ];
715
716 let result = RealTimeCorrelation::compute_from_returns(&returns);
717
718 assert_eq!(result.n_assets, 3);
719 assert_eq!(result.observations, 10);
720
721 let n = result.n_assets;
723 for i in 0..n {
725 assert!((result.correlations[i * n + i] - 1.0).abs() < 1e-10);
726 }
727 for i in 0..n {
729 for j in 0..n {
730 let diff = (result.correlations[i * n + j] - result.correlations[j * n + i]).abs();
731 assert!(diff < 1e-10);
732 }
733 }
734 let corr_01 = result.correlations[0 * n + 1];
736 assert!(corr_01 > 0.9, "Expected high correlation: {}", corr_01);
737 }
738
739 #[test]
740 fn test_significant_change_detection() {
741 let config = CorrelationConfig {
742 n_assets: 10,
743 min_observations: 2,
744 change_threshold: 0.3, ..Default::default()
746 };
747 let kernel = RealTimeCorrelation::with_config(config);
748 kernel.initialize(&[1, 2]);
749
750 for i in 0..50 {
752 kernel.update(&CorrelationUpdate {
753 asset_id: 1,
754 value: i as f64 * 0.01,
755 timestamp: i as u64,
756 });
757 kernel.update(&CorrelationUpdate {
758 asset_id: 2,
759 value: i as f64 * 0.01 + 0.001,
760 timestamp: i as u64,
761 });
762 }
763
764 let baseline_corr = kernel.get_correlation(1, 2).unwrap();
767 assert!(
768 baseline_corr > 0.9,
769 "Expected high positive correlation: {}",
770 baseline_corr
771 );
772 }
773
774 #[test]
775 fn test_get_row() {
776 let kernel = RealTimeCorrelation::new();
777 kernel.initialize(&[1, 2, 3]);
778
779 for i in 0..30 {
781 kernel.update(&CorrelationUpdate {
782 asset_id: 1,
783 value: i as f64,
784 timestamp: i as u64,
785 });
786 kernel.update(&CorrelationUpdate {
787 asset_id: 2,
788 value: i as f64 * 2.0,
789 timestamp: i as u64,
790 });
791 kernel.update(&CorrelationUpdate {
792 asset_id: 3,
793 value: i as f64 * 1.5,
794 timestamp: i as u64,
795 });
796 }
797
798 let row = kernel.get_row(1).unwrap();
799 assert_eq!(row.len(), 3);
800
801 let self_corr = row.iter().find(|(id, _)| *id == 1).map(|(_, c)| *c);
803 assert!((self_corr.unwrap() - 1.0).abs() < 1e-10);
804 }
805
806 #[test]
807 fn test_reset() {
808 let kernel = RealTimeCorrelation::new();
809 kernel.initialize(&[1, 2]);
810
811 for i in 0..30 {
812 kernel.update(&CorrelationUpdate {
813 asset_id: 1,
814 value: i as f64,
815 timestamp: i as u64,
816 });
817 }
818
819 let matrix_before = kernel.get_matrix();
820 assert!(matrix_before.observations > 0);
821
822 kernel.reset();
823
824 let matrix_after = kernel.get_matrix();
825 assert_eq!(matrix_after.observations, 0);
826 }
827
828 #[test]
829 fn test_empty_returns() {
830 let result = RealTimeCorrelation::compute_from_returns(&[]);
831 assert_eq!(result.n_assets, 0);
832
833 let empty_inner: Vec<Vec<f64>> = vec![vec![]];
834 let result2 = RealTimeCorrelation::compute_from_returns(&empty_inner);
835 assert_eq!(result2.n_assets, 0);
836 }
837}