sklears_inspection/
streaming.rs

1//! Streaming explanation computation for large datasets
2//!
3//! This module provides streaming algorithms for explanation computation
4//! that can handle datasets larger than memory by processing them in chunks.
5
6use crate::memory::{CacheConfig, ExplanationCache};
7use crate::types::*;
8// ✅ SciRS2 Policy Compliant Imports
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
10use sklears_core::prelude::SklearsError;
11use std::collections::VecDeque;
12use std::sync::{Arc, Mutex};
13
14/// Configuration for streaming explanation computation
15#[derive(Clone, Debug)]
16pub struct StreamingConfig {
17    /// Chunk size for processing
18    pub chunk_size: usize,
19    /// Number of chunks to keep in memory
20    pub memory_chunks: usize,
21    /// Enable online aggregation
22    pub online_aggregation: bool,
23    /// Minimum chunk size to process
24    pub min_chunk_size: usize,
25    /// Maximum memory usage in MB
26    pub max_memory_mb: usize,
27}
28
29impl Default for StreamingConfig {
30    fn default() -> Self {
31        Self {
32            chunk_size: 1000,
33            memory_chunks: 3,
34            online_aggregation: true,
35            min_chunk_size: 100,
36            max_memory_mb: 512,
37        }
38    }
39}
40
41/// Streaming explanation processor
42pub struct StreamingExplainer {
43    /// Configuration
44    config: StreamingConfig,
45    /// Cache for repeated computations
46    cache: Arc<ExplanationCache>,
47    /// Chunk buffer
48    chunk_buffer: Arc<Mutex<VecDeque<Array2<Float>>>>,
49    /// Accumulated statistics
50    stats: Arc<Mutex<StreamingStatistics>>,
51}
52
53/// Statistics for streaming computation
54#[derive(Clone, Debug, Default)]
55pub struct StreamingStatistics {
56    /// Number of chunks processed
57    pub chunks_processed: usize,
58    /// Total samples processed
59    pub total_samples: usize,
60    /// Current memory usage in bytes
61    pub current_memory_usage: usize,
62    /// Peak memory usage in bytes
63    pub peak_memory_usage: usize,
64    /// Processing time per chunk
65    pub avg_chunk_time: f64,
66}
67
68/// Streaming explanation result
69#[derive(Clone, Debug)]
70pub struct StreamingExplanationResult {
71    /// Aggregated feature importance
72    pub feature_importance: Array1<Float>,
73    /// Confidence intervals
74    pub confidence_intervals: Array2<Float>,
75    /// Processing statistics
76    pub statistics: StreamingStatistics,
77    /// Number of chunks used
78    pub chunks_used: usize,
79}
80
81/// Online aggregator for streaming results
82pub struct OnlineAggregator {
83    /// Running sum of feature importance
84    running_sum: Array1<Float>,
85    /// Running sum of squared values
86    running_sum_squared: Array1<Float>,
87    /// Number of observations
88    count: usize,
89    /// Number of features
90    n_features: usize,
91}
92
93impl StreamingExplainer {
94    /// Create a new streaming explainer
95    pub fn new(config: StreamingConfig) -> Self {
96        let cache_config = CacheConfig {
97            max_cache_size_mb: config.max_memory_mb / 4, // Use 1/4 of memory for cache
98            ..Default::default()
99        };
100
101        Self {
102            config,
103            cache: Arc::new(ExplanationCache::new(&cache_config)),
104            chunk_buffer: Arc::new(Mutex::new(VecDeque::new())),
105            stats: Arc::new(Mutex::new(StreamingStatistics::default())),
106        }
107    }
108
109    /// Process data stream and compute explanations
110    pub fn process_stream<F, I>(
111        &self,
112        data_stream: I,
113        model: &F,
114    ) -> crate::SklResult<StreamingExplanationResult>
115    where
116        F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>> + Sync + Send,
117        I: Iterator<Item = Array2<Float>>,
118    {
119        let mut aggregator = None;
120        let mut chunks_processed = 0;
121        let start_time = std::time::Instant::now();
122
123        for chunk in data_stream {
124            if chunk.nrows() < self.config.min_chunk_size {
125                continue;
126            }
127
128            // Initialize aggregator with first chunk
129            if aggregator.is_none() {
130                aggregator = Some(OnlineAggregator::new(chunk.ncols()));
131            }
132
133            // Process chunk
134            let chunk_result = self.process_chunk(&chunk.view(), model)?;
135
136            // Update aggregator
137            if let Some(ref mut agg) = aggregator {
138                agg.update(&chunk_result)?;
139            }
140
141            chunks_processed += 1;
142
143            // Update statistics
144            {
145                let mut stats = self.stats.lock().unwrap();
146                stats.chunks_processed = chunks_processed;
147                stats.total_samples += chunk.nrows();
148                stats.current_memory_usage = self.estimate_memory_usage();
149                stats.peak_memory_usage = stats.peak_memory_usage.max(stats.current_memory_usage);
150                stats.avg_chunk_time = start_time.elapsed().as_secs_f64() / chunks_processed as f64;
151            }
152
153            // Manage memory
154            self.manage_memory()?;
155        }
156
157        // Finalize results
158        let aggregator = aggregator
159            .ok_or_else(|| SklearsError::InvalidInput("No valid chunks processed".to_string()))?;
160
161        let (feature_importance, confidence_intervals) = aggregator.finalize();
162        let statistics = self.stats.lock().unwrap().clone();
163
164        Ok(StreamingExplanationResult {
165            feature_importance,
166            confidence_intervals,
167            statistics,
168            chunks_used: chunks_processed,
169        })
170    }
171
172    /// Process a single chunk of data
173    fn process_chunk<F>(
174        &self,
175        chunk: &ArrayView2<Float>,
176        model: &F,
177    ) -> crate::SklResult<Array1<Float>>
178    where
179        F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>>,
180    {
181        // Use cache-friendly computation
182        crate::memory::cache_friendly_permutation_importance(
183            chunk,
184            &Array1::zeros(chunk.nrows()).view(), // Dummy y for feature importance
185            model,
186            &self.cache,
187            &CacheConfig::default(),
188        )
189    }
190
191    /// Estimate current memory usage
192    fn estimate_memory_usage(&self) -> usize {
193        let buffer_size = {
194            let buffer = self.chunk_buffer.lock().unwrap();
195            buffer
196                .iter()
197                .map(|chunk| chunk.len() * std::mem::size_of::<Float>())
198                .sum::<usize>()
199        };
200
201        let cache_size = self.cache.get_statistics().total_size;
202
203        buffer_size + cache_size
204    }
205
206    /// Manage memory usage by evicting old chunks
207    fn manage_memory(&self) -> crate::SklResult<()> {
208        let current_usage = self.estimate_memory_usage();
209        let max_usage = self.config.max_memory_mb * 1024 * 1024;
210
211        if current_usage > max_usage {
212            // Evict oldest chunks
213            let mut buffer = self.chunk_buffer.lock().unwrap();
214            while !buffer.is_empty() && self.estimate_memory_usage() > max_usage {
215                buffer.pop_front();
216            }
217
218            // Clear cache if still over limit
219            if self.estimate_memory_usage() > max_usage {
220                self.cache.clear_all();
221            }
222        }
223
224        Ok(())
225    }
226}
227
228impl OnlineAggregator {
229    /// Create a new online aggregator
230    pub fn new(n_features: usize) -> Self {
231        Self {
232            running_sum: Array1::zeros(n_features),
233            running_sum_squared: Array1::zeros(n_features),
234            count: 0,
235            n_features,
236        }
237    }
238
239    /// Update aggregator with new values
240    pub fn update(&mut self, values: &Array1<Float>) -> crate::SklResult<()> {
241        if values.len() != self.n_features {
242            return Err(SklearsError::InvalidInput(
243                "Feature dimension mismatch".to_string(),
244            ));
245        }
246
247        // Update running sums
248        self.running_sum += values;
249        self.running_sum_squared += &values.mapv(|x| x * x);
250        self.count += 1;
251
252        Ok(())
253    }
254
255    /// Finalize aggregation and return mean and confidence intervals
256    pub fn finalize(self) -> (Array1<Float>, Array2<Float>) {
257        if self.count == 0 {
258            return (
259                Array1::zeros(self.n_features),
260                Array2::zeros((self.n_features, 2)),
261            );
262        }
263
264        let count_f = self.count as Float;
265        let mean = &self.running_sum / count_f;
266
267        // Compute standard deviation
268        let variance = (&self.running_sum_squared / count_f) - mean.mapv(|x| x * x);
269        let std_dev = variance.mapv(|x| x.sqrt());
270
271        // Compute 95% confidence intervals
272        let t_value = 1.96; // For large samples, approximating t-distribution with normal
273        let stderr = &std_dev / (count_f.sqrt());
274        let margin = &stderr * t_value;
275
276        let mut confidence_intervals = Array2::zeros((self.n_features, 2));
277        for i in 0..self.n_features {
278            confidence_intervals[(i, 0)] = mean[i] - margin[i]; // Lower bound
279            confidence_intervals[(i, 1)] = mean[i] + margin[i]; // Upper bound
280        }
281
282        (mean, confidence_intervals)
283    }
284}
285
286/// Streaming SHAP computation
287pub struct StreamingShapExplainer {
288    /// Base configuration
289    config: StreamingConfig,
290    /// Sample buffer for baseline computation
291    sample_buffer: Arc<Mutex<VecDeque<Array1<Float>>>>,
292    /// Background statistics
293    background_stats: Arc<Mutex<BackgroundStatistics>>,
294}
295
296/// Background statistics for SHAP computation
297#[derive(Clone, Debug, Default)]
298pub struct BackgroundStatistics {
299    /// Feature means
300    pub feature_means: Array1<Float>,
301    /// Feature standard deviations
302    pub feature_stds: Array1<Float>,
303    /// Number of samples seen
304    pub samples_seen: usize,
305}
306
307impl StreamingShapExplainer {
308    /// Create a new streaming SHAP explainer
309    pub fn new(config: StreamingConfig) -> Self {
310        Self {
311            config,
312            sample_buffer: Arc::new(Mutex::new(VecDeque::new())),
313            background_stats: Arc::new(Mutex::new(BackgroundStatistics::default())),
314        }
315    }
316
317    /// Compute SHAP values for a stream of data
318    pub fn compute_shap_stream<F, I>(
319        &self,
320        data_stream: I,
321        model: &F,
322    ) -> crate::SklResult<StreamingExplanationResult>
323    where
324        F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>> + Sync + Send,
325        I: Iterator<Item = Array2<Float>>,
326    {
327        let mut aggregator = None;
328        let mut chunks_processed = 0;
329
330        for chunk in data_stream {
331            if chunk.nrows() < self.config.min_chunk_size {
332                continue;
333            }
334
335            // Update background statistics
336            self.update_background_stats(&chunk.view())?;
337
338            // Initialize aggregator
339            if aggregator.is_none() {
340                aggregator = Some(OnlineAggregator::new(chunk.ncols()));
341            }
342
343            // Compute SHAP values for chunk
344            let shap_values = self.compute_chunk_shap(&chunk.view(), model)?;
345
346            // Aggregate results
347            if let Some(ref mut agg) = aggregator {
348                let mean_shap = shap_values.mean_axis(Axis(0)).unwrap();
349                agg.update(&mean_shap)?;
350            }
351
352            chunks_processed += 1;
353        }
354
355        // Finalize results
356        let aggregator = aggregator
357            .ok_or_else(|| SklearsError::InvalidInput("No valid chunks processed".to_string()))?;
358
359        let (feature_importance, confidence_intervals) = aggregator.finalize();
360
361        Ok(StreamingExplanationResult {
362            feature_importance,
363            confidence_intervals,
364            statistics: StreamingStatistics {
365                chunks_processed,
366                total_samples: chunks_processed * self.config.chunk_size,
367                ..Default::default()
368            },
369            chunks_used: chunks_processed,
370        })
371    }
372
373    /// Update background statistics with new data
374    fn update_background_stats(&self, chunk: &ArrayView2<Float>) -> crate::SklResult<()> {
375        let mut stats = self.background_stats.lock().unwrap();
376
377        if stats.samples_seen == 0 {
378            // Initialize with first chunk
379            stats.feature_means = chunk.mean_axis(Axis(0)).ok_or_else(|| {
380                SklearsError::InvalidInput("Cannot compute feature means".to_string())
381            })?;
382            stats.feature_stds = chunk.std_axis(Axis(0), 0.0);
383            stats.samples_seen = chunk.nrows();
384        } else {
385            // Update statistics incrementally
386            let chunk_means = chunk.mean_axis(Axis(0)).ok_or_else(|| {
387                SklearsError::InvalidInput("Cannot compute feature means".to_string())
388            })?;
389
390            let total_samples = stats.samples_seen + chunk.nrows();
391            let weight_old = stats.samples_seen as Float / total_samples as Float;
392            let weight_new = chunk.nrows() as Float / total_samples as Float;
393
394            // Update means
395            stats.feature_means = &stats.feature_means * weight_old + &chunk_means * weight_new;
396            stats.samples_seen = total_samples;
397        }
398
399        Ok(())
400    }
401
402    /// Compute SHAP values for a single chunk
403    fn compute_chunk_shap<F>(
404        &self,
405        chunk: &ArrayView2<Float>,
406        model: &F,
407    ) -> crate::SklResult<Array2<Float>>
408    where
409        F: Fn(&ArrayView2<Float>) -> crate::SklResult<Array1<Float>>,
410    {
411        let n_samples = chunk.nrows();
412        let n_features = chunk.ncols();
413
414        // Get background statistics
415        let background_means = {
416            let stats = self.background_stats.lock().unwrap();
417            stats.feature_means.clone()
418        };
419
420        // Compute simplified SHAP values
421        let mut shap_values = Array2::zeros((n_samples, n_features));
422
423        for sample_idx in 0..n_samples {
424            let sample = chunk.row(sample_idx);
425
426            // Baseline prediction (using background means)
427            let baseline_data = background_means.clone().insert_axis(Axis(0));
428            let baseline_pred = model(&baseline_data.view())?;
429            let baseline_value = baseline_pred[0];
430
431            // Full prediction
432            let full_pred = model(&sample.insert_axis(Axis(0)))?;
433            let full_value = full_pred[0];
434
435            // Compute marginal contributions
436            let total_contribution = full_value - baseline_value;
437
438            // Simple attribution: proportional to deviation from baseline
439            let deviations = &sample.to_owned() - &background_means;
440            let total_deviation = deviations.mapv(|x| x.abs()).sum();
441
442            if total_deviation > 0.0 {
443                for feature_idx in 0..n_features {
444                    let feature_contrib = if total_deviation > 0.0 {
445                        total_contribution * (deviations[feature_idx].abs() / total_deviation)
446                    } else {
447                        total_contribution / n_features as Float
448                    };
449
450                    shap_values[(sample_idx, feature_idx)] = feature_contrib;
451                }
452            }
453        }
454
455        Ok(shap_values)
456    }
457}
458
459/// Utility function to create data chunks from large arrays
460pub fn create_data_chunks(data: &ArrayView2<Float>, chunk_size: usize) -> Vec<Array2<Float>> {
461    let mut chunks = Vec::new();
462    let n_samples = data.nrows();
463
464    for start in (0..n_samples).step_by(chunk_size) {
465        let end = (start + chunk_size).min(n_samples);
466        let chunk = data.slice(s![start..end, ..]).to_owned();
467        chunks.push(chunk);
468    }
469
470    chunks
471}
472
473/// Streaming data iterator for file-based processing
474pub struct StreamingDataIterator {
475    /// Current position in data
476    position: usize,
477    /// Data source
478    data: Array2<Float>,
479    /// Chunk size
480    chunk_size: usize,
481}
482
483impl StreamingDataIterator {
484    /// Create a new streaming data iterator
485    pub fn new(data: Array2<Float>, chunk_size: usize) -> Self {
486        Self {
487            position: 0,
488            data,
489            chunk_size,
490        }
491    }
492}
493
494impl Iterator for StreamingDataIterator {
495    type Item = Array2<Float>;
496
497    fn next(&mut self) -> Option<Self::Item> {
498        if self.position >= self.data.nrows() {
499            return None;
500        }
501
502        let end = (self.position + self.chunk_size).min(self.data.nrows());
503        let chunk = self.data.slice(s![self.position..end, ..]).to_owned();
504        self.position = end;
505
506        Some(chunk)
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use approx::assert_abs_diff_eq;
514    // ✅ SciRS2 Policy Compliant Import
515    use scirs2_core::ndarray::array;
516
517    #[test]
518    fn test_streaming_config_default() {
519        let config = StreamingConfig::default();
520        assert_eq!(config.chunk_size, 1000);
521        assert_eq!(config.memory_chunks, 3);
522        assert!(config.online_aggregation);
523    }
524
525    #[test]
526    fn test_online_aggregator() {
527        let mut aggregator = OnlineAggregator::new(2);
528
529        // Add some values
530        aggregator.update(&array![1.0, 2.0]).unwrap();
531        aggregator.update(&array![3.0, 4.0]).unwrap();
532
533        let (mean, confidence_intervals) = aggregator.finalize();
534
535        assert_abs_diff_eq!(mean[0], 2.0, epsilon = 1e-6);
536        assert_abs_diff_eq!(mean[1], 3.0, epsilon = 1e-6);
537        assert_eq!(confidence_intervals.shape(), &[2, 2]);
538    }
539
540    #[test]
541    fn test_streaming_explainer_creation() {
542        let config = StreamingConfig::default();
543        let explainer = StreamingExplainer::new(config);
544
545        let stats = explainer.stats.lock().unwrap();
546        assert_eq!(stats.chunks_processed, 0);
547        assert_eq!(stats.total_samples, 0);
548    }
549
550    #[test]
551    fn test_create_data_chunks() {
552        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
553        let chunks = create_data_chunks(&data.view(), 2);
554
555        assert_eq!(chunks.len(), 2);
556        assert_eq!(chunks[0].nrows(), 2);
557        assert_eq!(chunks[1].nrows(), 2);
558    }
559
560    #[test]
561    fn test_streaming_data_iterator() {
562        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
563        let mut iterator = StreamingDataIterator::new(data, 2);
564
565        let chunk1 = iterator.next().unwrap();
566        assert_eq!(chunk1.nrows(), 2);
567
568        let chunk2 = iterator.next().unwrap();
569        assert_eq!(chunk2.nrows(), 1);
570
571        assert!(iterator.next().is_none());
572    }
573
574    #[test]
575    fn test_streaming_shap_explainer() {
576        let config = StreamingConfig::default();
577        let explainer = StreamingShapExplainer::new(config);
578
579        let stats = explainer.background_stats.lock().unwrap();
580        assert_eq!(stats.samples_seen, 0);
581    }
582
583    #[test]
584    fn test_background_statistics_update() {
585        let config = StreamingConfig::default();
586        let explainer = StreamingShapExplainer::new(config);
587
588        let chunk = array![[1.0, 2.0], [3.0, 4.0]];
589        explainer.update_background_stats(&chunk.view()).unwrap();
590
591        let stats = explainer.background_stats.lock().unwrap();
592        assert_eq!(stats.samples_seen, 2);
593        assert_abs_diff_eq!(stats.feature_means[0], 2.0, epsilon = 1e-6);
594        assert_abs_diff_eq!(stats.feature_means[1], 3.0, epsilon = 1e-6);
595    }
596
597    #[test]
598    fn test_streaming_statistics_default() {
599        let stats = StreamingStatistics::default();
600        assert_eq!(stats.chunks_processed, 0);
601        assert_eq!(stats.total_samples, 0);
602        assert_eq!(stats.current_memory_usage, 0);
603    }
604
605    #[test]
606    fn test_streaming_explanation_result() {
607        let result = StreamingExplanationResult {
608            feature_importance: array![0.5, 0.3],
609            confidence_intervals: array![[0.4, 0.6], [0.2, 0.4]],
610            statistics: StreamingStatistics::default(),
611            chunks_used: 3,
612        };
613
614        assert_eq!(result.feature_importance.len(), 2);
615        assert_eq!(result.confidence_intervals.shape(), &[2, 2]);
616        assert_eq!(result.chunks_used, 3);
617    }
618
619    #[test]
620    fn test_process_chunk_computation() {
621        let config = StreamingConfig::default();
622        let explainer = StreamingExplainer::new(config);
623
624        let chunk = array![[1.0, 2.0], [3.0, 4.0]];
625        let model =
626            |_: &ArrayView2<Float>| -> crate::SklResult<Array1<Float>> { Ok(array![0.5, 0.7]) };
627
628        let result = explainer.process_chunk(&chunk.view(), &model).unwrap();
629        assert_eq!(result.len(), 2);
630    }
631}