1use rustkernel_core::traits::GpuKernel;
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::time::Instant;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
20pub enum CorrelationType {
21 #[default]
23 Pearson,
24 Exponential,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct CorrelationConfig {
31 pub n_assets: usize,
33 pub correlation_type: CorrelationType,
35 pub decay_factor: f64,
38 pub min_observations: usize,
40 pub change_threshold: f64,
42}
43
44impl Default for CorrelationConfig {
45 fn default() -> Self {
46 Self {
47 n_assets: 100,
48 correlation_type: CorrelationType::Pearson,
49 decay_factor: 0.94, min_observations: 30,
51 change_threshold: 0.1, }
53 }
54}
55
56#[derive(Debug, Clone, Default)]
58pub struct AssetStats {
59 pub count: u64,
61 pub mean: f64,
63 pub m2: f64,
65 pub last_value: f64,
67 pub last_timestamp: u64,
69}
70
71impl AssetStats {
72 pub fn update(&mut self, value: f64, timestamp: u64) {
74 self.count += 1;
75 let delta = value - self.mean;
76 self.mean += delta / self.count as f64;
77 let delta2 = value - self.mean;
78 self.m2 += delta * delta2;
79 self.last_value = value;
80 self.last_timestamp = timestamp;
81 }
82
83 pub fn variance(&self) -> f64 {
85 if self.count < 2 {
86 0.0
87 } else {
88 self.m2 / (self.count - 1) as f64
89 }
90 }
91
92 pub fn std_dev(&self) -> f64 {
94 self.variance().sqrt()
95 }
96}
97
98#[derive(Debug, Clone, Default)]
100pub struct PairwiseStats {
101 pub count: u64,
103 pub mean_i: f64,
105 pub mean_j: f64,
107 pub co_moment: f64,
109}
110
111impl PairwiseStats {
112 pub fn update(&mut self, value_i: f64, value_j: f64) {
114 self.count += 1;
115 let n = self.count as f64;
116
117 let delta_i = value_i - self.mean_i;
118 let delta_j = value_j - self.mean_j;
119
120 self.mean_i += delta_i / n;
121 self.mean_j += delta_j / n;
122
123 let delta_j_new = value_j - self.mean_j;
125 self.co_moment += delta_i * delta_j_new;
126 }
127
128 pub fn covariance(&self) -> f64 {
130 if self.count < 2 {
131 0.0
132 } else {
133 self.co_moment / (self.count - 1) as f64
134 }
135 }
136}
137
138#[derive(Debug, Clone, Default)]
140pub struct CorrelationState {
141 pub config: CorrelationConfig,
143 pub asset_stats: Vec<AssetStats>,
145 pub pairwise_stats: Vec<PairwiseStats>,
148 pub correlation_matrix: Vec<f64>,
150 pub prev_correlation_matrix: Vec<f64>,
152 pub total_observations: u64,
154 pub asset_index: HashMap<u64, usize>,
156}
157
158impl CorrelationState {
159 pub fn new(config: CorrelationConfig) -> Self {
161 let n = config.n_assets;
162 let n_pairs = n * (n - 1) / 2;
163
164 Self {
165 config,
166 asset_stats: vec![AssetStats::default(); n],
167 pairwise_stats: vec![PairwiseStats::default(); n_pairs],
168 correlation_matrix: vec![0.0; n * n],
169 prev_correlation_matrix: vec![0.0; n * n],
170 total_observations: 0,
171 asset_index: HashMap::new(),
172 }
173 }
174
175 fn pair_index(&self, i: usize, j: usize) -> usize {
177 let (i, j) = if i < j { (i, j) } else { (j, i) };
178 let n = self.config.n_assets;
179 i * (2 * n - i - 1) / 2 + (j - i - 1)
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct CorrelationUpdate {
186 pub asset_id: u64,
188 pub value: f64,
190 pub timestamp: u64,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct CorrelationUpdateResult {
197 pub asset_id: u64,
199 pub correlations_updated: usize,
201 pub significant_changes: Vec<CorrelationChange>,
203 pub latency_us: u64,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct CorrelationChange {
210 pub asset_i: u64,
212 pub asset_j: u64,
214 pub old_correlation: f64,
216 pub new_correlation: f64,
218 pub change: f64,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct CorrelationMatrixResult {
225 pub n_assets: usize,
227 pub correlations: Vec<f64>,
229 pub observations: u64,
231 pub timestamp: u64,
233 pub compute_time_us: u64,
235}
236
237#[derive(Debug)]
243pub struct RealTimeCorrelation {
244 metadata: KernelMetadata,
245 state: std::sync::RwLock<CorrelationState>,
247}
248
249impl Clone for RealTimeCorrelation {
250 fn clone(&self) -> Self {
251 Self {
252 metadata: self.metadata.clone(),
253 state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
254 }
255 }
256}
257
258impl Default for RealTimeCorrelation {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264impl RealTimeCorrelation {
265 #[must_use]
267 pub fn new() -> Self {
268 Self {
269 metadata: KernelMetadata::ring("risk/realtime-correlation", Domain::RiskAnalytics)
270 .with_description("Streaming correlation matrix updates")
271 .with_throughput(500_000)
272 .with_latency_us(10.0),
273 state: std::sync::RwLock::new(CorrelationState::new(CorrelationConfig::default())),
274 }
275 }
276
277 #[must_use]
279 pub fn with_config(config: CorrelationConfig) -> Self {
280 Self {
281 metadata: KernelMetadata::ring("risk/realtime-correlation", Domain::RiskAnalytics)
282 .with_description("Streaming correlation matrix updates")
283 .with_throughput(500_000)
284 .with_latency_us(10.0),
285 state: std::sync::RwLock::new(CorrelationState::new(config)),
286 }
287 }
288
289 pub fn initialize(&self, asset_ids: &[u64]) {
291 let mut state = self.state.write().unwrap();
292 state.asset_index.clear();
293 for (idx, &id) in asset_ids.iter().enumerate() {
294 if idx < state.config.n_assets {
295 state.asset_index.insert(id, idx);
296 }
297 }
298 let n = state.config.n_assets;
300 state.asset_stats = vec![AssetStats::default(); n];
301 state.pairwise_stats = vec![PairwiseStats::default(); n * (n - 1) / 2];
302 state.correlation_matrix = vec![0.0; n * n];
303 state.prev_correlation_matrix = vec![0.0; n * n];
304 state.total_observations = 0;
305 }
306
307 pub fn update(&self, update: &CorrelationUpdate) -> CorrelationUpdateResult {
309 let start = Instant::now();
310 let mut state = self.state.write().unwrap();
311
312 let asset_idx = if let Some(&idx) = state.asset_index.get(&update.asset_id) {
314 idx
315 } else if state.asset_index.len() < state.config.n_assets {
316 let idx = state.asset_index.len();
317 state.asset_index.insert(update.asset_id, idx);
318 idx
319 } else {
320 return CorrelationUpdateResult {
322 asset_id: update.asset_id,
323 correlations_updated: 0,
324 significant_changes: Vec::new(),
325 latency_us: start.elapsed().as_micros() as u64,
326 };
327 };
328
329 state.asset_stats[asset_idx].update(update.value, update.timestamp);
331 state.total_observations += 1;
332
333 let n = state.config.n_assets;
335 let mut correlations_updated = 0;
336 let mut significant_changes = Vec::new();
337
338 for other_idx in 0..state.asset_index.len() {
342 if other_idx == asset_idx {
343 continue;
344 }
345
346 let other_stats = &state.asset_stats[other_idx];
347 if other_stats.count == 0 {
348 continue;
349 }
350
351 let (i, j) = if asset_idx < other_idx {
353 (asset_idx, other_idx)
354 } else {
355 (other_idx, asset_idx)
356 };
357 let pair_idx = state.pair_index(i, j);
358
359 let value_i = if asset_idx == i {
361 update.value
362 } else {
363 state.asset_stats[i].last_value
364 };
365 let value_j = if asset_idx == j {
366 update.value
367 } else {
368 state.asset_stats[j].last_value
369 };
370
371 state.pairwise_stats[pair_idx].update(value_i, value_j);
372
373 if state.pairwise_stats[pair_idx].count >= state.config.min_observations as u64 {
375 let cov = state.pairwise_stats[pair_idx].covariance();
376 let std_i = state.asset_stats[i].std_dev();
377 let std_j = state.asset_stats[j].std_dev();
378
379 let new_corr = if std_i > 1e-10 && std_j > 1e-10 {
380 (cov / (std_i * std_j)).clamp(-1.0, 1.0)
381 } else {
382 0.0
383 };
384
385 let old_corr = state.correlation_matrix[i * n + j];
387 state.prev_correlation_matrix[i * n + j] = old_corr;
388 state.prev_correlation_matrix[j * n + i] = old_corr;
389 state.correlation_matrix[i * n + j] = new_corr;
390 state.correlation_matrix[j * n + i] = new_corr;
391
392 correlations_updated += 1;
393
394 let change = (new_corr - old_corr).abs();
396 if change >= state.config.change_threshold {
397 let id_i = state
399 .asset_index
400 .iter()
401 .find(|&(_, idx)| *idx == i)
402 .map(|(&id, _)| id)
403 .unwrap_or(0);
404 let id_j = state
405 .asset_index
406 .iter()
407 .find(|&(_, idx)| *idx == j)
408 .map(|(&id, _)| id)
409 .unwrap_or(0);
410
411 significant_changes.push(CorrelationChange {
412 asset_i: id_i,
413 asset_j: id_j,
414 old_correlation: old_corr,
415 new_correlation: new_corr,
416 change,
417 });
418 }
419 }
420 }
421
422 state.correlation_matrix[asset_idx * n + asset_idx] = 1.0;
424
425 CorrelationUpdateResult {
426 asset_id: update.asset_id,
427 correlations_updated,
428 significant_changes,
429 latency_us: start.elapsed().as_micros() as u64,
430 }
431 }
432
433 pub fn update_batch(&self, updates: &[CorrelationUpdate]) -> Vec<CorrelationUpdateResult> {
435 updates.iter().map(|u| self.update(u)).collect()
436 }
437
438 pub fn get_correlation(&self, asset_i: u64, asset_j: u64) -> Option<f64> {
440 let state = self.state.read().unwrap();
441 let idx_i = state.asset_index.get(&asset_i)?;
442 let idx_j = state.asset_index.get(&asset_j)?;
443 let n = state.config.n_assets;
444 Some(state.correlation_matrix[idx_i * n + idx_j])
445 }
446
447 pub fn get_matrix(&self) -> CorrelationMatrixResult {
449 let start = Instant::now();
450 let state = self.state.read().unwrap();
451
452 CorrelationMatrixResult {
453 n_assets: state.asset_index.len(),
454 correlations: state.correlation_matrix.clone(),
455 observations: state.total_observations,
456 timestamp: state
457 .asset_stats
458 .iter()
459 .map(|s| s.last_timestamp)
460 .max()
461 .unwrap_or(0),
462 compute_time_us: start.elapsed().as_micros() as u64,
463 }
464 }
465
466 pub fn get_row(&self, asset_id: u64) -> Option<Vec<(u64, f64)>> {
468 let state = self.state.read().unwrap();
469 let idx = state.asset_index.get(&asset_id)?;
470 let n = state.config.n_assets;
471
472 Some(
473 state
474 .asset_index
475 .iter()
476 .map(|(&id, &j)| (id, state.correlation_matrix[idx * n + j]))
477 .collect(),
478 )
479 }
480
481 pub fn reset(&self) {
483 let mut state = self.state.write().unwrap();
484 let config = state.config.clone();
485 *state = CorrelationState::new(config);
486 }
487
488 pub fn compute_from_returns(returns: &[Vec<f64>]) -> CorrelationMatrixResult {
490 let start = Instant::now();
491
492 if returns.is_empty() || returns[0].is_empty() {
493 return CorrelationMatrixResult {
494 n_assets: 0,
495 correlations: Vec::new(),
496 observations: 0,
497 timestamp: 0,
498 compute_time_us: start.elapsed().as_micros() as u64,
499 };
500 }
501
502 let n = returns.len();
503 let t = returns[0].len();
504
505 let means: Vec<f64> = returns
507 .iter()
508 .map(|r| r.iter().sum::<f64>() / t as f64)
509 .collect();
510
511 let stds: Vec<f64> = returns
513 .iter()
514 .zip(means.iter())
515 .map(|(r, &mean)| {
516 let var = r.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (t - 1) as f64;
517 var.sqrt()
518 })
519 .collect();
520
521 let mut correlations = vec![0.0; n * n];
523
524 for i in 0..n {
525 correlations[i * n + i] = 1.0; for j in (i + 1)..n {
528 let cov: f64 = returns[i]
529 .iter()
530 .zip(returns[j].iter())
531 .map(|(&xi, &xj)| (xi - means[i]) * (xj - means[j]))
532 .sum::<f64>()
533 / (t - 1) as f64;
534
535 let corr = if stds[i] > 1e-10 && stds[j] > 1e-10 {
536 (cov / (stds[i] * stds[j])).clamp(-1.0, 1.0)
537 } else {
538 0.0
539 };
540
541 correlations[i * n + j] = corr;
542 correlations[j * n + i] = corr;
543 }
544 }
545
546 CorrelationMatrixResult {
547 n_assets: n,
548 correlations,
549 observations: t as u64,
550 timestamp: 0,
551 compute_time_us: start.elapsed().as_micros() as u64,
552 }
553 }
554}
555
556impl GpuKernel for RealTimeCorrelation {
557 fn metadata(&self) -> &KernelMetadata {
558 &self.metadata
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn test_realtime_correlation_metadata() {
568 let kernel = RealTimeCorrelation::new();
569 assert_eq!(kernel.metadata().id, "risk/realtime-correlation");
570 assert_eq!(kernel.metadata().domain, Domain::RiskAnalytics);
571 }
572
573 #[test]
574 fn test_asset_stats_welford() {
575 let mut stats = AssetStats::default();
576
577 for v in [2.0, 4.0, 6.0, 8.0, 10.0] {
580 stats.update(v, 0);
581 }
582
583 assert!((stats.mean - 6.0).abs() < 1e-10);
584 assert!((stats.variance() - 10.0).abs() < 1e-10);
585 assert!((stats.std_dev() - (10.0_f64).sqrt()).abs() < 1e-10);
586 }
587
588 #[test]
589 fn test_initialize_assets() {
590 let kernel = RealTimeCorrelation::new();
591 kernel.initialize(&[100, 101, 102]);
592
593 let state = kernel.state.read().unwrap();
595 assert_eq!(state.asset_index.len(), 3);
596 }
597
598 #[test]
599 fn test_streaming_updates() {
600 let config = CorrelationConfig {
601 n_assets: 10,
602 min_observations: 2,
603 ..Default::default()
604 };
605 let kernel = RealTimeCorrelation::with_config(config);
606 kernel.initialize(&[1, 2]);
607
608 for i in 0..50 {
610 let r1 = (i as f64) * 0.01;
611 let r2 = r1 * 0.8 + 0.002; kernel.update(&CorrelationUpdate {
614 asset_id: 1,
615 value: r1,
616 timestamp: i as u64,
617 });
618 kernel.update(&CorrelationUpdate {
619 asset_id: 2,
620 value: r2,
621 timestamp: i as u64,
622 });
623 }
624
625 let corr = kernel.get_correlation(1, 2).unwrap();
627 assert!(corr > 0.9, "Expected high correlation, got: {}", corr);
628 }
629
630 #[test]
631 fn test_uncorrelated_assets() {
632 let config = CorrelationConfig {
633 n_assets: 10,
634 min_observations: 2,
635 ..Default::default()
636 };
637 let kernel = RealTimeCorrelation::with_config(config);
638 kernel.initialize(&[1, 2]);
639
640 for i in 0..100 {
642 let r1 = if i % 2 == 0 { 0.01 } else { -0.01 };
643 let r2 = if i % 3 == 0 { 0.01 } else { -0.01 };
644
645 kernel.update(&CorrelationUpdate {
646 asset_id: 1,
647 value: r1,
648 timestamp: i as u64,
649 });
650 kernel.update(&CorrelationUpdate {
651 asset_id: 2,
652 value: r2,
653 timestamp: i as u64,
654 });
655 }
656
657 let corr = kernel.get_correlation(1, 2).unwrap();
659 assert!(corr.abs() < 0.5, "Expected low correlation, got: {}", corr);
660 }
661
662 #[test]
663 fn test_correlation_matrix_diagonal() {
664 let kernel = RealTimeCorrelation::new();
665 kernel.initialize(&[1, 2, 3]);
666
667 for i in 0..30 {
669 kernel.update(&CorrelationUpdate {
670 asset_id: 1,
671 value: i as f64 * 0.01,
672 timestamp: i as u64,
673 });
674 kernel.update(&CorrelationUpdate {
675 asset_id: 2,
676 value: i as f64 * 0.02,
677 timestamp: i as u64,
678 });
679 kernel.update(&CorrelationUpdate {
680 asset_id: 3,
681 value: i as f64 * 0.015,
682 timestamp: i as u64,
683 });
684 }
685
686 let corr_11 = kernel.get_correlation(1, 1).unwrap();
688 let corr_22 = kernel.get_correlation(2, 2).unwrap();
689 let corr_33 = kernel.get_correlation(3, 3).unwrap();
690
691 assert!((corr_11 - 1.0).abs() < 1e-10);
692 assert!((corr_22 - 1.0).abs() < 1e-10);
693 assert!((corr_33 - 1.0).abs() < 1e-10);
694 }
695
696 #[test]
697 fn test_batch_correlation() {
698 let returns = vec![
700 vec![
701 0.01, 0.02, -0.01, 0.03, 0.01, -0.02, 0.01, 0.02, -0.01, 0.01,
702 ],
703 vec![
704 0.02, 0.03, -0.02, 0.04, 0.02, -0.03, 0.02, 0.03, -0.02, 0.02,
705 ], vec![
707 -0.01, 0.01, 0.02, -0.02, 0.03, 0.01, -0.01, 0.02, 0.01, -0.01,
708 ], ];
710
711 let result = RealTimeCorrelation::compute_from_returns(&returns);
712
713 assert_eq!(result.n_assets, 3);
714 assert_eq!(result.observations, 10);
715
716 let n = result.n_assets;
718 for i in 0..n {
720 assert!((result.correlations[i * n + i] - 1.0).abs() < 1e-10);
721 }
722 for i in 0..n {
724 for j in 0..n {
725 let diff = (result.correlations[i * n + j] - result.correlations[j * n + i]).abs();
726 assert!(diff < 1e-10);
727 }
728 }
729 let corr_01 = result.correlations[1];
731 assert!(corr_01 > 0.9, "Expected high correlation: {}", corr_01);
732 }
733
734 #[test]
735 fn test_significant_change_detection() {
736 let config = CorrelationConfig {
737 n_assets: 10,
738 min_observations: 2,
739 change_threshold: 0.3, ..Default::default()
741 };
742 let kernel = RealTimeCorrelation::with_config(config);
743 kernel.initialize(&[1, 2]);
744
745 for i in 0..50 {
747 kernel.update(&CorrelationUpdate {
748 asset_id: 1,
749 value: i as f64 * 0.01,
750 timestamp: i as u64,
751 });
752 kernel.update(&CorrelationUpdate {
753 asset_id: 2,
754 value: i as f64 * 0.01 + 0.001,
755 timestamp: i as u64,
756 });
757 }
758
759 let baseline_corr = kernel.get_correlation(1, 2).unwrap();
762 assert!(
763 baseline_corr > 0.9,
764 "Expected high positive correlation: {}",
765 baseline_corr
766 );
767 }
768
769 #[test]
770 fn test_get_row() {
771 let kernel = RealTimeCorrelation::new();
772 kernel.initialize(&[1, 2, 3]);
773
774 for i in 0..30 {
776 kernel.update(&CorrelationUpdate {
777 asset_id: 1,
778 value: i as f64,
779 timestamp: i as u64,
780 });
781 kernel.update(&CorrelationUpdate {
782 asset_id: 2,
783 value: i as f64 * 2.0,
784 timestamp: i as u64,
785 });
786 kernel.update(&CorrelationUpdate {
787 asset_id: 3,
788 value: i as f64 * 1.5,
789 timestamp: i as u64,
790 });
791 }
792
793 let row = kernel.get_row(1).unwrap();
794 assert_eq!(row.len(), 3);
795
796 let self_corr = row.iter().find(|(id, _)| *id == 1).map(|(_, c)| *c);
798 assert!((self_corr.unwrap() - 1.0).abs() < 1e-10);
799 }
800
801 #[test]
802 fn test_reset() {
803 let kernel = RealTimeCorrelation::new();
804 kernel.initialize(&[1, 2]);
805
806 for i in 0..30 {
807 kernel.update(&CorrelationUpdate {
808 asset_id: 1,
809 value: i as f64,
810 timestamp: i as u64,
811 });
812 }
813
814 let matrix_before = kernel.get_matrix();
815 assert!(matrix_before.observations > 0);
816
817 kernel.reset();
818
819 let matrix_after = kernel.get_matrix();
820 assert_eq!(matrix_after.observations, 0);
821 }
822
823 #[test]
824 fn test_empty_returns() {
825 let result = RealTimeCorrelation::compute_from_returns(&[]);
826 assert_eq!(result.n_assets, 0);
827
828 let empty_inner: Vec<Vec<f64>> = vec![vec![]];
829 let result2 = RealTimeCorrelation::compute_from_returns(&empty_inner);
830 assert_eq!(result2.n_assets, 0);
831 }
832}