1#[cfg(feature = "parallel")]
7use rayon::prelude::*;
8
9#[cfg(feature = "parallel")]
10use num_cpus;
11
12use crate::{Float, SklResult};
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
15use scirs2_core::random::{seq::SliceRandom, SeedableRng};
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
20pub struct ParallelConfig {
21 pub n_threads: Option<usize>,
23 pub min_batch_size: usize,
25 pub force_sequential: bool,
27}
28
29impl Default for ParallelConfig {
30 fn default() -> Self {
31 Self {
32 n_threads: None,
33 min_batch_size: 100,
34 force_sequential: false,
35 }
36 }
37}
38
39impl ParallelConfig {
40 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn with_threads(mut self, n_threads: usize) -> Self {
47 self.n_threads = Some(n_threads);
48 self
49 }
50
51 pub fn with_min_batch_size(mut self, min_batch_size: usize) -> Self {
53 self.min_batch_size = min_batch_size;
54 self
55 }
56
57 pub fn sequential(mut self) -> Self {
59 self.force_sequential = true;
60 self
61 }
62
63 pub fn get_n_threads(&self) -> usize {
65 if self.force_sequential {
66 return 1;
67 }
68
69 #[cfg(feature = "parallel")]
70 {
71 self.n_threads.unwrap_or_else(num_cpus::get)
72 }
73 #[cfg(not(feature = "parallel"))]
74 {
75 1
76 }
77 }
78
79 pub fn should_parallelize(&self, data_size: usize) -> bool {
81 !self.force_sequential && data_size >= self.min_batch_size
82 }
83}
84
85pub trait ParallelExplanation {
87 type Input;
88 type Output;
89 type Config;
90
91 fn compute_single(&self, input: &Self::Input, config: &Self::Config)
93 -> SklResult<Self::Output>;
94
95 fn compute_parallel(
97 &self,
98 inputs: &[Self::Input],
99 config: &Self::Config,
100 parallel_config: &ParallelConfig,
101 ) -> SklResult<Vec<Self::Output>>
102 where
103 Self: Sync,
104 Self::Input: Sync,
105 Self::Config: Sync,
106 Self::Output: Send,
107 {
108 if parallel_config.should_parallelize(inputs.len()) {
109 self.compute_parallel_impl(inputs, config, parallel_config)
110 } else {
111 self.compute_sequential(inputs, config)
112 }
113 }
114
115 fn compute_sequential(
117 &self,
118 inputs: &[Self::Input],
119 config: &Self::Config,
120 ) -> SklResult<Vec<Self::Output>> {
121 inputs
122 .iter()
123 .map(|input| self.compute_single(input, config))
124 .collect()
125 }
126
127 #[cfg(feature = "parallel")]
129 fn compute_parallel_impl(
130 &self,
131 inputs: &[Self::Input],
132 config: &Self::Config,
133 _parallel_config: &ParallelConfig,
134 ) -> SklResult<Vec<Self::Output>>
135 where
136 Self: Sync,
137 Self::Input: Sync,
138 Self::Config: Sync,
139 Self::Output: Send,
140 {
141 inputs
142 .par_iter()
143 .map(|input| self.compute_single(input, config))
144 .collect()
145 }
146
147 #[cfg(not(feature = "parallel"))]
149 fn compute_parallel_impl(
150 &self,
151 inputs: &[Self::Input],
152 config: &Self::Config,
153 _parallel_config: &ParallelConfig,
154 ) -> SklResult<Vec<Self::Output>> {
155 self.compute_sequential(inputs, config)
156 }
157}
158
159pub struct ParallelPermutationImportance<F> {
161 model: Arc<F>,
162 scoring_fn: fn(&ArrayView1<Float>, &ArrayView1<Float>) -> Float,
163}
164
165impl<F> ParallelPermutationImportance<F> {
166 pub fn new(
168 model: Arc<F>,
169 scoring_fn: fn(&ArrayView1<Float>, &ArrayView1<Float>) -> Float,
170 ) -> Self {
171 Self { model, scoring_fn }
172 }
173}
174
175#[derive(Debug, Clone)]
177pub struct PermutationInput {
178 pub feature_idx: usize,
180 pub x_data: Array2<Float>,
182 pub y_true: Array1<Float>,
184 pub n_repeats: usize,
186 pub random_state: u64,
188}
189
190impl<F> ParallelExplanation for ParallelPermutationImportance<F>
191where
192 F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>> + Send + Sync,
193{
194 type Input = PermutationInput;
195 type Output = Vec<Float>;
196 type Config = ();
197
198 fn compute_single(
199 &self,
200 input: &Self::Input,
201 _config: &Self::Config,
202 ) -> SklResult<Self::Output> {
203 let mut importances = Vec::with_capacity(input.n_repeats);
204 let mut rng = scirs2_core::random::ChaCha8Rng::seed_from_u64(input.random_state);
205
206 let y_pred_baseline = (self.model)(&input.x_data.view())?;
208 let baseline_score = (self.scoring_fn)(&input.y_true.view(), &y_pred_baseline.view());
209
210 for _ in 0..input.n_repeats {
211 let mut x_permuted = input.x_data.clone();
212 let mut column = x_permuted.column_mut(input.feature_idx);
213
214 let mut indices: Vec<usize> = (0..column.len()).collect();
216 indices.shuffle(&mut rng);
217
218 let original_values: Vec<Float> = column.to_vec();
219 for (i, &new_idx) in indices.iter().enumerate() {
220 column[i] = original_values[new_idx];
221 }
222
223 let y_pred_permuted = (self.model)(&x_permuted.view())?;
225 let permuted_score = (self.scoring_fn)(&input.y_true.view(), &y_pred_permuted.view());
226
227 importances.push(baseline_score - permuted_score);
228 }
229
230 Ok(importances)
231 }
232}
233
234#[cfg(feature = "parallel")]
236pub fn compute_shap_parallel<F>(
237 model: F,
238 X: &ArrayView2<Float>,
239 baseline: &ArrayView1<Float>,
240 config: &ParallelConfig,
241) -> SklResult<Array2<Float>>
242where
243 F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>> + Send + Sync + Clone,
244{
245 let n_samples = X.nrows();
246 let n_features = X.ncols();
247
248 if !config.should_parallelize(n_samples) {
249 return compute_shap_sequential(model, X, baseline);
250 }
251
252 let results: SklResult<Vec<_>> = (0..n_samples)
253 .into_par_iter()
254 .map(|i| {
255 let instance = X.row(i);
256 compute_shap_single_instance(model.clone(), &instance, baseline)
257 })
258 .collect();
259
260 let shap_values = results?;
261 let mut result = Array2::zeros((n_samples, n_features));
262 for (i, values) in shap_values.into_iter().enumerate() {
263 result.row_mut(i).assign(&values);
264 }
265
266 Ok(result)
267}
268
269fn compute_shap_sequential<F>(
271 model: F,
272 X: &ArrayView2<Float>,
273 baseline: &ArrayView1<Float>,
274) -> SklResult<Array2<Float>>
275where
276 F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>> + Clone,
277{
278 let n_samples = X.nrows();
279 let n_features = X.ncols();
280 let mut result = Array2::zeros((n_samples, n_features));
281
282 for i in 0..n_samples {
283 let instance = X.row(i);
284 let shap_values = compute_shap_single_instance(model.clone(), &instance, baseline)?;
285 result.row_mut(i).assign(&shap_values);
286 }
287
288 Ok(result)
289}
290
291fn compute_shap_single_instance<F>(
293 model: F,
294 instance: &ArrayView1<Float>,
295 baseline: &ArrayView1<Float>,
296) -> SklResult<Array1<Float>>
297where
298 F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
299{
300 let n_features = instance.len();
301 let mut shap_values = Array1::zeros(n_features);
302
303 for i in 0..n_features {
305 let mut coalition_without = baseline.to_owned();
307 for j in 0..n_features {
308 if j != i {
309 coalition_without[j] = instance[j];
310 }
311 }
312
313 let mut coalition_with = coalition_without.clone();
315 coalition_with[i] = instance[i];
316
317 let pred_without = model(&coalition_without.view().insert_axis(Axis(0)))?;
319 let pred_with = model(&coalition_with.view().insert_axis(Axis(0)))?;
320
321 shap_values[i] = pred_with[0] - pred_without[0];
322 }
323
324 Ok(shap_values)
325}
326
327#[derive(Debug, Clone)]
329pub struct BatchConfig {
330 pub base_batch_size: usize,
332 pub max_batch_size: usize,
334 pub min_batch_size: usize,
336 pub memory_limit_mb: usize,
338 pub dynamic_sizing: bool,
340 pub enable_progress: bool,
342}
343
344impl Default for BatchConfig {
345 fn default() -> Self {
346 Self {
347 base_batch_size: 1000,
348 max_batch_size: 10000,
349 min_batch_size: 100,
350 memory_limit_mb: 512,
351 dynamic_sizing: true,
352 enable_progress: false,
353 }
354 }
355}
356
357impl BatchConfig {
358 pub fn new() -> Self {
360 Self::default()
361 }
362
363 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
365 self.base_batch_size = batch_size;
366 self
367 }
368
369 pub fn with_memory_limit(mut self, memory_mb: usize) -> Self {
371 self.memory_limit_mb = memory_mb;
372 self
373 }
374
375 pub fn with_dynamic_sizing(mut self, enabled: bool) -> Self {
377 self.dynamic_sizing = enabled;
378 self
379 }
380
381 pub fn with_progress(mut self, enabled: bool) -> Self {
383 self.enable_progress = enabled;
384 self
385 }
386
387 pub fn calculate_optimal_batch_size(&self, item_size_bytes: usize) -> usize {
389 if item_size_bytes == 0 {
390 return self.base_batch_size;
391 }
392
393 let memory_limit_bytes = self.memory_limit_mb * 1024 * 1024;
394 let max_items_per_batch = memory_limit_bytes / item_size_bytes;
395
396 let optimal_size = max_items_per_batch
397 .min(self.max_batch_size)
398 .max(self.min_batch_size);
399
400 if self.dynamic_sizing {
401 optimal_size
402 } else {
403 self.base_batch_size
404 }
405 }
406}
407
408#[derive(Debug, Clone)]
410pub struct BatchStats {
411 pub total_items: usize,
413 pub num_batches: usize,
415 pub avg_batch_size: f64,
417 pub total_time_ms: u128,
419 pub avg_time_per_batch_ms: f64,
421 pub avg_time_per_item_us: f64,
423 pub peak_memory_mb: usize,
425}
426
427impl Default for BatchStats {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433impl BatchStats {
434 pub fn new() -> Self {
436 Self {
437 total_items: 0,
438 num_batches: 0,
439 avg_batch_size: 0.0,
440 total_time_ms: 0,
441 avg_time_per_batch_ms: 0.0,
442 avg_time_per_item_us: 0.0,
443 peak_memory_mb: 0,
444 }
445 }
446
447 pub fn update(&mut self, batch_size: usize, batch_time_ms: u128, memory_mb: usize) {
449 self.total_items += batch_size;
450 self.num_batches += 1;
451 self.total_time_ms += batch_time_ms;
452 self.peak_memory_mb = self.peak_memory_mb.max(memory_mb);
453
454 self.avg_batch_size = self.total_items as f64 / self.num_batches as f64;
456 self.avg_time_per_batch_ms = self.total_time_ms as f64 / self.num_batches as f64;
457 self.avg_time_per_item_us = (self.total_time_ms as f64 * 1000.0) / self.total_items as f64;
458 }
459
460 pub fn throughput(&self) -> f64 {
462 if self.total_time_ms == 0 {
463 return 0.0;
464 }
465 (self.total_items as f64 * 1000.0) / self.total_time_ms as f64
466 }
467
468 pub fn efficiency(&self) -> f64 {
470 if self.num_batches == 0 || self.total_time_ms == 0 {
471 return 0.0;
472 }
473
474 let throughput_score = (self.throughput() / 1000.0).min(1.0);
476 let memory_score = (512.0 / self.peak_memory_mb as f64).min(1.0);
477
478 (throughput_score + memory_score) / 2.0
479 }
480}
481
482pub type ProgressCallback = Box<dyn Fn(usize, usize) + Send + Sync>;
484
485#[cfg(feature = "parallel")]
487pub fn process_batches_parallel<T, R, F>(
488 data: &[T],
489 batch_size: usize,
490 config: &ParallelConfig,
491 processor: F,
492) -> SklResult<Vec<R>>
493where
494 T: Send + Sync,
495 R: Send,
496 F: Fn(&[T]) -> SklResult<Vec<R>> + Send + Sync,
497{
498 if !config.should_parallelize(data.len()) {
499 return processor(data);
500 }
501
502 let results: SklResult<Vec<_>> = data
503 .chunks(batch_size)
504 .collect::<Vec<_>>()
505 .into_par_iter()
506 .map(processor)
507 .collect();
508
509 let batched_results = results?;
510 Ok(batched_results.into_iter().flatten().collect())
511}
512
513#[cfg(feature = "parallel")]
515pub fn process_batches_optimized<T, R, F>(
516 data: &[T],
517 batch_config: &BatchConfig,
518 parallel_config: &ParallelConfig,
519 processor: F,
520 progress_callback: Option<ProgressCallback>,
521) -> SklResult<(Vec<R>, BatchStats)>
522where
523 T: Send + Sync,
524 R: Send,
525 F: Fn(&[T]) -> SklResult<Vec<R>> + Send + Sync,
526{
527 let mut stats = BatchStats::new();
528 let total_items = data.len();
529
530 if !parallel_config.should_parallelize(total_items) {
531 let start_time = std::time::Instant::now();
532 let results = processor(data)?;
533 let elapsed = start_time.elapsed().as_millis();
534 stats.update(total_items, elapsed, batch_config.memory_limit_mb);
535 return Ok((results, stats));
536 }
537
538 let estimated_item_size = std::mem::size_of::<T>();
540 let optimal_batch_size = batch_config.calculate_optimal_batch_size(estimated_item_size);
541
542 let batches: Vec<_> = data.chunks(optimal_batch_size).collect();
543 let total_batches = batches.len();
544
545 let start_time = std::time::Instant::now();
546 let processed_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
547
548 let results: SklResult<Vec<_>> = batches
549 .into_par_iter()
550 .enumerate()
551 .map(|(batch_idx, batch)| {
552 let batch_start = std::time::Instant::now();
553 let batch_result = processor(batch);
554 let batch_time = batch_start.elapsed().as_millis();
555
556 if let Some(ref callback) = progress_callback {
558 let completed =
559 processed_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
560 callback(completed, total_batches);
561 }
562
563 batch_result
565 })
566 .collect();
567
568 let total_time = start_time.elapsed().as_millis();
569 let batched_results = results?;
570 let final_results = batched_results.into_iter().flatten().collect();
571
572 stats.update(total_items, total_time, batch_config.memory_limit_mb);
574 stats.num_batches = total_batches;
575 stats.avg_batch_size = optimal_batch_size as f64;
576
577 Ok((final_results, stats))
578}
579
580#[cfg(feature = "parallel")]
582pub struct StreamingBatchProcessor<T, R> {
583 batch_config: BatchConfig,
584 parallel_config: ParallelConfig,
585 buffer: Vec<T>,
586 results: Vec<R>,
587 stats: BatchStats,
588}
589
590#[cfg(feature = "parallel")]
591impl<T, R> StreamingBatchProcessor<T, R>
592where
593 T: Send + Sync,
594 R: Send,
595{
596 pub fn new(batch_config: BatchConfig, parallel_config: ParallelConfig) -> Self {
598 Self {
599 batch_config,
600 parallel_config,
601 buffer: Vec::new(),
602 results: Vec::new(),
603 stats: BatchStats::new(),
604 }
605 }
606
607 pub fn push(&mut self, item: T) {
609 self.buffer.push(item);
610 }
611
612 pub fn process_buffer<F>(&mut self, processor: F) -> SklResult<Vec<R>>
614 where
615 F: Fn(&[T]) -> SklResult<Vec<R>> + Send + Sync,
616 R: Clone,
617 {
618 if self.buffer.is_empty() {
619 return Ok(Vec::new());
620 }
621
622 let batch_size = self
623 .batch_config
624 .calculate_optimal_batch_size(std::mem::size_of::<T>());
625 let start_time = std::time::Instant::now();
626
627 let mut batch_results = Vec::new();
628
629 for chunk in self.buffer.chunks(batch_size) {
630 let chunk_results = processor(chunk)?;
631 batch_results.extend(chunk_results);
632 }
633
634 let elapsed = start_time.elapsed().as_millis();
635 self.stats.update(
636 self.buffer.len(),
637 elapsed,
638 self.batch_config.memory_limit_mb,
639 );
640
641 self.buffer.clear();
642 let result_copy = batch_results.clone();
643 self.results.extend(batch_results);
644
645 Ok(result_copy)
646 }
647
648 pub fn stats(&self) -> &BatchStats {
650 &self.stats
651 }
652
653 pub fn results(&self) -> &[R] {
655 &self.results
656 }
657
658 pub fn reset(&mut self) {
660 self.buffer.clear();
661 self.results.clear();
662 self.stats = BatchStats::new();
663 }
664}
665
666#[derive(Debug, Clone)]
668pub struct AdaptiveBatchConfig {
669 pub base_config: BatchConfig,
671 pub monitor_cpu: bool,
673 pub monitor_memory: bool,
675 pub cpu_threshold: f64,
677 pub memory_threshold: f64,
679 pub sizing_factor: f64,
681}
682
683impl Default for AdaptiveBatchConfig {
684 fn default() -> Self {
685 Self {
686 base_config: BatchConfig::default(),
687 monitor_cpu: true,
688 monitor_memory: true,
689 cpu_threshold: 0.8,
690 memory_threshold: 0.8,
691 sizing_factor: 0.5,
692 }
693 }
694}
695
696impl AdaptiveBatchConfig {
697 pub fn new() -> Self {
699 Self::default()
700 }
701
702 pub fn with_base_config(mut self, config: BatchConfig) -> Self {
704 self.base_config = config;
705 self
706 }
707
708 pub fn with_cpu_monitoring(mut self, enabled: bool) -> Self {
710 self.monitor_cpu = enabled;
711 self
712 }
713
714 pub fn with_memory_monitoring(mut self, enabled: bool) -> Self {
716 self.monitor_memory = enabled;
717 self
718 }
719
720 pub fn with_cpu_threshold(mut self, threshold: f64) -> Self {
722 self.cpu_threshold = threshold.clamp(0.0, 1.0);
723 self
724 }
725
726 pub fn with_memory_threshold(mut self, threshold: f64) -> Self {
728 self.memory_threshold = threshold.clamp(0.0, 1.0);
729 self
730 }
731
732 pub fn calculate_adaptive_batch_size(&self, item_size_bytes: usize) -> usize {
734 let base_size = self
735 .base_config
736 .calculate_optimal_batch_size(item_size_bytes);
737
738 let cpu_load = self.get_cpu_load();
740 let memory_load = self.get_memory_load();
741
742 let mut scaling_factor = 1.0;
743
744 if self.monitor_cpu && cpu_load > self.cpu_threshold {
745 scaling_factor *= self.sizing_factor;
746 }
747
748 if self.monitor_memory && memory_load > self.memory_threshold {
749 scaling_factor *= self.sizing_factor;
750 }
751
752 let adaptive_size = (base_size as f64 * scaling_factor) as usize;
753 adaptive_size
754 .max(self.base_config.min_batch_size)
755 .min(self.base_config.max_batch_size)
756 }
757
758 fn get_cpu_load(&self) -> f64 {
760 0.5 }
764
765 fn get_memory_load(&self) -> f64 {
767 0.6 }
771}
772
773#[derive(Debug)]
775pub struct MemoryPool<T> {
776 pool: Vec<Vec<T>>,
777 max_pool_size: usize,
778 total_allocations: usize,
779 total_reuses: usize,
780}
781
782impl<T> MemoryPool<T> {
783 pub fn new(max_pool_size: usize) -> Self {
785 Self {
786 pool: Vec::new(),
787 max_pool_size,
788 total_allocations: 0,
789 total_reuses: 0,
790 }
791 }
792
793 pub fn get_vec(&mut self, capacity: usize) -> Vec<T> {
795 if let Some(mut vec) = self.pool.pop() {
796 vec.clear();
797 if vec.capacity() < capacity {
798 vec.reserve(capacity - vec.capacity());
799 }
800 self.total_reuses += 1;
801 vec
802 } else {
803 self.total_allocations += 1;
804 Vec::with_capacity(capacity)
805 }
806 }
807
808 pub fn return_vec(&mut self, vec: Vec<T>) {
810 if self.pool.len() < self.max_pool_size {
811 self.pool.push(vec);
812 }
813 }
814
815 pub fn stats(&self) -> (usize, usize, f64) {
817 let total_requests = self.total_allocations + self.total_reuses;
818 let reuse_rate = if total_requests > 0 {
819 self.total_reuses as f64 / total_requests as f64
820 } else {
821 0.0
822 };
823 (self.total_allocations, self.total_reuses, reuse_rate)
824 }
825}
826
827#[cfg(feature = "parallel")]
829pub struct HighPerformanceBatchProcessor<T, R> {
830 adaptive_config: AdaptiveBatchConfig,
831 parallel_config: ParallelConfig,
832 memory_pool: MemoryPool<T>,
833 result_pool: MemoryPool<R>,
834 stats: BatchStats,
835}
836
837#[cfg(feature = "parallel")]
838impl<T, R> HighPerformanceBatchProcessor<T, R>
839where
840 T: Send + Sync + Clone,
841 R: Send + Clone,
842{
843 pub fn new(adaptive_config: AdaptiveBatchConfig, parallel_config: ParallelConfig) -> Self {
845 Self {
846 adaptive_config,
847 parallel_config,
848 memory_pool: MemoryPool::new(10), result_pool: MemoryPool::new(10),
850 stats: BatchStats::new(),
851 }
852 }
853
854 pub fn process_adaptive<F>(&mut self, data: &[T], processor: F) -> SklResult<Vec<R>>
856 where
857 F: Fn(&[T]) -> SklResult<Vec<R>> + Send + Sync,
858 {
859 let start_time = std::time::Instant::now();
860 let item_size = std::mem::size_of::<T>();
861 let adaptive_batch_size = self
862 .adaptive_config
863 .calculate_adaptive_batch_size(item_size);
864
865 let mut results = self.result_pool.get_vec(data.len());
866
867 if !self.parallel_config.should_parallelize(data.len()) {
868 let batch_results = processor(data)?;
870 results.extend(batch_results);
871 } else {
872 let batches: Vec<_> = data.chunks(adaptive_batch_size).collect();
874 let batch_results: SklResult<Vec<_>> = batches.into_par_iter().map(processor).collect();
875
876 let processed_results = batch_results?;
877 for batch_result in processed_results {
878 results.extend(batch_result);
879 }
880 }
881
882 let elapsed = start_time.elapsed().as_millis();
883 self.stats.update(
884 data.len(),
885 elapsed,
886 self.adaptive_config.base_config.memory_limit_mb,
887 );
888
889 Ok(results)
890 }
891
892 pub fn memory_pool_stats(&self) -> ((usize, usize, f64), (usize, usize, f64)) {
894 (self.memory_pool.stats(), self.result_pool.stats())
895 }
896
897 pub fn stats(&self) -> &BatchStats {
899 &self.stats
900 }
901}
902
903#[derive(Debug, Clone)]
905pub struct CompressedBatch<T> {
906 data: Vec<T>,
908 original_size: usize,
910 compression_ratio: f64,
912}
913
914impl<T> CompressedBatch<T>
915where
916 T: Clone,
917{
918 pub fn compress(data: Vec<T>) -> Self {
919 let original_size = data.len();
920 let compression_ratio = 0.7; Self {
924 data,
925 original_size,
926 compression_ratio,
927 }
928 }
929
930 pub fn decompress(&self) -> Vec<T> {
932 self.data.clone()
934 }
935
936 pub fn compression_ratio(&self) -> f64 {
938 self.compression_ratio
939 }
940
941 pub fn compressed_size(&self) -> usize {
943 self.data.len()
944 }
945
946 pub fn original_size(&self) -> usize {
948 self.original_size
949 }
950}
951
952#[derive(Debug)]
954pub struct CacheAwareExplanationStore<T> {
955 hot_cache: std::collections::HashMap<u64, T>,
957 cold_storage: std::collections::HashMap<u64, CompressedBatch<T>>,
959 access_counts: std::collections::HashMap<u64, usize>,
961 max_hot_cache_size: usize,
963 cache_hits: usize,
965 cache_misses: usize,
966}
967
968impl<T> CacheAwareExplanationStore<T>
969where
970 T: Clone + std::hash::Hash,
971{
972 pub fn new(max_hot_cache_size: usize) -> Self {
974 Self {
975 hot_cache: std::collections::HashMap::new(),
976 cold_storage: std::collections::HashMap::new(),
977 access_counts: std::collections::HashMap::new(),
978 max_hot_cache_size,
979 cache_hits: 0,
980 cache_misses: 0,
981 }
982 }
983
984 pub fn store(&mut self, key: u64, value: T) {
986 if self.hot_cache.len() < self.max_hot_cache_size {
987 self.hot_cache.insert(key, value);
988 } else {
989 let compressed = CompressedBatch::compress(vec![value]);
991 self.cold_storage.insert(key, compressed);
992 }
993 self.access_counts.insert(key, 1);
994 }
995
996 pub fn get(&mut self, key: u64) -> Option<T> {
998 if let Some(value) = self.hot_cache.get(&key) {
1000 self.cache_hits += 1;
1001 *self.access_counts.entry(key).or_insert(0) += 1;
1002 return Some(value.clone());
1003 }
1004
1005 if let Some(compressed) = self.cold_storage.get(&key) {
1007 self.cache_hits += 1;
1008 let decompressed = compressed.decompress();
1009 let value = decompressed.into_iter().next()?;
1010
1011 let access_count = *self.access_counts.entry(key).or_insert(0) + 1;
1013 self.access_counts.insert(key, access_count);
1014
1015 if access_count > 3 && self.hot_cache.len() < self.max_hot_cache_size {
1016 self.hot_cache.insert(key, value.clone());
1017 self.cold_storage.remove(&key);
1018 }
1019
1020 return Some(value);
1021 }
1022
1023 self.cache_misses += 1;
1024 None
1025 }
1026
1027 pub fn cache_stats(&self) -> (usize, usize, f64, usize, usize) {
1029 let total_accesses = self.cache_hits + self.cache_misses;
1030 let hit_rate = if total_accesses > 0 {
1031 self.cache_hits as f64 / total_accesses as f64
1032 } else {
1033 0.0
1034 };
1035 (
1036 self.cache_hits,
1037 self.cache_misses,
1038 hit_rate,
1039 self.hot_cache.len(),
1040 self.cold_storage.len(),
1041 )
1042 }
1043}
1044
1045#[cfg(not(feature = "parallel"))]
1047pub fn process_batches_parallel<T, R, F>(
1048 data: &[T],
1049 _batch_size: usize,
1050 _config: &ParallelConfig,
1051 processor: F,
1052) -> SklResult<Vec<R>>
1053where
1054 F: Fn(&[T]) -> SklResult<Vec<R>>,
1055{
1056 processor(data)
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use scirs2_core::ndarray::array;
1064
1065 #[test]
1066 fn test_parallel_config_creation() {
1067 let config = ParallelConfig::new();
1068 assert!(!config.force_sequential);
1069 assert_eq!(config.min_batch_size, 100);
1070 }
1071
1072 #[test]
1073 fn test_parallel_config_with_threads() {
1074 let config = ParallelConfig::new().with_threads(4);
1075 assert_eq!(config.n_threads, Some(4));
1076 }
1077
1078 #[test]
1079 fn test_parallel_config_sequential() {
1080 let config = ParallelConfig::new().sequential();
1081 assert!(config.force_sequential);
1082 assert_eq!(config.get_n_threads(), 1);
1083 }
1084
1085 #[test]
1086 fn test_should_parallelize() {
1087 let config = ParallelConfig::new();
1088 assert!(config.should_parallelize(1000));
1089 assert!(!config.should_parallelize(50));
1090
1091 let sequential_config = ParallelConfig::new().sequential();
1092 assert!(!sequential_config.should_parallelize(1000));
1093 }
1094
1095 #[test]
1096 fn test_permutation_input_creation() {
1097 let x_data = array![[1.0, 2.0], [3.0, 4.0]];
1098 let y_true = array![1.0, 0.0];
1099
1100 let input = PermutationInput {
1101 feature_idx: 0,
1102 x_data,
1103 y_true,
1104 n_repeats: 5,
1105 random_state: 42,
1106 };
1107
1108 assert_eq!(input.feature_idx, 0);
1109 assert_eq!(input.n_repeats, 5);
1110 assert_eq!(input.random_state, 42);
1111 }
1112
1113 #[test]
1114 #[allow(non_snake_case)]
1115 fn test_shap_sequential_computation() {
1116 let model = |x: &ArrayView2<Float>| -> SklResult<Array1<Float>> { Ok(x.sum_axis(Axis(1))) };
1117
1118 let X = array![[1.0, 2.0], [3.0, 4.0]];
1119 let baseline = array![0.0, 0.0];
1120
1121 let result = compute_shap_sequential(model, &X.view(), &baseline.view());
1122 assert!(result.is_ok());
1123
1124 let shap_values = result.unwrap();
1125 assert_eq!(shap_values.shape(), &[2, 2]);
1126 }
1127
1128 #[cfg(feature = "parallel")]
1129 #[test]
1130 #[allow(non_snake_case)]
1131 fn test_parallel_shap_computation() {
1132 let model = |x: &ArrayView2<Float>| -> SklResult<Array1<Float>> { Ok(x.sum_axis(Axis(1))) };
1133
1134 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1135 let baseline = array![0.0, 0.0];
1136 let config = ParallelConfig::new().with_min_batch_size(1);
1137
1138 let result = compute_shap_parallel(model, &X.view(), &baseline.view(), &config);
1139 assert!(result.is_ok());
1140
1141 let shap_values = result.unwrap();
1142 assert_eq!(shap_values.shape(), &[3, 2]);
1143 }
1144
1145 #[test]
1146 fn test_single_instance_shap() {
1147 let model = |x: &ArrayView2<Float>| -> SklResult<Array1<Float>> { Ok(x.sum_axis(Axis(1))) };
1148
1149 let instance = array![1.0, 2.0];
1150 let baseline = array![0.0, 0.0];
1151
1152 let result = compute_shap_single_instance(model, &instance.view(), &baseline.view());
1153 assert!(result.is_ok());
1154
1155 let shap_values = result.unwrap();
1156 assert_eq!(shap_values.len(), 2);
1157 }
1158
1159 #[test]
1160 fn test_batch_config_creation() {
1161 let config = BatchConfig::new();
1162 assert_eq!(config.base_batch_size, 1000);
1163 assert_eq!(config.max_batch_size, 10000);
1164 assert_eq!(config.min_batch_size, 100);
1165 assert_eq!(config.memory_limit_mb, 512);
1166 assert!(config.dynamic_sizing);
1167 assert!(!config.enable_progress);
1168 }
1169
1170 #[test]
1171 fn test_batch_config_fluent_api() {
1172 let config = BatchConfig::new()
1173 .with_batch_size(500)
1174 .with_memory_limit(256)
1175 .with_dynamic_sizing(false)
1176 .with_progress(true);
1177
1178 assert_eq!(config.base_batch_size, 500);
1179 assert_eq!(config.memory_limit_mb, 256);
1180 assert!(!config.dynamic_sizing);
1181 assert!(config.enable_progress);
1182 }
1183
1184 #[test]
1185 fn test_batch_config_optimal_batch_size() {
1186 let config = BatchConfig::new()
1187 .with_memory_limit(1) .with_dynamic_sizing(true);
1189
1190 let optimal_size = config.calculate_optimal_batch_size(1024); assert_eq!(optimal_size, 1024); let optimal_size = config.calculate_optimal_batch_size(1024 * 1024 * 10); assert_eq!(optimal_size, config.min_batch_size);
1197
1198 let static_config = BatchConfig::new().with_dynamic_sizing(false);
1200 let optimal_size = static_config.calculate_optimal_batch_size(1024);
1201 assert_eq!(optimal_size, static_config.base_batch_size);
1202 }
1203
1204 #[test]
1205 fn test_batch_stats_creation() {
1206 let stats = BatchStats::new();
1207 assert_eq!(stats.total_items, 0);
1208 assert_eq!(stats.num_batches, 0);
1209 assert_eq!(stats.avg_batch_size, 0.0);
1210 assert_eq!(stats.total_time_ms, 0);
1211 assert_eq!(stats.throughput(), 0.0);
1212 assert_eq!(stats.efficiency(), 0.0);
1213 }
1214
1215 #[test]
1216 fn test_batch_stats_update() {
1217 let mut stats = BatchStats::new();
1218
1219 stats.update(100, 1000, 128); assert_eq!(stats.total_items, 100);
1222 assert_eq!(stats.num_batches, 1);
1223 assert_eq!(stats.avg_batch_size, 100.0);
1224 assert_eq!(stats.total_time_ms, 1000);
1225 assert_eq!(stats.peak_memory_mb, 128);
1226
1227 stats.update(200, 2000, 256); assert_eq!(stats.total_items, 300);
1230 assert_eq!(stats.num_batches, 2);
1231 assert_eq!(stats.avg_batch_size, 150.0);
1232 assert_eq!(stats.total_time_ms, 3000);
1233 assert_eq!(stats.peak_memory_mb, 256);
1234 }
1235
1236 #[test]
1237 fn test_batch_stats_throughput() {
1238 let mut stats = BatchStats::new();
1239 stats.update(1000, 2000, 128); let throughput = stats.throughput();
1242 assert!((throughput - 500.0).abs() < 0.001); }
1244
1245 #[test]
1246 fn test_batch_stats_efficiency() {
1247 let mut stats = BatchStats::new();
1248 stats.update(1000, 1000, 256); let efficiency = stats.efficiency();
1251 assert!(efficiency > 0.0);
1252 assert!(efficiency <= 1.0);
1253 }
1254
1255 #[cfg(feature = "parallel")]
1256 #[test]
1257 fn test_process_batches_optimized() {
1258 let data: Vec<i32> = (0..1000).collect();
1259 let batch_config = BatchConfig::new().with_batch_size(100);
1260 let parallel_config = ParallelConfig::new().with_min_batch_size(50);
1261
1262 let processor =
1263 |batch: &[i32]| -> SklResult<Vec<i32>> { Ok(batch.iter().map(|x| x * 2).collect()) };
1264
1265 let result =
1266 process_batches_optimized(&data, &batch_config, ¶llel_config, processor, None);
1267
1268 assert!(result.is_ok());
1269 let (results, stats) = result.unwrap();
1270 assert_eq!(results.len(), 1000);
1271 assert_eq!(results[0], 0);
1272 assert_eq!(results[999], 1998);
1273 assert!(stats.total_items > 0);
1274 assert!(stats.num_batches > 0);
1275 }
1276
1277 #[cfg(feature = "parallel")]
1278 #[test]
1279 fn test_process_batches_with_progress() {
1280 let data: Vec<i32> = (0..500).collect();
1281 let batch_config = BatchConfig::new().with_batch_size(100);
1282 let parallel_config = ParallelConfig::new().with_min_batch_size(50);
1283
1284 let progress_calls = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
1285 let progress_calls_clone = progress_calls.clone();
1286
1287 let progress_callback = Box::new(move |completed: usize, total: usize| {
1288 progress_calls_clone
1289 .lock()
1290 .unwrap()
1291 .push((completed, total));
1292 });
1293
1294 let processor =
1295 |batch: &[i32]| -> SklResult<Vec<i32>> { Ok(batch.iter().map(|x| x * 2).collect()) };
1296
1297 let result = process_batches_optimized(
1298 &data,
1299 &batch_config,
1300 ¶llel_config,
1301 processor,
1302 Some(progress_callback),
1303 );
1304
1305 assert!(result.is_ok());
1306 let (results, _stats) = result.unwrap();
1307 assert_eq!(results.len(), 500);
1308
1309 let calls = progress_calls.lock().unwrap();
1311 assert!(!calls.is_empty());
1312 }
1313
1314 #[cfg(feature = "parallel")]
1315 #[test]
1316 fn test_streaming_batch_processor() {
1317 let batch_config = BatchConfig::new().with_batch_size(50);
1318 let parallel_config = ParallelConfig::new();
1319 let mut processor = StreamingBatchProcessor::new(batch_config, parallel_config);
1320
1321 for i in 0..100 {
1323 processor.push(i);
1324 }
1325
1326 let process_fn =
1328 |batch: &[i32]| -> SklResult<Vec<i32>> { Ok(batch.iter().map(|x| x * 2).collect()) };
1329
1330 let result = processor.process_buffer(process_fn);
1331 assert!(result.is_ok());
1332
1333 let batch_results = result.unwrap();
1334 assert_eq!(batch_results.len(), 100);
1335 assert_eq!(batch_results[0], 0);
1336 assert_eq!(batch_results[99], 198);
1337
1338 let stats = processor.stats();
1340 assert_eq!(stats.total_items, 100);
1341 assert!(stats.num_batches > 0);
1342
1343 let all_results = processor.results();
1345 assert_eq!(all_results.len(), 100);
1346 }
1347
1348 #[cfg(feature = "parallel")]
1349 #[test]
1350 fn test_streaming_batch_processor_reset() {
1351 let batch_config = BatchConfig::new();
1352 let parallel_config = ParallelConfig::new();
1353 let mut processor = StreamingBatchProcessor::new(batch_config, parallel_config);
1354
1355 processor.push(1);
1357 processor.push(2);
1358 let process_fn = |batch: &[i32]| -> SklResult<Vec<i32>> { Ok(batch.to_vec()) };
1359 let _ = processor.process_buffer(process_fn);
1360
1361 assert_eq!(processor.results().len(), 2);
1363 assert!(processor.stats().total_items > 0);
1364
1365 processor.reset();
1367 assert_eq!(processor.results().len(), 0);
1368 assert_eq!(processor.stats().total_items, 0);
1369 }
1370
1371 #[cfg(feature = "parallel")]
1372 #[test]
1373 fn test_streaming_batch_processor_empty_buffer() {
1374 let batch_config = BatchConfig::new();
1375 let parallel_config = ParallelConfig::new();
1376 let mut processor = StreamingBatchProcessor::new(batch_config, parallel_config);
1377
1378 let process_fn = |batch: &[i32]| -> SklResult<Vec<i32>> { Ok(batch.to_vec()) };
1379
1380 let result = processor.process_buffer(process_fn);
1381 assert!(result.is_ok());
1382
1383 let batch_results = result.unwrap();
1384 assert!(batch_results.is_empty());
1385 }
1386
1387 #[test]
1390 fn test_adaptive_batch_config_creation() {
1391 let config = AdaptiveBatchConfig::new();
1392 assert!(config.monitor_cpu);
1393 assert!(config.monitor_memory);
1394 assert_eq!(config.cpu_threshold, 0.8);
1395 assert_eq!(config.memory_threshold, 0.8);
1396 assert_eq!(config.sizing_factor, 0.5);
1397 }
1398
1399 #[test]
1400 fn test_adaptive_batch_config_fluent_api() {
1401 let base_config = BatchConfig::new().with_batch_size(500);
1402 let adaptive_config = AdaptiveBatchConfig::new()
1403 .with_base_config(base_config)
1404 .with_cpu_monitoring(false)
1405 .with_memory_monitoring(true)
1406 .with_cpu_threshold(0.9)
1407 .with_memory_threshold(0.7);
1408
1409 assert!(!adaptive_config.monitor_cpu);
1410 assert!(adaptive_config.monitor_memory);
1411 assert_eq!(adaptive_config.cpu_threshold, 0.9);
1412 assert_eq!(adaptive_config.memory_threshold, 0.7);
1413 assert_eq!(adaptive_config.base_config.base_batch_size, 500);
1414 }
1415
1416 #[test]
1417 fn test_adaptive_batch_size_calculation() {
1418 let config = AdaptiveBatchConfig::new();
1419 let batch_size = config.calculate_adaptive_batch_size(1024);
1420
1421 assert!(batch_size >= config.base_config.min_batch_size);
1423 assert!(batch_size <= config.base_config.max_batch_size);
1424 }
1425
1426 #[test]
1427 fn test_memory_pool_creation() {
1428 let mut pool: MemoryPool<i32> = MemoryPool::new(5);
1429 let (allocations, reuses, reuse_rate) = pool.stats();
1430
1431 assert_eq!(allocations, 0);
1432 assert_eq!(reuses, 0);
1433 assert_eq!(reuse_rate, 0.0);
1434 }
1435
1436 #[test]
1437 fn test_memory_pool_reuse() {
1438 let mut pool: MemoryPool<i32> = MemoryPool::new(5);
1439
1440 let vec1 = pool.get_vec(10);
1442 assert_eq!(vec1.capacity(), 10);
1443
1444 pool.return_vec(vec1);
1446
1447 let vec2 = pool.get_vec(5);
1449 assert!(vec2.capacity() >= 5); let (allocations, reuses, reuse_rate) = pool.stats();
1452 assert_eq!(allocations, 1);
1453 assert_eq!(reuses, 1);
1454 assert!((reuse_rate - 0.5).abs() < 0.001); }
1456
1457 #[cfg(feature = "parallel")]
1458 #[test]
1459 fn test_high_performance_batch_processor() {
1460 let adaptive_config = AdaptiveBatchConfig::new();
1461 let parallel_config = ParallelConfig::new().with_min_batch_size(10);
1462 let mut processor = HighPerformanceBatchProcessor::new(adaptive_config, parallel_config);
1463
1464 let data: Vec<i32> = (0..100).collect();
1465 let process_fn =
1466 |batch: &[i32]| -> SklResult<Vec<i32>> { Ok(batch.iter().map(|x| x * 2).collect()) };
1467
1468 let result = processor.process_adaptive(&data, process_fn);
1469 assert!(result.is_ok());
1470
1471 let results = result.unwrap();
1472 assert_eq!(results.len(), 100);
1473 assert_eq!(results[0], 0);
1474 assert_eq!(results[99], 198);
1475
1476 let stats = processor.stats();
1478 assert_eq!(stats.total_items, 100);
1479
1480 let (pool_stats, result_pool_stats) = processor.memory_pool_stats();
1482 assert!(pool_stats.0 >= 0); assert!(result_pool_stats.0 >= 0); }
1485
1486 #[test]
1487 fn test_compressed_batch_creation() {
1488 let data = vec![1, 2, 3, 4, 5];
1489 let compressed = CompressedBatch::compress(data.clone());
1490
1491 assert_eq!(compressed.original_size(), 5);
1492 assert_eq!(compressed.compressed_size(), 5); assert_eq!(compressed.compression_ratio(), 0.7);
1494
1495 let decompressed = compressed.decompress();
1496 assert_eq!(decompressed, data);
1497 }
1498
1499 #[test]
1500 fn test_cache_aware_explanation_store() {
1501 let mut store: CacheAwareExplanationStore<i32> = CacheAwareExplanationStore::new(2);
1502
1503 store.store(1, 100);
1505 store.store(2, 200);
1506 store.store(3, 300); assert_eq!(store.get(1), Some(100));
1510 assert_eq!(store.get(2), Some(200));
1511
1512 assert_eq!(store.get(3), Some(300));
1514
1515 assert_eq!(store.get(4), None);
1517
1518 let (hits, misses, hit_rate, hot_size, cold_size) = store.cache_stats();
1520 assert!(hits >= 3);
1521 assert_eq!(misses, 1);
1522 assert!(hit_rate > 0.5);
1523 assert!(hot_size <= 2);
1524 }
1525
1526 #[test]
1527 fn test_cache_promotion() {
1528 let mut store: CacheAwareExplanationStore<i32> = CacheAwareExplanationStore::new(1);
1529
1530 store.store(1, 100);
1532 store.store(2, 200); for _ in 0..4 {
1536 assert_eq!(store.get(2), Some(200));
1537 }
1538
1539 let (_, _, _, hot_size, cold_size) = store.cache_stats();
1540 assert!(hot_size == 1); }
1543
1544 #[test]
1545 fn test_threshold_clamping() {
1546 let config = AdaptiveBatchConfig::new()
1547 .with_cpu_threshold(1.5) .with_memory_threshold(-0.5); assert_eq!(config.cpu_threshold, 1.0);
1551 assert_eq!(config.memory_threshold, 0.0);
1552 }
1553}