1use crate::memory::{CacheConfig, ExplanationCache};
7use crate::types::*;
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
10use sklears_core::prelude::SklearsError;
11use std::collections::VecDeque;
12use std::sync::{Arc, Mutex};
13
14#[derive(Clone, Debug)]
16pub struct StreamingConfig {
17 pub chunk_size: usize,
19 pub memory_chunks: usize,
21 pub online_aggregation: bool,
23 pub min_chunk_size: usize,
25 pub max_memory_mb: usize,
27}
28
29impl Default for StreamingConfig {
30 fn default() -> Self {
31 Self {
32 chunk_size: 1000,
33 memory_chunks: 3,
34 online_aggregation: true,
35 min_chunk_size: 100,
36 max_memory_mb: 512,
37 }
38 }
39}
40
41pub struct StreamingExplainer {
43 config: StreamingConfig,
45 cache: Arc<ExplanationCache>,
47 chunk_buffer: Arc<Mutex<VecDeque<Array2<Float>>>>,
49 stats: Arc<Mutex<StreamingStatistics>>,
51}
52
53#[derive(Clone, Debug, Default)]
55pub struct StreamingStatistics {
56 pub chunks_processed: usize,
58 pub total_samples: usize,
60 pub current_memory_usage: usize,
62 pub peak_memory_usage: usize,
64 pub avg_chunk_time: f64,
66}
67
68#[derive(Clone, Debug)]
70pub struct StreamingExplanationResult {
71 pub feature_importance: Array1<Float>,
73 pub confidence_intervals: Array2<Float>,
75 pub statistics: StreamingStatistics,
77 pub chunks_used: usize,
79}
80
81pub struct OnlineAggregator {
83 running_sum: Array1<Float>,
85 running_sum_squared: Array1<Float>,
87 count: usize,
89 n_features: usize,
91}
92
93impl StreamingExplainer {
94 pub fn new(config: StreamingConfig) -> Self {
96 let cache_config = CacheConfig {
97 max_cache_size_mb: config.max_memory_mb / 4, ..Default::default()
99 };
100
101 Self {
102 config,
103 cache: Arc::new(ExplanationCache::new(&cache_config)),
104 chunk_buffer: Arc::new(Mutex::new(VecDeque::new())),
105 stats: Arc::new(Mutex::new(StreamingStatistics::default())),
106 }
107 }
108
109 pub fn process_stream<F, I>(
111 &self,
112 data_stream: I,
113 model: &F,
114 ) -> crate::SklResult<StreamingExplanationResult>
115 where
116 F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>> + Sync + Send,
117 I: Iterator<Item = Array2<Float>>,
118 {
119 let mut aggregator = None;
120 let mut chunks_processed = 0;
121 let start_time = std::time::Instant::now();
122
123 for chunk in data_stream {
124 if chunk.nrows() < self.config.min_chunk_size {
125 continue;
126 }
127
128 if aggregator.is_none() {
130 aggregator = Some(OnlineAggregator::new(chunk.ncols()));
131 }
132
133 let chunk_result = self.process_chunk(&chunk.view(), model)?;
135
136 if let Some(ref mut agg) = aggregator {
138 agg.update(&chunk_result)?;
139 }
140
141 chunks_processed += 1;
142
143 {
145 let mut stats = self.stats.lock().unwrap();
146 stats.chunks_processed = chunks_processed;
147 stats.total_samples += chunk.nrows();
148 stats.current_memory_usage = self.estimate_memory_usage();
149 stats.peak_memory_usage = stats.peak_memory_usage.max(stats.current_memory_usage);
150 stats.avg_chunk_time = start_time.elapsed().as_secs_f64() / chunks_processed as f64;
151 }
152
153 self.manage_memory()?;
155 }
156
157 let aggregator = aggregator
159 .ok_or_else(|| SklearsError::InvalidInput("No valid chunks processed".to_string()))?;
160
161 let (feature_importance, confidence_intervals) = aggregator.finalize();
162 let statistics = self.stats.lock().unwrap().clone();
163
164 Ok(StreamingExplanationResult {
165 feature_importance,
166 confidence_intervals,
167 statistics,
168 chunks_used: chunks_processed,
169 })
170 }
171
172 fn process_chunk<F>(
174 &self,
175 chunk: &ArrayView2<Float>,
176 model: &F,
177 ) -> crate::SklResult<Array1<Float>>
178 where
179 F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>>,
180 {
181 crate::memory::cache_friendly_permutation_importance(
183 chunk,
184 &Array1::zeros(chunk.nrows()).view(), model,
186 &self.cache,
187 &CacheConfig::default(),
188 )
189 }
190
191 fn estimate_memory_usage(&self) -> usize {
193 let buffer_size = {
194 let buffer = self.chunk_buffer.lock().unwrap();
195 buffer
196 .iter()
197 .map(|chunk| chunk.len() * std::mem::size_of::<Float>())
198 .sum::<usize>()
199 };
200
201 let cache_size = self.cache.get_statistics().total_size;
202
203 buffer_size + cache_size
204 }
205
206 fn manage_memory(&self) -> crate::SklResult<()> {
208 let current_usage = self.estimate_memory_usage();
209 let max_usage = self.config.max_memory_mb * 1024 * 1024;
210
211 if current_usage > max_usage {
212 let mut buffer = self.chunk_buffer.lock().unwrap();
214 while !buffer.is_empty() && self.estimate_memory_usage() > max_usage {
215 buffer.pop_front();
216 }
217
218 if self.estimate_memory_usage() > max_usage {
220 self.cache.clear_all();
221 }
222 }
223
224 Ok(())
225 }
226}
227
228impl OnlineAggregator {
229 pub fn new(n_features: usize) -> Self {
231 Self {
232 running_sum: Array1::zeros(n_features),
233 running_sum_squared: Array1::zeros(n_features),
234 count: 0,
235 n_features,
236 }
237 }
238
239 pub fn update(&mut self, values: &Array1<Float>) -> crate::SklResult<()> {
241 if values.len() != self.n_features {
242 return Err(SklearsError::InvalidInput(
243 "Feature dimension mismatch".to_string(),
244 ));
245 }
246
247 self.running_sum += values;
249 self.running_sum_squared += &values.mapv(|x| x * x);
250 self.count += 1;
251
252 Ok(())
253 }
254
255 pub fn finalize(self) -> (Array1<Float>, Array2<Float>) {
257 if self.count == 0 {
258 return (
259 Array1::zeros(self.n_features),
260 Array2::zeros((self.n_features, 2)),
261 );
262 }
263
264 let count_f = self.count as Float;
265 let mean = &self.running_sum / count_f;
266
267 let variance = (&self.running_sum_squared / count_f) - mean.mapv(|x| x * x);
269 let std_dev = variance.mapv(|x| x.sqrt());
270
271 let t_value = 1.96; let stderr = &std_dev / (count_f.sqrt());
274 let margin = &stderr * t_value;
275
276 let mut confidence_intervals = Array2::zeros((self.n_features, 2));
277 for i in 0..self.n_features {
278 confidence_intervals[(i, 0)] = mean[i] - margin[i]; confidence_intervals[(i, 1)] = mean[i] + margin[i]; }
281
282 (mean, confidence_intervals)
283 }
284}
285
286pub struct StreamingShapExplainer {
288 config: StreamingConfig,
290 sample_buffer: Arc<Mutex<VecDeque<Array1<Float>>>>,
292 background_stats: Arc<Mutex<BackgroundStatistics>>,
294}
295
296#[derive(Clone, Debug, Default)]
298pub struct BackgroundStatistics {
299 pub feature_means: Array1<Float>,
301 pub feature_stds: Array1<Float>,
303 pub samples_seen: usize,
305}
306
307impl StreamingShapExplainer {
308 pub fn new(config: StreamingConfig) -> Self {
310 Self {
311 config,
312 sample_buffer: Arc::new(Mutex::new(VecDeque::new())),
313 background_stats: Arc::new(Mutex::new(BackgroundStatistics::default())),
314 }
315 }
316
317 pub fn compute_shap_stream<F, I>(
319 &self,
320 data_stream: I,
321 model: &F,
322 ) -> crate::SklResult<StreamingExplanationResult>
323 where
324 F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>> + Sync + Send,
325 I: Iterator<Item = Array2<Float>>,
326 {
327 let mut aggregator = None;
328 let mut chunks_processed = 0;
329
330 for chunk in data_stream {
331 if chunk.nrows() < self.config.min_chunk_size {
332 continue;
333 }
334
335 self.update_background_stats(&chunk.view())?;
337
338 if aggregator.is_none() {
340 aggregator = Some(OnlineAggregator::new(chunk.ncols()));
341 }
342
343 let shap_values = self.compute_chunk_shap(&chunk.view(), model)?;
345
346 if let Some(ref mut agg) = aggregator {
348 let mean_shap = shap_values.mean_axis(Axis(0)).unwrap();
349 agg.update(&mean_shap)?;
350 }
351
352 chunks_processed += 1;
353 }
354
355 let aggregator = aggregator
357 .ok_or_else(|| SklearsError::InvalidInput("No valid chunks processed".to_string()))?;
358
359 let (feature_importance, confidence_intervals) = aggregator.finalize();
360
361 Ok(StreamingExplanationResult {
362 feature_importance,
363 confidence_intervals,
364 statistics: StreamingStatistics {
365 chunks_processed,
366 total_samples: chunks_processed * self.config.chunk_size,
367 ..Default::default()
368 },
369 chunks_used: chunks_processed,
370 })
371 }
372
373 fn update_background_stats(&self, chunk: &ArrayView2<Float>) -> crate::SklResult<()> {
375 let mut stats = self.background_stats.lock().unwrap();
376
377 if stats.samples_seen == 0 {
378 stats.feature_means = chunk.mean_axis(Axis(0)).ok_or_else(|| {
380 SklearsError::InvalidInput("Cannot compute feature means".to_string())
381 })?;
382 stats.feature_stds = chunk.std_axis(Axis(0), 0.0);
383 stats.samples_seen = chunk.nrows();
384 } else {
385 let chunk_means = chunk.mean_axis(Axis(0)).ok_or_else(|| {
387 SklearsError::InvalidInput("Cannot compute feature means".to_string())
388 })?;
389
390 let total_samples = stats.samples_seen + chunk.nrows();
391 let weight_old = stats.samples_seen as Float / total_samples as Float;
392 let weight_new = chunk.nrows() as Float / total_samples as Float;
393
394 stats.feature_means = &stats.feature_means * weight_old + &chunk_means * weight_new;
396 stats.samples_seen = total_samples;
397 }
398
399 Ok(())
400 }
401
402 fn compute_chunk_shap<F>(
404 &self,
405 chunk: &ArrayView2<Float>,
406 model: &F,
407 ) -> crate::SklResult<Array2<Float>>
408 where
409 F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>>,
410 {
411 let n_samples = chunk.nrows();
412 let n_features = chunk.ncols();
413
414 let background_means = {
416 let stats = self.background_stats.lock().unwrap();
417 stats.feature_means.clone()
418 };
419
420 let mut shap_values = Array2::zeros((n_samples, n_features));
422
423 for sample_idx in 0..n_samples {
424 let sample = chunk.row(sample_idx);
425
426 let baseline_data = background_means.clone().insert_axis(Axis(0));
428 let baseline_pred = model(&baseline_data.view())?;
429 let baseline_value = baseline_pred[0];
430
431 let full_pred = model(&sample.insert_axis(Axis(0)))?;
433 let full_value = full_pred[0];
434
435 let total_contribution = full_value - baseline_value;
437
438 let deviations = &sample.to_owned() - &background_means;
440 let total_deviation = deviations.mapv(|x| x.abs()).sum();
441
442 if total_deviation > 0.0 {
443 for feature_idx in 0..n_features {
444 let feature_contrib = if total_deviation > 0.0 {
445 total_contribution * (deviations[feature_idx].abs() / total_deviation)
446 } else {
447 total_contribution / n_features as Float
448 };
449
450 shap_values[(sample_idx, feature_idx)] = feature_contrib;
451 }
452 }
453 }
454
455 Ok(shap_values)
456 }
457}
458
459pub fn create_data_chunks(data: &ArrayView2<Float>, chunk_size: usize) -> Vec<Array2<Float>> {
461 let mut chunks = Vec::new();
462 let n_samples = data.nrows();
463
464 for start in (0..n_samples).step_by(chunk_size) {
465 let end = (start + chunk_size).min(n_samples);
466 let chunk = data.slice(s![start..end, ..]).to_owned();
467 chunks.push(chunk);
468 }
469
470 chunks
471}
472
473pub struct StreamingDataIterator {
475 position: usize,
477 data: Array2<Float>,
479 chunk_size: usize,
481}
482
483impl StreamingDataIterator {
484 pub fn new(data: Array2<Float>, chunk_size: usize) -> Self {
486 Self {
487 position: 0,
488 data,
489 chunk_size,
490 }
491 }
492}
493
494impl Iterator for StreamingDataIterator {
495 type Item = Array2<Float>;
496
497 fn next(&mut self) -> Option<Self::Item> {
498 if self.position >= self.data.nrows() {
499 return None;
500 }
501
502 let end = (self.position + self.chunk_size).min(self.data.nrows());
503 let chunk = self.data.slice(s![self.position..end, ..]).to_owned();
504 self.position = end;
505
506 Some(chunk)
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513 use approx::assert_abs_diff_eq;
514 use scirs2_core::ndarray::array;
516
517 #[test]
518 fn test_streaming_config_default() {
519 let config = StreamingConfig::default();
520 assert_eq!(config.chunk_size, 1000);
521 assert_eq!(config.memory_chunks, 3);
522 assert!(config.online_aggregation);
523 }
524
525 #[test]
526 fn test_online_aggregator() {
527 let mut aggregator = OnlineAggregator::new(2);
528
529 aggregator.update(&array![1.0, 2.0]).unwrap();
531 aggregator.update(&array![3.0, 4.0]).unwrap();
532
533 let (mean, confidence_intervals) = aggregator.finalize();
534
535 assert_abs_diff_eq!(mean[0], 2.0, epsilon = 1e-6);
536 assert_abs_diff_eq!(mean[1], 3.0, epsilon = 1e-6);
537 assert_eq!(confidence_intervals.shape(), &[2, 2]);
538 }
539
540 #[test]
541 fn test_streaming_explainer_creation() {
542 let config = StreamingConfig::default();
543 let explainer = StreamingExplainer::new(config);
544
545 let stats = explainer.stats.lock().unwrap();
546 assert_eq!(stats.chunks_processed, 0);
547 assert_eq!(stats.total_samples, 0);
548 }
549
550 #[test]
551 fn test_create_data_chunks() {
552 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
553 let chunks = create_data_chunks(&data.view(), 2);
554
555 assert_eq!(chunks.len(), 2);
556 assert_eq!(chunks[0].nrows(), 2);
557 assert_eq!(chunks[1].nrows(), 2);
558 }
559
560 #[test]
561 fn test_streaming_data_iterator() {
562 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
563 let mut iterator = StreamingDataIterator::new(data, 2);
564
565 let chunk1 = iterator.next().unwrap();
566 assert_eq!(chunk1.nrows(), 2);
567
568 let chunk2 = iterator.next().unwrap();
569 assert_eq!(chunk2.nrows(), 1);
570
571 assert!(iterator.next().is_none());
572 }
573
574 #[test]
575 fn test_streaming_shap_explainer() {
576 let config = StreamingConfig::default();
577 let explainer = StreamingShapExplainer::new(config);
578
579 let stats = explainer.background_stats.lock().unwrap();
580 assert_eq!(stats.samples_seen, 0);
581 }
582
583 #[test]
584 fn test_background_statistics_update() {
585 let config = StreamingConfig::default();
586 let explainer = StreamingShapExplainer::new(config);
587
588 let chunk = array![[1.0, 2.0], [3.0, 4.0]];
589 explainer.update_background_stats(&chunk.view()).unwrap();
590
591 let stats = explainer.background_stats.lock().unwrap();
592 assert_eq!(stats.samples_seen, 2);
593 assert_abs_diff_eq!(stats.feature_means[0], 2.0, epsilon = 1e-6);
594 assert_abs_diff_eq!(stats.feature_means[1], 3.0, epsilon = 1e-6);
595 }
596
597 #[test]
598 fn test_streaming_statistics_default() {
599 let stats = StreamingStatistics::default();
600 assert_eq!(stats.chunks_processed, 0);
601 assert_eq!(stats.total_samples, 0);
602 assert_eq!(stats.current_memory_usage, 0);
603 }
604
605 #[test]
606 fn test_streaming_explanation_result() {
607 let result = StreamingExplanationResult {
608 feature_importance: array![0.5, 0.3],
609 confidence_intervals: array![[0.4, 0.6], [0.2, 0.4]],
610 statistics: StreamingStatistics::default(),
611 chunks_used: 3,
612 };
613
614 assert_eq!(result.feature_importance.len(), 2);
615 assert_eq!(result.confidence_intervals.shape(), &[2, 2]);
616 assert_eq!(result.chunks_used, 3);
617 }
618
619 #[test]
620 fn test_process_chunk_computation() {
621 let config = StreamingConfig::default();
622 let explainer = StreamingExplainer::new(config);
623
624 let chunk = array![[1.0, 2.0], [3.0, 4.0]];
625 let model =
626 |_: &ArrayView2<Float>| -> crate::SklResult<Array1<Float>> { Ok(array![0.5, 0.7]) };
627
628 let result = explainer.process_chunk(&chunk.view(), &model).unwrap();
629 assert_eq!(result.len(), 2);
630 }
631}