1use rayon::prelude::*;
15use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
16use scirs2_core::random::{thread_rng, Distribution, Rng};
17use sklears_core::error::{Result, SklearsError};
18use sklears_core::traits::Estimator;
19use std::collections::HashMap;
20use std::sync::RwLock;
21use std::time::Instant;
22
23#[derive(Debug, Clone)]
25pub struct LargeScaleConfig {
26 pub max_memory_bytes: usize,
28 pub chunk_size: usize,
30 pub n_workers: usize,
32 pub use_memory_mapping: bool,
34 pub enable_compression: bool,
36}
37
38impl Default for LargeScaleConfig {
39 fn default() -> Self {
40 Self {
41 max_memory_bytes: 1_073_741_824, chunk_size: 10_000,
43 n_workers: num_cpus::get(),
44 use_memory_mapping: true,
45 enable_compression: true,
46 }
47 }
48}
49
50#[derive(Debug, Clone, PartialEq)]
52pub enum LargeScaleStrategy {
53 ChunkedProcessing { chunk_size: usize, overlap: usize },
55 MemoryMapped {
57 block_size: usize,
58 prefetch_blocks: usize,
59 },
60 ReservoirSampling {
62 reservoir_size: usize,
63 replacement_rate: f64,
64 },
65 SketchBased {
67 sketch_size: usize,
68 hash_functions: usize,
69 },
70 Distributed {
72 node_id: usize,
73 total_nodes: usize,
74 coordinator_address: String,
75 },
76}
77
78pub struct LargeScaleDummyEstimator {
80 strategy: LargeScaleStrategy,
81 config: LargeScaleConfig,
82 state: RwLock<LargeScaleState>,
83}
84
85#[derive(Debug, Clone)]
86struct LargeScaleState {
87 sample_count: usize,
89 running_sum: f64,
90 running_sum_squares: f64,
91 reservoir: Vec<f64>,
93 sketches: HashMap<usize, Vec<f64>>,
95 node_statistics: HashMap<usize, NodeStatistics>,
97 current_memory_usage: usize,
99}
100
101#[derive(Debug, Clone)]
102struct NodeStatistics {
103 sample_count: usize,
104 mean: f64,
105 variance: f64,
106 last_update: Instant,
107}
108
109impl LargeScaleDummyEstimator {
110 pub fn new(strategy: LargeScaleStrategy) -> Self {
112 Self::with_config(strategy, LargeScaleConfig::default())
113 }
114
115 pub fn with_config(strategy: LargeScaleStrategy, config: LargeScaleConfig) -> Self {
117 Self {
118 strategy,
119 config,
120 state: RwLock::new(LargeScaleState {
121 sample_count: 0,
122 running_sum: 0.0,
123 running_sum_squares: 0.0,
124 reservoir: Vec::new(),
125 sketches: HashMap::new(),
126 node_statistics: HashMap::new(),
127 current_memory_usage: 0,
128 }),
129 }
130 }
131
132 pub fn fit_chunked(&self, x: &ArrayView2<f64>, y: &ArrayView1<f64>) -> Result<()> {
134 match &self.strategy {
135 LargeScaleStrategy::ChunkedProcessing {
136 chunk_size,
137 overlap,
138 } => self.process_chunked(x, y, *chunk_size, *overlap),
139 LargeScaleStrategy::MemoryMapped {
140 block_size,
141 prefetch_blocks,
142 } => self.process_memory_mapped(x, y, *block_size, *prefetch_blocks),
143 LargeScaleStrategy::ReservoirSampling {
144 reservoir_size,
145 replacement_rate,
146 } => self.process_reservoir_sampling(x, y, *reservoir_size, *replacement_rate),
147 LargeScaleStrategy::SketchBased {
148 sketch_size,
149 hash_functions,
150 } => self.process_sketch_based(x, y, *sketch_size, *hash_functions),
151 LargeScaleStrategy::Distributed {
152 node_id,
153 total_nodes,
154 coordinator_address,
155 } => self.process_distributed(x, y, *node_id, *total_nodes, coordinator_address),
156 }
157 }
158
159 fn process_chunked(
161 &self,
162 x: &ArrayView2<f64>,
163 y: &ArrayView1<f64>,
164 chunk_size: usize,
165 overlap: usize,
166 ) -> Result<()> {
167 let n_samples = x.nrows();
168 let mut start_idx = 0;
169
170 while start_idx < n_samples {
171 let end_idx = (start_idx + chunk_size).min(n_samples);
172 let chunk_x = x.slice(s![start_idx..end_idx, ..]);
173 let chunk_y = y.slice(s![start_idx..end_idx]);
174
175 self.update_statistics(&chunk_x, &chunk_y)?;
177
178 start_idx += chunk_size - overlap;
180 }
181
182 Ok(())
183 }
184
185 fn process_memory_mapped(
187 &self,
188 x: &ArrayView2<f64>,
189 y: &ArrayView1<f64>,
190 block_size: usize,
191 prefetch_blocks: usize,
192 ) -> Result<()> {
193 if !self.config.use_memory_mapping {
194 return self.process_chunked(x, y, block_size, 0);
195 }
196
197 let n_samples = x.nrows();
199 let n_features = x.ncols();
200
201 for block_start in (0..n_samples).step_by(block_size) {
203 let block_end = (block_start + block_size).min(n_samples);
204
205 let block_x = x.slice(s![block_start..block_end, ..]);
207 let block_y = y.slice(s![block_start..block_end]);
208
209 if block_end + prefetch_blocks * block_size < n_samples {
211 }
213
214 self.update_statistics(&block_x, &block_y)?;
215 }
216
217 Ok(())
218 }
219
220 fn process_reservoir_sampling(
222 &self,
223 x: &ArrayView2<f64>,
224 y: &ArrayView1<f64>,
225 reservoir_size: usize,
226 replacement_rate: f64,
227 ) -> Result<()> {
228 let mut state = self.state.write().unwrap();
229 let mut rng = thread_rng();
230
231 if state.reservoir.is_empty() {
233 state.reservoir.reserve(reservoir_size);
234 }
235
236 for &value in y.iter() {
237 state.sample_count += 1;
238
239 if state.reservoir.len() < reservoir_size {
240 state.reservoir.push(value);
242 } else {
243 let k = rng.gen_range(0..state.sample_count);
245 if k < reservoir_size {
246 state.reservoir[k] = value;
247 } else if rng.gen::<f64>() < replacement_rate {
248 let idx = rng.gen_range(0..reservoir_size);
250 state.reservoir[idx] = value;
251 }
252 }
253 }
254
255 if !state.reservoir.is_empty() {
257 state.running_sum = state.reservoir.iter().sum();
258 state.running_sum_squares = state.reservoir.iter().map(|&x| x * x).sum();
259 }
260
261 Ok(())
262 }
263
264 fn process_sketch_based(
266 &self,
267 x: &ArrayView2<f64>,
268 y: &ArrayView1<f64>,
269 sketch_size: usize,
270 hash_functions: usize,
271 ) -> Result<()> {
272 let mut state = self.state.write().unwrap();
273
274 for h in 0..hash_functions {
276 state
277 .sketches
278 .entry(h)
279 .or_insert_with(|| vec![0.0; sketch_size]);
280 }
281
282 for &value in y.iter() {
284 for h in 0..hash_functions {
285 let hash = self.hash_function(value, h) % sketch_size;
286 if let Some(sketch) = state.sketches.get_mut(&h) {
287 sketch[hash] += 1.0;
288 }
289 }
290 state.sample_count += 1;
291 }
292
293 self.estimate_from_sketches(&mut state, y)?;
295
296 Ok(())
297 }
298
299 fn process_distributed(
301 &self,
302 x: &ArrayView2<f64>,
303 y: &ArrayView1<f64>,
304 node_id: usize,
305 total_nodes: usize,
306 coordinator_address: &str,
307 ) -> Result<()> {
308 let mut state = self.state.write().unwrap();
309
310 let local_count = y.len();
312 let local_sum: f64 = y.iter().sum();
313 let local_mean = if local_count > 0 {
314 local_sum / local_count as f64
315 } else {
316 0.0
317 };
318 let local_variance = if local_count > 1 {
319 y.iter().map(|&x| (x - local_mean).powi(2)).sum::<f64>() / (local_count - 1) as f64
320 } else {
321 0.0
322 };
323
324 state.node_statistics.insert(
326 node_id,
327 NodeStatistics {
328 sample_count: local_count,
329 mean: local_mean,
330 variance: local_variance,
331 last_update: Instant::now(),
332 },
333 );
334
335 self.combine_distributed_statistics(&mut state)?;
338
339 Ok(())
340 }
341
342 fn update_statistics(&self, x: &ArrayView2<f64>, y: &ArrayView1<f64>) -> Result<()> {
344 let mut state = self.state.write().unwrap();
345
346 let chunk_count = y.len();
347 let chunk_sum: f64 = y.iter().sum();
348 let chunk_sum_squares: f64 = y.iter().map(|&x| x * x).sum();
349
350 let old_count = state.sample_count;
352 let new_count = old_count + chunk_count;
353
354 if new_count > 0 {
355 let delta = chunk_sum - state.running_sum * (chunk_count as f64 / old_count as f64);
356 state.running_sum += chunk_sum;
357 state.running_sum_squares += chunk_sum_squares;
358 state.sample_count = new_count;
359 }
360
361 state.current_memory_usage += std::mem::size_of_val(x) + std::mem::size_of_val(y);
363
364 Ok(())
365 }
366
367 fn hash_function(&self, value: f64, seed: usize) -> usize {
369 use std::collections::hash_map::DefaultHasher;
370 use std::hash::{Hash, Hasher};
371
372 let mut hasher = DefaultHasher::new();
373 value.to_bits().hash(&mut hasher);
374 seed.hash(&mut hasher);
375 hasher.finish() as usize
376 }
377
378 fn estimate_from_sketches(
380 &self,
381 state: &mut LargeScaleState,
382 y: &ArrayView1<f64>,
383 ) -> Result<()> {
384 if let Some(sketch_0) = state.sketches.get(&0) {
386 let total_frequency: f64 = sketch_0.iter().sum();
388 if total_frequency > 0.0 {
389 state.running_sum = y.iter().sum();
391 state.running_sum_squares = y.iter().map(|&x| x * x).sum();
392 }
393 }
394 Ok(())
395 }
396
397 fn combine_distributed_statistics(&self, state: &mut LargeScaleState) -> Result<()> {
399 let mut total_count = 0;
400 let mut weighted_sum = 0.0;
401 let mut weighted_sum_squares = 0.0;
402
403 for node_stats in state.node_statistics.values() {
404 total_count += node_stats.sample_count;
405 weighted_sum += node_stats.mean * node_stats.sample_count as f64;
406 weighted_sum_squares += (node_stats.variance + node_stats.mean * node_stats.mean)
407 * node_stats.sample_count as f64;
408 }
409
410 if total_count > 0 {
411 state.sample_count = total_count;
412 state.running_sum = weighted_sum;
413 state.running_sum_squares = weighted_sum_squares;
414 }
415
416 Ok(())
417 }
418
419 pub fn get_mean(&self) -> f64 {
421 let state = self.state.read().unwrap();
422 if state.sample_count > 0 {
423 state.running_sum / state.sample_count as f64
424 } else {
425 0.0
426 }
427 }
428
429 pub fn get_variance(&self) -> f64 {
431 let state = self.state.read().unwrap();
432 if state.sample_count > 1 {
433 let mean = state.running_sum / state.sample_count as f64;
434 let variance = state.running_sum_squares / state.sample_count as f64 - mean * mean;
435 variance * state.sample_count as f64 / (state.sample_count - 1) as f64
436 } else {
438 0.0
439 }
440 }
441
442 pub fn get_memory_usage(&self) -> usize {
444 self.state.read().unwrap().current_memory_usage
445 }
446
447 pub fn get_processing_stats(&self) -> ProcessingStats {
449 let state = self.state.read().unwrap();
450 ProcessingStats {
451 total_samples_processed: state.sample_count,
452 current_memory_usage: state.current_memory_usage,
453 max_memory_limit: self.config.max_memory_bytes,
454 reservoir_size: state.reservoir.len(),
455 sketch_count: state.sketches.len(),
456 distributed_nodes: state.node_statistics.len(),
457 }
458 }
459}
460
461#[derive(Debug, Clone)]
463pub struct ProcessingStats {
464 pub total_samples_processed: usize,
466 pub current_memory_usage: usize,
468 pub max_memory_limit: usize,
470 pub reservoir_size: usize,
472 pub sketch_count: usize,
474 pub distributed_nodes: usize,
476}
477
478pub struct StreamingBaselineUpdater {
480 count: usize,
482 mean: f64,
483 m2: f64, decay_factor: f64,
486 min_samples: usize,
488}
489
490impl StreamingBaselineUpdater {
491 pub fn new(decay_factor: f64, min_samples: usize) -> Self {
493 Self {
494 count: 0,
495 mean: 0.0,
496 m2: 0.0,
497 decay_factor,
498 min_samples,
499 }
500 }
501
502 pub fn update(&mut self, value: f64) {
504 self.count += 1;
505
506 if self.decay_factor < 1.0 && self.count > 1 {
507 let effective_count = (self.count as f64 * self.decay_factor).max(1.0);
509 let delta = value - self.mean;
510 self.mean += delta / effective_count;
511 let delta2 = value - self.mean;
512 self.m2 += delta * delta2;
513 } else {
514 let delta = value - self.mean;
516 self.mean += delta / self.count as f64;
517 let delta2 = value - self.mean;
518 self.m2 += delta * delta2;
519 }
520 }
521
522 pub fn mean(&self) -> f64 {
524 self.mean
525 }
526
527 pub fn variance(&self) -> f64 {
529 if self.count > 1 {
530 self.m2 / (self.count - 1) as f64
531 } else {
532 0.0
533 }
534 }
535
536 pub fn std_dev(&self) -> f64 {
538 self.variance().sqrt()
539 }
540
541 pub fn is_ready(&self) -> bool {
543 self.count >= self.min_samples
544 }
545
546 pub fn count(&self) -> usize {
548 self.count
549 }
550
551 pub fn reset(&mut self) {
553 self.count = 0;
554 self.mean = 0.0;
555 self.m2 = 0.0;
556 }
557
558 pub fn predict(&self) -> Result<f64> {
560 if !self.is_ready() {
561 return Err(SklearsError::InvalidInput(format!(
562 "Need at least {} samples before making predictions",
563 self.min_samples
564 )));
565 }
566 Ok(self.mean)
567 }
568
569 pub fn predict_with_confidence(&self, confidence_level: f64) -> Result<(f64, f64, f64)> {
571 if !self.is_ready() {
572 return Err(SklearsError::InvalidInput(format!(
573 "Need at least {} samples before making predictions",
574 self.min_samples
575 )));
576 }
577
578 let prediction = self.mean;
579 let std_err = self.std_dev() / (self.count as f64).sqrt();
580
581 let z_score = match confidence_level {
583 level if level >= 0.99 => 2.576,
584 level if level >= 0.95 => 1.96,
585 level if level >= 0.90 => 1.645,
586 _ => 1.0,
587 };
588
589 let margin = z_score * std_err;
590 Ok((prediction, prediction - margin, prediction + margin))
591 }
592}
593
594pub struct ApproximateBaseline {
596 method: ApproximateMethod,
597 error_bound: f64,
598 confidence_level: f64,
599}
600
601#[derive(Debug, Clone)]
602pub enum ApproximateMethod {
603 Bootstrap { n_samples: usize },
605 Stratified {
607 n_strata: usize,
608 samples_per_stratum: usize,
609 },
610 Systematic { sampling_interval: usize },
612 Cluster { n_clusters: usize },
614}
615
616impl ApproximateBaseline {
617 pub fn new(method: ApproximateMethod, error_bound: f64, confidence_level: f64) -> Result<Self> {
619 if !(0.0..=1.0).contains(&error_bound) {
620 return Err(SklearsError::InvalidInput(
621 "Error bound must be between 0 and 1".to_string(),
622 ));
623 }
624 if !(0.0..=1.0).contains(&confidence_level) {
625 return Err(SklearsError::InvalidInput(
626 "Confidence level must be between 0 and 1".to_string(),
627 ));
628 }
629
630 Ok(Self {
631 method,
632 error_bound,
633 confidence_level,
634 })
635 }
636
637 pub fn compute_approximate_stats(&self, y: &ArrayView1<f64>) -> Result<ApproximateStats> {
639 match &self.method {
640 ApproximateMethod::Bootstrap { n_samples } => self.bootstrap_stats(y, *n_samples),
641 ApproximateMethod::Stratified {
642 n_strata,
643 samples_per_stratum,
644 } => self.stratified_stats(y, *n_strata, *samples_per_stratum),
645 ApproximateMethod::Systematic { sampling_interval } => {
646 self.systematic_stats(y, *sampling_interval)
647 }
648 ApproximateMethod::Cluster { n_clusters } => self.cluster_stats(y, *n_clusters),
649 }
650 }
651
652 fn bootstrap_stats(&self, y: &ArrayView1<f64>, n_samples: usize) -> Result<ApproximateStats> {
654 let mut rng = thread_rng();
655 let total_samples = y.len();
656 let mut bootstrap_means = Vec::with_capacity(n_samples);
657 let sample_size = (total_samples as f64 * 0.632).ceil() as usize; for _ in 0..n_samples {
660 let mut sample_sum = 0.0;
661
662 for _ in 0..sample_size {
663 let idx = rng.gen_range(0..total_samples);
664 sample_sum += y[idx];
665 }
666
667 bootstrap_means.push(sample_sum / sample_size as f64);
668 }
669
670 let estimated_mean = bootstrap_means.iter().sum::<f64>() / bootstrap_means.len() as f64;
671 let estimated_variance = bootstrap_means
672 .iter()
673 .map(|&x| (x - estimated_mean).powi(2))
674 .sum::<f64>()
675 / (bootstrap_means.len() - 1) as f64;
676
677 Ok(ApproximateStats {
678 estimated_mean,
679 estimated_variance,
680 confidence_interval: self.compute_confidence_interval(
681 estimated_mean,
682 estimated_variance.sqrt(),
683 bootstrap_means.len(),
684 ),
685 sample_size_used: sample_size * n_samples,
686 method_info: format!("Bootstrap with {} resamples", n_samples),
687 })
688 }
689
690 fn stratified_stats(
692 &self,
693 y: &ArrayView1<f64>,
694 n_strata: usize,
695 samples_per_stratum: usize,
696 ) -> Result<ApproximateStats> {
697 let mut rng = thread_rng();
698 let total_samples = y.len();
699
700 let mut indexed_data: Vec<(usize, f64)> =
702 y.iter().enumerate().map(|(i, &v)| (i, v)).collect();
703 indexed_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
704
705 let stratum_size = total_samples / n_strata;
706 let mut stratum_means = Vec::new();
707 let mut total_sampled = 0;
708
709 for stratum in 0..n_strata {
710 let start = stratum * stratum_size;
711 let end = if stratum == n_strata - 1 {
712 total_samples
713 } else {
714 (stratum + 1) * stratum_size
715 };
716 let stratum_data = &indexed_data[start..end];
717
718 if stratum_data.is_empty() {
719 continue;
720 }
721
722 let actual_samples = samples_per_stratum.min(stratum_data.len());
723 let mut stratum_sum = 0.0;
724
725 for _ in 0..actual_samples {
726 let idx = rng.gen_range(0..stratum_data.len());
727 stratum_sum += stratum_data[idx].1;
728 total_sampled += 1;
729 }
730
731 stratum_means.push(stratum_sum / actual_samples as f64);
732 }
733
734 let estimated_mean = stratum_means.iter().sum::<f64>() / stratum_means.len() as f64;
735 let estimated_variance = if stratum_means.len() > 1 {
736 stratum_means
737 .iter()
738 .map(|&x| (x - estimated_mean).powi(2))
739 .sum::<f64>()
740 / (stratum_means.len() - 1) as f64
741 } else {
742 0.0
743 };
744
745 Ok(ApproximateStats {
746 estimated_mean,
747 estimated_variance,
748 confidence_interval: self.compute_confidence_interval(
749 estimated_mean,
750 estimated_variance.sqrt(),
751 stratum_means.len(),
752 ),
753 sample_size_used: total_sampled,
754 method_info: format!(
755 "Stratified sampling with {} strata, {} samples per stratum",
756 n_strata, samples_per_stratum
757 ),
758 })
759 }
760
761 fn systematic_stats(
763 &self,
764 y: &ArrayView1<f64>,
765 sampling_interval: usize,
766 ) -> Result<ApproximateStats> {
767 let mut rng = thread_rng();
768 let total_samples = y.len();
769
770 if sampling_interval >= total_samples {
771 return Err(SklearsError::InvalidInput(
772 "Sampling interval too large".to_string(),
773 ));
774 }
775
776 let start = rng.gen_range(0..sampling_interval);
777 let mut sample_sum = 0.0;
778 let mut sample_count = 0;
779
780 for i in (start..total_samples).step_by(sampling_interval) {
781 sample_sum += y[i];
782 sample_count += 1;
783 }
784
785 let estimated_mean = if sample_count > 0 {
786 sample_sum / sample_count as f64
787 } else {
788 0.0
789 };
790
791 let mut variance_sum = 0.0;
793 for i in (start..total_samples).step_by(sampling_interval) {
794 variance_sum += (y[i] - estimated_mean).powi(2);
795 }
796 let estimated_variance = if sample_count > 1 {
797 variance_sum / (sample_count - 1) as f64
798 } else {
799 0.0
800 };
801
802 Ok(ApproximateStats {
803 estimated_mean,
804 estimated_variance,
805 confidence_interval: self.compute_confidence_interval(
806 estimated_mean,
807 estimated_variance.sqrt(),
808 sample_count,
809 ),
810 sample_size_used: sample_count,
811 method_info: format!("Systematic sampling with interval {}", sampling_interval),
812 })
813 }
814
815 fn cluster_stats(&self, y: &ArrayView1<f64>, n_clusters: usize) -> Result<ApproximateStats> {
817 let mut rng = thread_rng();
818 let total_samples = y.len();
819 let cluster_size = total_samples / n_clusters;
820
821 if cluster_size == 0 {
822 return Err(SklearsError::InvalidInput(
823 "Too many clusters for dataset size".to_string(),
824 ));
825 }
826
827 let selected_clusters = n_clusters / 2; let mut cluster_means = Vec::new();
830 let mut total_sampled = 0;
831
832 for _ in 0..selected_clusters {
833 let cluster_id = rng.gen_range(0..n_clusters);
834 let start = cluster_id * cluster_size;
835 let end = if cluster_id == n_clusters - 1 {
836 total_samples
837 } else {
838 (cluster_id + 1) * cluster_size
839 };
840
841 let cluster_sum: f64 = y.slice(s![start..end]).iter().sum();
842 let cluster_mean = cluster_sum / (end - start) as f64;
843 cluster_means.push(cluster_mean);
844 total_sampled += end - start;
845 }
846
847 let estimated_mean = cluster_means.iter().sum::<f64>() / cluster_means.len() as f64;
848 let estimated_variance = if cluster_means.len() > 1 {
849 cluster_means
850 .iter()
851 .map(|&x| (x - estimated_mean).powi(2))
852 .sum::<f64>()
853 / (cluster_means.len() - 1) as f64
854 } else {
855 0.0
856 };
857
858 Ok(ApproximateStats {
859 estimated_mean,
860 estimated_variance,
861 confidence_interval: self.compute_confidence_interval(
862 estimated_mean,
863 estimated_variance.sqrt(),
864 cluster_means.len(),
865 ),
866 sample_size_used: total_sampled,
867 method_info: format!(
868 "Cluster sampling with {} selected from {} clusters",
869 selected_clusters, n_clusters
870 ),
871 })
872 }
873
874 fn compute_confidence_interval(&self, mean: f64, std_error: f64, n: usize) -> (f64, f64) {
876 let t_value = if n > 30 {
878 match self.confidence_level {
880 level if level >= 0.99 => 2.576,
881 level if level >= 0.95 => 1.96,
882 level if level >= 0.90 => 1.645,
883 _ => 1.0,
884 }
885 } else {
886 match self.confidence_level {
888 level if level >= 0.95 => 2.0,
889 _ => 1.5,
890 }
891 };
892
893 let margin = t_value * std_error / (n as f64).sqrt();
894 (mean - margin, mean + margin)
895 }
896}
897
898#[derive(Debug, Clone)]
900pub struct ApproximateStats {
901 pub estimated_mean: f64,
903 pub estimated_variance: f64,
905 pub confidence_interval: (f64, f64),
907 pub sample_size_used: usize,
909 pub method_info: String,
911}
912
913pub struct SamplingBasedBaseline {
915 sampling_rate: f64,
916 min_samples: usize,
917 max_samples: usize,
918 adaptive: bool,
919}
920
921impl SamplingBasedBaseline {
922 pub fn new(sampling_rate: f64, min_samples: usize, max_samples: usize) -> Result<Self> {
924 if !(0.0..=1.0).contains(&sampling_rate) {
925 return Err(SklearsError::InvalidInput(
926 "Sampling rate must be between 0 and 1".to_string(),
927 ));
928 }
929 if min_samples > max_samples {
930 return Err(SklearsError::InvalidInput(
931 "Min samples cannot exceed max samples".to_string(),
932 ));
933 }
934
935 Ok(Self {
936 sampling_rate,
937 min_samples,
938 max_samples,
939 adaptive: true,
940 })
941 }
942
943 pub fn compute_sampled_baseline(&self, y: &ArrayView1<f64>) -> Result<SampledBaselineResult> {
945 let total_samples = y.len();
946 let target_samples = (total_samples as f64 * self.sampling_rate) as usize;
947 let actual_samples =
948 target_samples.clamp(self.min_samples, self.max_samples.min(total_samples));
949
950 if actual_samples == 0 {
951 return Err(SklearsError::InvalidInput(
952 "No samples to process".to_string(),
953 ));
954 }
955
956 let mut rng = thread_rng();
958 let mut sample = Vec::with_capacity(actual_samples);
959
960 for (i, &value) in y.iter().enumerate() {
961 if sample.len() < actual_samples {
962 sample.push(value);
963 } else {
964 let j = rng.gen_range(0..i + 1);
965 if j < actual_samples {
966 sample[j] = value;
967 }
968 }
969 }
970
971 let mean = sample.iter().sum::<f64>() / sample.len() as f64;
973 let variance = if sample.len() > 1 {
974 sample.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (sample.len() - 1) as f64
975 } else {
976 0.0
977 };
978
979 let standard_error = variance.sqrt() / (sample.len() as f64).sqrt();
981 let confidence_95 = (mean - 1.96 * standard_error, mean + 1.96 * standard_error);
982
983 Ok(SampledBaselineResult {
984 mean,
985 variance,
986 standard_error,
987 confidence_interval: confidence_95,
988 sample_size: sample.len(),
989 total_size: total_samples,
990 sampling_efficiency: sample.len() as f64 / total_samples as f64,
991 })
992 }
993
994 pub fn adaptive_sample(&mut self, y: &ArrayView1<f64>) -> Result<SampledBaselineResult> {
996 if !self.adaptive {
997 return self.compute_sampled_baseline(y);
998 }
999
1000 let initial_result = self.compute_sampled_baseline(y)?;
1002
1003 let cv = initial_result.standard_error / initial_result.mean.abs();
1005
1006 if cv > 0.1 {
1007 self.sampling_rate = (self.sampling_rate * 1.5).min(1.0);
1009 } else if cv < 0.05 {
1010 self.sampling_rate = (self.sampling_rate * 0.8).max(0.01);
1012 }
1013
1014 self.compute_sampled_baseline(y)
1016 }
1017}
1018
1019#[derive(Debug, Clone)]
1021pub struct SampledBaselineResult {
1022 pub mean: f64,
1024 pub variance: f64,
1026 pub standard_error: f64,
1028 pub confidence_interval: (f64, f64),
1030 pub sample_size: usize,
1032 pub total_size: usize,
1034 pub sampling_efficiency: f64,
1036}
1037
1038#[allow(non_snake_case)]
1039#[cfg(test)]
1040mod tests {
1041 use super::*;
1042 use scirs2_core::ndarray::array;
1043
1044 #[test]
1045 fn test_large_scale_chunked_processing() {
1046 let x = Array2::from_shape_vec((1000, 5), (0..5000).map(|i| i as f64).collect()).unwrap();
1047 let y = Array1::from_shape_vec(1000, (0..1000).map(|i| (i % 10) as f64).collect()).unwrap();
1048
1049 let estimator = LargeScaleDummyEstimator::new(LargeScaleStrategy::ChunkedProcessing {
1050 chunk_size: 100,
1051 overlap: 10,
1052 });
1053
1054 let result = estimator.fit_chunked(&x.view(), &y.view());
1055 assert!(result.is_ok());
1056
1057 let mean = estimator.get_mean();
1058 assert!(mean >= 0.0 && mean <= 10.0);
1059 }
1060
1061 #[test]
1062 fn test_streaming_baseline_updater() {
1063 let mut updater = StreamingBaselineUpdater::new(0.95, 5);
1064
1065 for value in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] {
1067 updater.update(value);
1068 }
1069
1070 assert!(updater.is_ready());
1071 assert!((updater.mean() - 3.5).abs() < 0.1);
1072 assert!(updater.variance() > 0.0);
1073 }
1074
1075 #[test]
1076 fn test_reservoir_sampling() {
1077 let x = Array2::zeros((1000, 5));
1078 let y = Array1::from_shape_vec(1000, (0..1000).map(|i| i as f64).collect()).unwrap();
1079
1080 let estimator = LargeScaleDummyEstimator::new(LargeScaleStrategy::ReservoirSampling {
1081 reservoir_size: 100,
1082 replacement_rate: 0.1,
1083 });
1084
1085 let result = estimator.fit_chunked(&x.view(), &y.view());
1086 assert!(result.is_ok());
1087
1088 let stats = estimator.get_processing_stats();
1089 assert_eq!(stats.total_samples_processed, 1000);
1090 assert_eq!(stats.reservoir_size, 100);
1091 }
1092
1093 #[test]
1094 fn test_approximate_bootstrap() {
1095 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1096
1097 let approx =
1098 ApproximateBaseline::new(ApproximateMethod::Bootstrap { n_samples: 50 }, 0.05, 0.95)
1099 .unwrap();
1100
1101 let result = approx.compute_approximate_stats(&y.view());
1102 assert!(result.is_ok());
1103
1104 let stats = result.unwrap();
1105 assert!(stats.estimated_mean > 0.0);
1106 assert!(stats.estimated_variance >= 0.0);
1107 assert!(stats.confidence_interval.0 < stats.confidence_interval.1);
1108 }
1109
1110 #[test]
1111 fn test_sampling_based_baseline() {
1112 let y = Array1::from_shape_vec(1000, (0..1000).map(|i| i as f64).collect()).unwrap();
1113
1114 let baseline = SamplingBasedBaseline::new(0.1, 50, 200).unwrap();
1115 let result = baseline.compute_sampled_baseline(&y.view());
1116
1117 assert!(result.is_ok());
1118 let stats = result.unwrap();
1119 assert!(stats.sample_size >= 50 && stats.sample_size <= 200);
1120 assert!(stats.sampling_efficiency > 0.0 && stats.sampling_efficiency <= 1.0);
1121 }
1122
1123 #[test]
1124 fn test_distributed_processing() {
1125 let x = Array2::zeros((100, 3));
1126 let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
1127
1128 let estimator = LargeScaleDummyEstimator::new(LargeScaleStrategy::Distributed {
1129 node_id: 0,
1130 total_nodes: 3,
1131 coordinator_address: "localhost:8080".to_string(),
1132 });
1133
1134 let result = estimator.fit_chunked(&x.slice(s![..5, ..]).view(), &y.view());
1135 assert!(result.is_ok());
1136
1137 let stats = estimator.get_processing_stats();
1138 assert_eq!(stats.distributed_nodes, 1);
1139 }
1140
1141 #[test]
1142 fn test_sketch_based_processing() {
1143 let x = Array2::zeros((200, 4));
1144 let y = Array1::from_shape_vec(200, (0..200).map(|i| (i % 20) as f64).collect()).unwrap();
1145
1146 let estimator = LargeScaleDummyEstimator::new(LargeScaleStrategy::SketchBased {
1147 sketch_size: 32,
1148 hash_functions: 4,
1149 });
1150
1151 let result = estimator.fit_chunked(&x.view(), &y.view());
1152 assert!(result.is_ok());
1153
1154 let stats = estimator.get_processing_stats();
1155 assert_eq!(stats.sketch_count, 4);
1156 }
1157}