1use crate::types::{AnomalyResult, DataMatrix};
8use rand::prelude::*;
9use rand::{Rng, rng};
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct StreamingConfig {
21 pub n_trees: usize,
23 pub sample_size: usize,
25 pub window_size: usize,
27 pub rebuild_interval: usize,
29 pub contamination: f64,
31 pub use_sliding_window: bool,
33}
34
35impl Default for StreamingConfig {
36 fn default() -> Self {
37 Self {
38 n_trees: 100,
39 sample_size: 256,
40 window_size: 10000,
41 rebuild_interval: 1000,
42 contamination: 0.1,
43 use_sliding_window: true,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct StreamingState {
51 window: VecDeque<Vec<f64>>,
53 n_features: usize,
55 trees: Vec<StreamingITree>,
57 samples_since_rebuild: usize,
59 total_samples: usize,
61 score_stats: OnlineStats,
63 threshold: f64,
65}
66
67impl StreamingState {
68 pub fn new(n_features: usize) -> Self {
70 Self {
71 window: VecDeque::new(),
72 n_features,
73 trees: Vec::new(),
74 samples_since_rebuild: 0,
75 total_samples: 0,
76 score_stats: OnlineStats::new(),
77 threshold: 0.5,
78 }
79 }
80
81 pub fn window_size(&self) -> usize {
83 self.window.len()
84 }
85
86 pub fn total_samples(&self) -> usize {
88 self.total_samples
89 }
90
91 pub fn threshold(&self) -> f64 {
93 self.threshold
94 }
95}
96
97#[derive(Debug, Clone, Default)]
99struct OnlineStats {
100 count: u64,
101 mean: f64,
102 m2: f64, min: f64,
104 max: f64,
105}
106
107impl OnlineStats {
108 fn new() -> Self {
109 Self {
110 count: 0,
111 mean: 0.0,
112 m2: 0.0,
113 min: f64::MAX,
114 max: f64::MIN,
115 }
116 }
117
118 fn update(&mut self, value: f64) {
120 self.count += 1;
121 let delta = value - self.mean;
122 self.mean += delta / self.count as f64;
123 let delta2 = value - self.mean;
124 self.m2 += delta * delta2;
125 self.min = self.min.min(value);
126 self.max = self.max.max(value);
127 }
128
129 fn variance(&self) -> f64 {
130 if self.count < 2 {
131 0.0
132 } else {
133 self.m2 / (self.count - 1) as f64
134 }
135 }
136
137 fn std_dev(&self) -> f64 {
138 self.variance().sqrt()
139 }
140}
141
142#[derive(Debug, Clone)]
144enum StreamingINode {
145 Internal {
146 split_feature: usize,
147 split_value: f64,
148 left: Box<StreamingINode>,
149 right: Box<StreamingINode>,
150 },
151 External {
152 size: usize,
153 },
154}
155
156#[derive(Debug, Clone)]
158#[allow(dead_code)]
159struct StreamingITree {
160 root: StreamingINode,
161 max_depth: usize,
162}
163
164impl StreamingITree {
165 fn build(samples: &[Vec<f64>], max_depth: usize) -> Self {
167 let root = Self::build_node(samples, 0, max_depth);
168 Self { root, max_depth }
169 }
170
171 fn build_node(samples: &[Vec<f64>], depth: usize, max_depth: usize) -> StreamingINode {
172 if samples.is_empty() || depth >= max_depth || samples.len() <= 1 {
173 return StreamingINode::External {
174 size: samples.len(),
175 };
176 }
177
178 let n_features = samples[0].len();
179 if n_features == 0 {
180 return StreamingINode::External {
181 size: samples.len(),
182 };
183 }
184
185 let mut rng = rng();
186 let feature = rng.random_range(0..n_features);
187
188 let values: Vec<f64> = samples.iter().map(|s| s[feature]).collect();
190 let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
191 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
192
193 if (max_val - min_val).abs() < 1e-10 {
194 return StreamingINode::External {
195 size: samples.len(),
196 };
197 }
198
199 let split_value = rng.random_range(min_val..max_val);
200
201 let (left_samples, right_samples): (Vec<_>, Vec<_>) = samples
202 .iter()
203 .cloned()
204 .partition(|s| s[feature] < split_value);
205
206 StreamingINode::Internal {
207 split_feature: feature,
208 split_value,
209 left: Box::new(Self::build_node(&left_samples, depth + 1, max_depth)),
210 right: Box::new(Self::build_node(&right_samples, depth + 1, max_depth)),
211 }
212 }
213
214 fn path_length(&self, point: &[f64]) -> f64 {
216 self.path_length_node(&self.root, point, 0)
217 }
218
219 #[allow(clippy::only_used_in_recursion)]
220 fn path_length_node(&self, node: &StreamingINode, point: &[f64], depth: usize) -> f64 {
221 match node {
222 StreamingINode::External { size } => depth as f64 + Self::c_factor(*size),
223 StreamingINode::Internal {
224 split_feature,
225 split_value,
226 left,
227 right,
228 } => {
229 if point[*split_feature] < *split_value {
230 self.path_length_node(left, point, depth + 1)
231 } else {
232 self.path_length_node(right, point, depth + 1)
233 }
234 }
235 }
236 }
237
238 fn c_factor(n: usize) -> f64 {
240 if n <= 1 {
241 0.0
242 } else if n == 2 {
243 1.0
244 } else {
245 let n_f = n as f64;
246 2.0 * ((n_f - 1.0).ln() + 0.5772156649) - 2.0 * (n_f - 1.0) / n_f
248 }
249 }
250}
251
252#[derive(Debug, Clone)]
258pub struct StreamingIsolationForest {
259 metadata: KernelMetadata,
260}
261
262impl Default for StreamingIsolationForest {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268impl StreamingIsolationForest {
269 #[must_use]
271 pub fn new() -> Self {
272 Self {
273 metadata: KernelMetadata::batch("ml/streaming-isolation-forest", Domain::StatisticalML)
274 .with_description("Online streaming anomaly detection with sliding window")
275 .with_throughput(50_000)
276 .with_latency_us(20.0),
277 }
278 }
279
280 pub fn init(n_features: usize) -> StreamingState {
282 StreamingState::new(n_features)
283 }
284
285 pub fn process_sample(
289 state: &mut StreamingState,
290 sample: Vec<f64>,
291 config: &StreamingConfig,
292 ) -> (f64, bool) {
293 if sample.len() != state.n_features && state.n_features > 0 {
294 return (0.0, false); }
296
297 if state.n_features == 0 {
298 state.n_features = sample.len();
299 }
300
301 state.window.push_back(sample.clone());
303 if config.use_sliding_window && state.window.len() > config.window_size {
304 state.window.pop_front();
305 }
306
307 state.total_samples += 1;
308 state.samples_since_rebuild += 1;
309
310 if state.trees.is_empty()
312 || (state.samples_since_rebuild >= config.rebuild_interval
313 && state.window.len() >= config.sample_size)
314 {
315 Self::rebuild_forest(state, config);
316 state.samples_since_rebuild = 0;
317 }
318
319 let score = if state.trees.is_empty() {
321 0.5 } else {
323 Self::compute_score(&state.trees, &sample, config.sample_size)
324 };
325
326 state.score_stats.update(score);
328
329 if state.score_stats.count > 100 {
331 let k = Self::contamination_to_k(config.contamination);
334 state.threshold = state.score_stats.mean + k * state.score_stats.std_dev();
335 state.threshold = state.threshold.clamp(0.0, 1.0);
336 }
337
338 let is_anomaly = score >= state.threshold;
339 (score, is_anomaly)
340 }
341
342 pub fn process_batch(
344 state: &mut StreamingState,
345 samples: &DataMatrix,
346 config: &StreamingConfig,
347 ) -> AnomalyResult {
348 let mut scores = Vec::with_capacity(samples.n_samples);
349 let mut labels = Vec::with_capacity(samples.n_samples);
350
351 for i in 0..samples.n_samples {
352 let sample = samples.row(i).to_vec();
353 let (score, is_anomaly) = Self::process_sample(state, sample, config);
354 scores.push(score);
355 labels.push(if is_anomaly { -1 } else { 1 });
356 }
357
358 AnomalyResult {
359 scores,
360 labels,
361 threshold: state.threshold,
362 }
363 }
364
365 fn rebuild_forest(state: &mut StreamingState, config: &StreamingConfig) {
367 if state.window.is_empty() {
368 return;
369 }
370
371 let samples: Vec<Vec<f64>> = state.window.iter().cloned().collect();
372 let sample_size = config.sample_size.min(samples.len());
373 let max_depth = (sample_size as f64).log2().ceil() as usize;
374
375 let mut rng = rng();
376 state.trees = (0..config.n_trees)
377 .map(|_| {
378 let subset: Vec<Vec<f64>> = samples
379 .choose_multiple(&mut rng, sample_size)
380 .cloned()
381 .collect();
382 StreamingITree::build(&subset, max_depth)
383 })
384 .collect();
385 }
386
387 fn compute_score(trees: &[StreamingITree], point: &[f64], sample_size: usize) -> f64 {
389 if trees.is_empty() {
390 return 0.5;
391 }
392
393 let avg_path_length: f64 = trees
394 .iter()
395 .map(|tree| tree.path_length(point))
396 .sum::<f64>()
397 / trees.len() as f64;
398
399 let c_n = StreamingITree::c_factor(sample_size);
400 if c_n.abs() < 1e-10 {
401 return 0.5;
402 }
403
404 (2.0_f64).powf(-avg_path_length / c_n)
405 }
406
407 fn contamination_to_k(contamination: f64) -> f64 {
409 if contamination <= 0.01 {
412 2.33
413 } else if contamination <= 0.05 {
414 1.65
415 } else if contamination <= 0.10 {
416 1.28
417 } else if contamination <= 0.20 {
418 0.84
419 } else {
420 0.5
421 }
422 }
423}
424
425impl GpuKernel for StreamingIsolationForest {
426 fn metadata(&self) -> &KernelMetadata {
427 &self.metadata
428 }
429}
430
431#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct AdaptiveThresholdConfig {
438 pub initial_threshold: f64,
440 pub window_size: usize,
442 pub target_fpr: f64,
444 pub learning_rate: f64,
446 pub min_threshold: f64,
448 pub max_threshold: f64,
450 pub detect_drift: bool,
452 pub drift_sensitivity: f64,
454}
455
456impl Default for AdaptiveThresholdConfig {
457 fn default() -> Self {
458 Self {
459 initial_threshold: 0.5,
460 window_size: 1000,
461 target_fpr: 0.05,
462 learning_rate: 0.01,
463 min_threshold: 0.1,
464 max_threshold: 0.9,
465 detect_drift: true,
466 drift_sensitivity: 2.0,
467 }
468 }
469}
470
471#[derive(Debug, Clone)]
473pub struct AdaptiveThresholdState {
474 threshold: f64,
476 score_window: VecDeque<f64>,
478 label_window: VecDeque<Option<bool>>,
480 stats: OnlineStats,
482 prev_window_stats: Option<WindowStats>,
484 curr_window_stats: WindowStats,
486 total_samples: usize,
488 drift_detected: bool,
490 drift_count: usize,
492}
493
494#[derive(Debug, Clone, Default)]
496struct WindowStats {
497 mean: f64,
498 variance: f64,
499 count: usize,
500}
501
502impl AdaptiveThresholdState {
503 pub fn new(config: &AdaptiveThresholdConfig) -> Self {
505 Self {
506 threshold: config.initial_threshold,
507 score_window: VecDeque::new(),
508 label_window: VecDeque::new(),
509 stats: OnlineStats::new(),
510 prev_window_stats: None,
511 curr_window_stats: WindowStats::default(),
512 total_samples: 0,
513 drift_detected: false,
514 drift_count: 0,
515 }
516 }
517
518 pub fn threshold(&self) -> f64 {
520 self.threshold
521 }
522
523 pub fn total_samples(&self) -> usize {
525 self.total_samples
526 }
527
528 pub fn drift_detected(&self) -> bool {
530 self.drift_detected
531 }
532
533 pub fn drift_count(&self) -> usize {
535 self.drift_count
536 }
537}
538
539#[derive(Debug, Clone, Serialize, Deserialize)]
541pub struct ThresholdResult {
542 pub threshold: f64,
544 pub is_anomaly: bool,
546 pub estimated_fpr: f64,
548 pub drift_detected: bool,
550 pub confidence: f64,
552}
553
554#[derive(Debug, Clone)]
560pub struct AdaptiveThreshold {
561 metadata: KernelMetadata,
562}
563
564impl Default for AdaptiveThreshold {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570impl AdaptiveThreshold {
571 #[must_use]
573 pub fn new() -> Self {
574 Self {
575 metadata: KernelMetadata::batch("ml/adaptive-threshold", Domain::StatisticalML)
576 .with_description("Self-adjusting anomaly thresholds with drift detection")
577 .with_throughput(100_000)
578 .with_latency_us(5.0),
579 }
580 }
581
582 pub fn init(config: &AdaptiveThresholdConfig) -> AdaptiveThresholdState {
584 AdaptiveThresholdState::new(config)
585 }
586
587 pub fn process_score(
589 state: &mut AdaptiveThresholdState,
590 score: f64,
591 ground_truth: Option<bool>,
592 config: &AdaptiveThresholdConfig,
593 ) -> ThresholdResult {
594 state.stats.update(score);
596 state.total_samples += 1;
597
598 state.score_window.push_back(score);
600 state.label_window.push_back(ground_truth);
601
602 if state.score_window.len() > config.window_size {
603 state.score_window.pop_front();
604 state.label_window.pop_front();
605 }
606
607 state.curr_window_stats = Self::compute_window_stats(&state.score_window);
609
610 state.drift_detected = false;
612 if config.detect_drift {
613 if let Some(prev) = &state.prev_window_stats {
614 let drift = Self::detect_drift(prev, &state.curr_window_stats, config);
615 if drift {
616 state.drift_detected = true;
617 state.drift_count += 1;
618 state.threshold = Self::estimate_threshold_from_window(
620 &state.score_window,
621 config.target_fpr,
622 );
623 }
624 }
625 }
626
627 if let Some(is_anomaly) = ground_truth {
629 Self::update_threshold_with_feedback(state, score, is_anomaly, config);
630 } else {
631 Self::update_threshold_quantile(state, config);
633 }
634
635 if state.score_window.len() == config.window_size
638 && (state.prev_window_stats.is_none() || state.drift_detected)
639 {
640 state.prev_window_stats = Some(state.curr_window_stats.clone());
641 }
642
643 let is_anomaly = score >= state.threshold;
644 let estimated_fpr = Self::estimate_fpr(state, config);
645 let confidence = Self::compute_confidence(state, config);
646
647 ThresholdResult {
648 threshold: state.threshold,
649 is_anomaly,
650 estimated_fpr,
651 drift_detected: state.drift_detected,
652 confidence,
653 }
654 }
655
656 fn compute_window_stats(window: &VecDeque<f64>) -> WindowStats {
658 if window.is_empty() {
659 return WindowStats::default();
660 }
661
662 let count = window.len();
663 let mean: f64 = window.iter().sum::<f64>() / count as f64;
664 let variance: f64 = window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / count as f64;
665
666 WindowStats {
667 mean,
668 variance,
669 count,
670 }
671 }
672
673 fn detect_drift(
675 prev: &WindowStats,
676 curr: &WindowStats,
677 config: &AdaptiveThresholdConfig,
678 ) -> bool {
679 if prev.count < 10 || curr.count < 10 {
680 return false;
681 }
682
683 let se = ((prev.variance / prev.count as f64) + (curr.variance / curr.count as f64)).sqrt();
685 if se.abs() < 1e-10 {
686 return false;
687 }
688
689 let t_stat = (curr.mean - prev.mean).abs() / se;
690 t_stat > config.drift_sensitivity
691 }
692
693 fn estimate_threshold_from_window(window: &VecDeque<f64>, target_fpr: f64) -> f64 {
695 if window.is_empty() {
696 return 0.5;
697 }
698
699 let mut sorted: Vec<f64> = window.iter().cloned().collect();
700 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
701
702 let idx = ((1.0 - target_fpr) * sorted.len() as f64) as usize;
703 let idx = idx.min(sorted.len() - 1);
704 sorted[idx]
705 }
706
707 fn update_threshold_with_feedback(
709 state: &mut AdaptiveThresholdState,
710 score: f64,
711 is_anomaly: bool,
712 config: &AdaptiveThresholdConfig,
713 ) {
714 if score >= state.threshold && !is_anomaly {
716 state.threshold += config.learning_rate * (score - state.threshold);
718 }
719 else if score < state.threshold && is_anomaly {
721 state.threshold -= config.learning_rate * (state.threshold - score);
723 }
724
725 state.threshold = state
726 .threshold
727 .clamp(config.min_threshold, config.max_threshold);
728 }
729
730 fn update_threshold_quantile(
732 state: &mut AdaptiveThresholdState,
733 config: &AdaptiveThresholdConfig,
734 ) {
735 if state.score_window.len() < 10 {
736 return;
737 }
738
739 let target = Self::estimate_threshold_from_window(&state.score_window, config.target_fpr);
740
741 state.threshold =
743 state.threshold * (1.0 - config.learning_rate) + target * config.learning_rate;
744 state.threshold = state
745 .threshold
746 .clamp(config.min_threshold, config.max_threshold);
747 }
748
749 fn estimate_fpr(state: &AdaptiveThresholdState, _config: &AdaptiveThresholdConfig) -> f64 {
751 if state.score_window.is_empty() {
752 return 0.0;
753 }
754
755 let above_threshold = state
756 .score_window
757 .iter()
758 .filter(|&&s| s >= state.threshold)
759 .count();
760
761 above_threshold as f64 / state.score_window.len() as f64
762 }
763
764 fn compute_confidence(state: &AdaptiveThresholdState, config: &AdaptiveThresholdConfig) -> f64 {
766 let sample_factor = (state.score_window.len() as f64 / config.window_size as f64).min(1.0);
768
769 let drift_factor = if state.drift_detected { 0.5 } else { 1.0 };
771
772 let bound_factor = if (state.threshold - config.min_threshold).abs() < 0.01
774 || (state.threshold - config.max_threshold).abs() < 0.01
775 {
776 0.7
777 } else {
778 1.0
779 };
780
781 sample_factor * drift_factor * bound_factor
782 }
783
784 pub fn process_batch(
786 state: &mut AdaptiveThresholdState,
787 scores: &[f64],
788 ground_truth: Option<&[bool]>,
789 config: &AdaptiveThresholdConfig,
790 ) -> Vec<ThresholdResult> {
791 scores
792 .iter()
793 .enumerate()
794 .map(|(i, &score)| {
795 let gt = ground_truth.map(|gt| gt[i]);
796 Self::process_score(state, score, gt, config)
797 })
798 .collect()
799 }
800}
801
802impl GpuKernel for AdaptiveThreshold {
803 fn metadata(&self) -> &KernelMetadata {
804 &self.metadata
805 }
806}
807
808#[cfg(test)]
809mod tests {
810 use super::*;
811
812 #[test]
813 fn test_streaming_isolation_forest_metadata() {
814 let kernel = StreamingIsolationForest::new();
815 assert_eq!(kernel.metadata().id, "ml/streaming-isolation-forest");
816 }
817
818 #[test]
819 fn test_streaming_isolation_forest_basic() {
820 let config = StreamingConfig {
821 n_trees: 10,
822 sample_size: 50,
823 window_size: 100,
824 rebuild_interval: 20,
825 contamination: 0.1,
826 use_sliding_window: true,
827 };
828
829 let mut state = StreamingIsolationForest::init(2);
830
831 for _ in 0..50 {
833 let sample = vec![rng().random_range(0.0..1.0), rng().random_range(0.0..1.0)];
834 StreamingIsolationForest::process_sample(&mut state, sample, &config);
835 }
836
837 assert!(state.window_size() > 0);
838 assert_eq!(state.total_samples(), 50);
839
840 let (score, _is_anomaly) =
842 StreamingIsolationForest::process_sample(&mut state, vec![100.0, 100.0], &config);
843 assert!(score > 0.0);
844 }
845
846 #[test]
847 fn test_streaming_sliding_window() {
848 let config = StreamingConfig {
849 window_size: 10,
850 use_sliding_window: true,
851 ..Default::default()
852 };
853
854 let mut state = StreamingIsolationForest::init(1);
855
856 for i in 0..20 {
858 StreamingIsolationForest::process_sample(&mut state, vec![i as f64], &config);
859 }
860
861 assert_eq!(state.window_size(), 10);
863 assert_eq!(state.total_samples(), 20);
864 }
865
866 #[test]
867 fn test_adaptive_threshold_metadata() {
868 let kernel = AdaptiveThreshold::new();
869 assert_eq!(kernel.metadata().id, "ml/adaptive-threshold");
870 }
871
872 #[test]
873 fn test_adaptive_threshold_basic() {
874 let config = AdaptiveThresholdConfig {
875 initial_threshold: 0.5,
876 window_size: 100,
877 target_fpr: 0.1,
878 learning_rate: 0.1,
879 ..Default::default()
880 };
881
882 let mut state = AdaptiveThreshold::init(&config);
883
884 for _ in 0..50 {
886 let score = rng().random_range(0.0..0.4);
887 AdaptiveThreshold::process_score(&mut state, score, None, &config);
888 }
889
890 let result = AdaptiveThreshold::process_score(&mut state, 0.9, None, &config);
892 assert!(result.is_anomaly);
893 }
894
895 #[test]
896 fn test_adaptive_threshold_feedback() {
897 let config = AdaptiveThresholdConfig {
898 initial_threshold: 0.5,
899 learning_rate: 0.2,
900 ..Default::default()
901 };
902
903 let mut state = AdaptiveThreshold::init(&config);
904
905 let initial_threshold = state.threshold();
907 AdaptiveThreshold::process_score(&mut state, 0.6, Some(false), &config);
908 assert!(state.threshold() > initial_threshold);
909
910 let prev_threshold = state.threshold();
912 AdaptiveThreshold::process_score(&mut state, 0.3, Some(true), &config);
913 assert!(state.threshold() < prev_threshold);
914 }
915
916 #[test]
917 fn test_drift_detection() {
918 let config = AdaptiveThresholdConfig {
919 window_size: 10,
920 detect_drift: true,
921 drift_sensitivity: 1.5, ..Default::default()
923 };
924
925 let mut state = AdaptiveThreshold::init(&config);
926
927 for _ in 0..10 {
929 AdaptiveThreshold::process_score(&mut state, 0.15, None, &config);
930 }
931
932 let mut drift_found = false;
934 for _ in 0..15 {
935 let result = AdaptiveThreshold::process_score(&mut state, 0.85, None, &config);
936 if result.drift_detected {
937 drift_found = true;
938 }
939 }
940
941 assert!(
943 drift_found || state.drift_count() > 0,
944 "Should detect drift between 0.15 and 0.85 score ranges"
945 );
946 }
947
948 #[test]
949 fn test_batch_processing() {
950 let config = StreamingConfig::default();
951 let mut state = StreamingIsolationForest::init(2);
952
953 let data = DataMatrix::new(
954 vec![
955 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 10.0, 10.0, ],
957 4,
958 2,
959 );
960
961 let result = StreamingIsolationForest::process_batch(&mut state, &data, &config);
962 assert_eq!(result.scores.len(), 4);
963 assert_eq!(result.labels.len(), 4);
964 }
965
966 #[test]
967 fn test_online_stats() {
968 let mut stats = OnlineStats::new();
969
970 for v in [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
971 stats.update(v);
972 }
973
974 assert!((stats.mean - 5.0).abs() < 0.01);
975 assert!((stats.variance() - 4.57).abs() < 0.1);
976 assert_eq!(stats.min, 2.0);
977 assert_eq!(stats.max, 9.0);
978 }
979}