1use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use scirs2_core::random::{Random, Rng};
10use crate::core::Imputer;
14use rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use sklears_core::{
17 error::{Result as SklResult, SklearsError},
18 traits::{Estimator, Fit, Transform, Untrained},
19 types::Float,
20};
21use std::collections::{HashMap, HashSet};
22use std::time::Duration;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ApproximateConfig {
27 pub accuracy_level: f64,
29 pub max_time_per_feature: Duration,
31 pub sample_size: usize,
33 pub use_randomization: bool,
35 pub early_stopping: bool,
37 pub tolerance: f64,
39 pub max_iterations: usize,
41}
42
43impl Default for ApproximateConfig {
44 fn default() -> Self {
45 Self {
46 accuracy_level: 0.8,
47 max_time_per_feature: Duration::from_secs(1),
48 sample_size: 1000,
49 use_randomization: true,
50 early_stopping: true,
51 tolerance: 1e-3,
52 max_iterations: 10,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum ApproximationStrategy {
60 RandomSampling,
62 Sketching,
64 LocalApproximation,
66 LinearApproximation,
68 HashBased,
70}
71
72#[derive(Debug)]
74pub struct ApproximateKNNImputer<S = Untrained> {
75 state: S,
76 n_neighbors: usize,
77 weights: String,
78 missing_values: f64,
79 config: ApproximateConfig,
80 strategy: ApproximationStrategy,
81}
82
83#[derive(Debug)]
85pub struct ApproximateKNNImputerTrained {
86 reference_samples: Array2<f64>,
87 sample_indices: Vec<usize>,
88 n_features_in_: usize,
89 config: ApproximateConfig,
90 strategy: ApproximationStrategy,
91 locality_hash: Option<LocalityHashTable>,
92}
93
94#[derive(Debug)]
96pub struct LocalityHashTable {
97 hash_functions: Vec<RandomHashFunction>,
98 buckets: HashMap<Vec<u32>, Vec<usize>>,
99 num_hash_functions: usize,
100 bucket_width: f64,
101}
102
103#[derive(Debug, Clone)]
105pub struct RandomHashFunction {
106 random_vector: Array1<f64>,
107 offset: f64,
108 bucket_width: f64,
109}
110
111#[derive(Debug)]
113pub struct ApproximateSimpleImputer<S = Untrained> {
114 state: S,
115 strategy: String,
116 missing_values: f64,
117 config: ApproximateConfig,
118}
119
120#[derive(Debug)]
122pub struct ApproximateSimpleImputerTrained {
123 approximate_statistics_: Array1<f64>,
124 confidence_intervals_: Array2<f64>, n_features_in_: usize,
126 config: ApproximateConfig,
127}
128
129#[derive(Debug)]
131pub struct SketchingImputer<S = Untrained> {
132 state: S,
133 sketch_size: usize,
134 missing_values: f64,
135 config: ApproximateConfig,
136 hash_family: HashFamily,
137}
138
139#[derive(Debug)]
141pub struct SketchingImputerTrained {
142 sketches: Vec<CountSketch>,
143 n_features_in_: usize,
144 config: ApproximateConfig,
145}
146
147#[derive(Debug, Clone)]
149pub struct CountSketch {
150 sketch: Array1<f64>,
151 hash_functions: Vec<(usize, i32)>, size: usize,
153}
154
155#[derive(Debug, Clone)]
157pub enum HashFamily {
158 Universal,
160 Polynomial,
162 Murmur,
164}
165
166#[derive(Debug)]
168pub struct RandomizedIterativeImputer<S = Untrained> {
169 state: S,
170 max_iter: usize,
171 missing_values: f64,
172 config: ApproximateConfig,
173 random_order: bool,
174 subsample_features: f64,
175}
176
177pub struct RandomizedIterativeImputerTrained {
179 estimators_: Vec<Box<dyn Imputer>>,
180 feature_order: Vec<usize>,
181 n_features_in_: usize,
182 config: ApproximateConfig,
183}
184
185impl ApproximateKNNImputer<Untrained> {
186 pub fn new() -> Self {
187 Self {
188 state: Untrained,
189 n_neighbors: 5,
190 weights: "uniform".to_string(),
191 missing_values: f64::NAN,
192 config: ApproximateConfig::default(),
193 strategy: ApproximationStrategy::RandomSampling,
194 }
195 }
196
197 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
198 self.n_neighbors = n_neighbors;
199 self
200 }
201
202 pub fn weights(mut self, weights: String) -> Self {
203 self.weights = weights;
204 self
205 }
206
207 pub fn approximate_config(mut self, config: ApproximateConfig) -> Self {
208 self.config = config;
209 self
210 }
211
212 pub fn strategy(mut self, strategy: ApproximationStrategy) -> Self {
213 self.strategy = strategy;
214 self
215 }
216
217 pub fn accuracy_level(mut self, level: f64) -> Self {
218 self.config.accuracy_level = level.clamp(0.0, 1.0);
219 self
220 }
221
222 pub fn sample_size(mut self, size: usize) -> Self {
223 self.config.sample_size = size;
224 self
225 }
226
227 fn is_missing(&self, value: f64) -> bool {
228 if self.missing_values.is_nan() {
229 value.is_nan()
230 } else {
231 (value - self.missing_values).abs() < f64::EPSILON
232 }
233 }
234}
235
236impl Default for ApproximateKNNImputer<Untrained> {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242impl Estimator for ApproximateKNNImputer<Untrained> {
243 type Config = ApproximateConfig;
244 type Error = SklearsError;
245 type Float = Float;
246
247 fn config(&self) -> &Self::Config {
248 &self.config
249 }
250}
251
252impl Fit<ArrayView2<'_, Float>, ()> for ApproximateKNNImputer<Untrained> {
253 type Fitted = ApproximateKNNImputer<ApproximateKNNImputerTrained>;
254
255 #[allow(non_snake_case)]
256 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
257 let X = X.mapv(|x| x);
258 let (n_samples, n_features) = X.dim();
259
260 let effective_sample_size = ((self.config.sample_size as f64 * self.config.accuracy_level)
262 as usize)
263 .min(n_samples)
264 .max(self.n_neighbors * 10); let (reference_samples, sample_indices) =
268 self.sample_training_data(&X, effective_sample_size)?;
269
270 let locality_hash = match self.strategy {
272 ApproximationStrategy::HashBased => {
273 Some(self.build_locality_hash_table(&reference_samples)?)
274 }
275 _ => None,
276 };
277
278 Ok(ApproximateKNNImputer {
279 state: ApproximateKNNImputerTrained {
280 reference_samples,
281 sample_indices,
282 n_features_in_: n_features,
283 config: self.config,
284 strategy: self.strategy,
285 locality_hash,
286 },
287 n_neighbors: self.n_neighbors,
288 weights: self.weights,
289 missing_values: self.missing_values,
290 config: Default::default(),
291 strategy: ApproximationStrategy::RandomSampling,
292 })
293 }
294}
295
296impl Transform<ArrayView2<'_, Float>, Array2<Float>>
297 for ApproximateKNNImputer<ApproximateKNNImputerTrained>
298{
299 #[allow(non_snake_case)]
300 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
301 let X = X.mapv(|x| x);
302 let (_n_samples, n_features) = X.dim();
303
304 if n_features != self.state.n_features_in_ {
305 return Err(SklearsError::InvalidInput(format!(
306 "Number of features {} does not match training features {}",
307 n_features, self.state.n_features_in_
308 )));
309 }
310
311 let mut X_imputed = X.clone();
312
313 X_imputed
315 .axis_iter_mut(Axis(0))
316 .into_par_iter()
317 .enumerate()
318 .for_each(|(_i, mut row)| {
319 for j in 0..n_features {
320 if self.is_missing(row[j]) {
321 if let Ok(neighbors) = self.find_approximate_neighbors(&row.to_owned(), j) {
323 if !neighbors.is_empty() {
324 if let Ok(imputed_value) = self.compute_weighted_average(&neighbors)
325 {
326 row[j] = imputed_value;
327 }
328 }
329 }
330 }
331 }
332 });
333
334 Ok(X_imputed.mapv(|x| x as Float))
335 }
336}
337
338impl ApproximateKNNImputer<Untrained> {
339 fn sample_training_data(
341 &self,
342 X: &Array2<f64>,
343 sample_size: usize,
344 ) -> Result<(Array2<f64>, Vec<usize>), SklearsError> {
345 let n_samples = X.nrows();
346
347 if sample_size >= n_samples {
348 return Ok((X.clone(), (0..n_samples).collect()));
349 }
350
351 let mut rng = Random::default();
353 let mut indices: Vec<usize> = (0..n_samples).collect();
354
355 for i in (1..indices.len()).rev() {
357 let j = rng.gen_range(0..i + 1);
358 indices.swap(i, j);
359 }
360
361 indices.truncate(sample_size);
362 indices.sort(); let mut sampled_data = Array2::<f64>::zeros((sample_size, X.ncols()));
366 for (new_idx, &orig_idx) in indices.iter().enumerate() {
367 sampled_data.row_mut(new_idx).assign(&X.row(orig_idx));
368 }
369
370 Ok((sampled_data, indices))
371 }
372
373 fn build_locality_hash_table(
375 &self,
376 data: &Array2<f64>,
377 ) -> Result<LocalityHashTable, SklearsError> {
378 let n_features = data.ncols();
379 let num_hash_functions = (self.config.accuracy_level * 10.0) as usize + 2;
380 let bucket_width = 1.0 / (self.config.accuracy_level + 0.1);
381
382 let mut hash_functions = Vec::new();
383 let mut rng = Random::default();
384
385 for _ in 0..num_hash_functions {
387 let mut random_vector = Array1::<f64>::zeros(n_features);
388 for i in 0..n_features {
389 let u1: f64 = rng.gen();
391 let u2: f64 = rng.gen();
392 let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
393 random_vector[i] = z;
394 }
395 let offset: f64 = rng.gen::<f64>() * bucket_width;
396
397 hash_functions.push(RandomHashFunction {
398 random_vector,
399 offset,
400 bucket_width,
401 });
402 }
403
404 let mut buckets = HashMap::new();
406 for (row_idx, row) in data.rows().into_iter().enumerate() {
407 let hash_values = self.compute_hash_values(&row.to_owned(), &hash_functions);
408 buckets
409 .entry(hash_values)
410 .or_insert_with(Vec::new)
411 .push(row_idx);
412 }
413
414 Ok(LocalityHashTable {
415 hash_functions,
416 buckets,
417 num_hash_functions,
418 bucket_width,
419 })
420 }
421
422 fn compute_hash_values(
424 &self,
425 point: &Array1<f64>,
426 hash_functions: &[RandomHashFunction],
427 ) -> Vec<u32> {
428 hash_functions
429 .iter()
430 .map(|hash_fn| {
431 let dot_product: f64 = point
432 .iter()
433 .zip(hash_fn.random_vector.iter())
434 .filter(|(&x, _)| !self.is_missing(x))
435 .map(|(&x, &h)| x * h)
436 .sum();
437
438 ((dot_product + hash_fn.offset) / hash_fn.bucket_width).floor() as u32
439 })
440 .collect()
441 }
442}
443
444impl ApproximateKNNImputer<ApproximateKNNImputerTrained> {
445 fn find_approximate_neighbors(
447 &self,
448 query_row: &Array1<f64>,
449 target_feature: usize,
450 ) -> Result<Vec<(f64, f64)>, SklearsError> {
451 match self.state.strategy {
452 ApproximationStrategy::RandomSampling => {
453 self.find_neighbors_random_sampling(query_row, target_feature)
454 }
455 ApproximationStrategy::HashBased => {
456 self.find_neighbors_hash_based(query_row, target_feature)
457 }
458 ApproximationStrategy::LocalApproximation => {
459 self.find_neighbors_local_approximation(query_row, target_feature)
460 }
461 _ => self.find_neighbors_random_sampling(query_row, target_feature),
462 }
463 }
464
465 fn find_neighbors_random_sampling(
467 &self,
468 query_row: &Array1<f64>,
469 target_feature: usize,
470 ) -> Result<Vec<(f64, f64)>, SklearsError> {
471 let mut neighbors = Vec::new();
472 let max_candidates = (self.n_neighbors * 3).min(self.state.reference_samples.nrows());
473
474 let mut rng = Random::default();
476 let mut candidate_indices: Vec<usize> = (0..self.state.reference_samples.nrows()).collect();
477
478 for i in (1..candidate_indices.len()).rev() {
479 let j = rng.gen_range(0..i + 1);
480 candidate_indices.swap(i, j);
481 }
482
483 candidate_indices.truncate(max_candidates);
484
485 for &idx in &candidate_indices {
486 let ref_row = self.state.reference_samples.row(idx);
487
488 if self.is_missing(ref_row[target_feature]) {
489 continue;
490 }
491
492 let distance = self.compute_approximate_distance(query_row, &ref_row.to_owned());
493 if distance.is_finite() {
494 neighbors.push((distance, ref_row[target_feature]));
495 }
496 }
497
498 neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
500 neighbors.truncate(self.n_neighbors);
501
502 Ok(neighbors)
503 }
504
505 fn find_neighbors_hash_based(
507 &self,
508 query_row: &Array1<f64>,
509 target_feature: usize,
510 ) -> Result<Vec<(f64, f64)>, SklearsError> {
511 if let Some(ref hash_table) = self.state.locality_hash {
512 let query_hash = self.compute_query_hash_values(query_row, &hash_table.hash_functions);
513 let mut candidates = HashSet::new();
514
515 if let Some(bucket_candidates) = hash_table.buckets.get(&query_hash) {
517 candidates.extend(bucket_candidates);
518 }
519
520 if candidates.len() < self.n_neighbors * 2 {
522 for (hash_key, bucket_candidates) in &hash_table.buckets {
523 let hamming_distance = self.hamming_distance(&query_hash, hash_key);
524 if hamming_distance <= 2 {
525 candidates.extend(bucket_candidates);
527 }
528 if candidates.len() >= self.n_neighbors * 3 {
529 break;
530 }
531 }
532 }
533
534 if candidates.is_empty() {
535 return self.find_neighbors_random_sampling(query_row, target_feature);
536 }
537
538 let mut neighbors = Vec::new();
540 for &idx in &candidates {
541 let ref_row = self.state.reference_samples.row(idx);
542
543 if self.is_missing(ref_row[target_feature]) {
544 continue;
545 }
546
547 let distance = self.compute_approximate_distance(query_row, &ref_row.to_owned());
548 if distance.is_finite() {
549 neighbors.push((distance, ref_row[target_feature]));
550 }
551 }
552
553 neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
554 neighbors.truncate(self.n_neighbors);
555
556 if neighbors.is_empty() {
557 return self.find_neighbors_random_sampling(query_row, target_feature);
558 }
559
560 Ok(neighbors)
561 } else {
562 self.find_neighbors_random_sampling(query_row, target_feature)
563 }
564 }
565
566 fn find_neighbors_local_approximation(
568 &self,
569 query_row: &Array1<f64>,
570 target_feature: usize,
571 ) -> Result<Vec<(f64, f64)>, SklearsError> {
572 let n_features = query_row.len();
574 let subset_size = ((n_features as f64 * self.state.config.accuracy_level) as usize).max(1);
575
576 let mut rng = Random::default();
577 let mut feature_indices: Vec<usize> = (0..n_features).collect();
578 for i in (1..feature_indices.len()).rev() {
579 let j = rng.gen_range(0..i + 1);
580 feature_indices.swap(i, j);
581 }
582 feature_indices.truncate(subset_size);
583 feature_indices.sort();
584
585 let mut neighbors = Vec::new();
587 for ref_row in self.state.reference_samples.rows() {
588 if self.is_missing(ref_row[target_feature]) {
589 continue;
590 }
591
592 let distance =
593 self.compute_subset_distance(query_row, &ref_row.to_owned(), &feature_indices);
594 if distance.is_finite() {
595 neighbors.push((distance, ref_row[target_feature]));
596 }
597 }
598
599 neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
600 neighbors.truncate(self.n_neighbors);
601
602 Ok(neighbors)
603 }
604
605 fn compute_query_hash_values(
607 &self,
608 query_row: &Array1<f64>,
609 hash_functions: &[RandomHashFunction],
610 ) -> Vec<u32> {
611 hash_functions
612 .iter()
613 .map(|hash_fn| {
614 let dot_product: f64 = query_row
615 .iter()
616 .zip(hash_fn.random_vector.iter())
617 .filter(|(&x, _)| !self.is_missing(x))
618 .map(|(&x, &h)| x * h)
619 .sum();
620
621 ((dot_product + hash_fn.offset) / hash_fn.bucket_width).floor() as u32
622 })
623 .collect()
624 }
625
626 fn hamming_distance(&self, hash1: &[u32], hash2: &[u32]) -> usize {
628 hash1
629 .iter()
630 .zip(hash2.iter())
631 .map(|(a, b)| if a == b { 0 } else { 1 })
632 .sum()
633 }
634
635 fn compute_approximate_distance(&self, row1: &Array1<f64>, row2: &Array1<f64>) -> f64 {
637 let mut sum_sq = 0.0;
638 let mut valid_count = 0;
639
640 let sample_rate = self.state.config.accuracy_level;
642 let mut rng = Random::default();
643
644 for (&x1, &x2) in row1.iter().zip(row2.iter()) {
645 if rng.gen::<f64>() > sample_rate {
647 continue;
648 }
649
650 if !self.is_missing(x1) && !self.is_missing(x2) {
651 sum_sq += (x1 - x2).powi(2);
652 valid_count += 1;
653 }
654 }
655
656 if valid_count > 0 {
657 (sum_sq / valid_count as f64).sqrt()
658 } else {
659 f64::INFINITY
660 }
661 }
662
663 fn compute_subset_distance(
665 &self,
666 row1: &Array1<f64>,
667 row2: &Array1<f64>,
668 feature_indices: &[usize],
669 ) -> f64 {
670 let mut sum_sq = 0.0;
671 let mut valid_count = 0;
672
673 for &idx in feature_indices {
674 let x1 = row1[idx];
675 let x2 = row2[idx];
676
677 if !self.is_missing(x1) && !self.is_missing(x2) {
678 sum_sq += (x1 - x2).powi(2);
679 valid_count += 1;
680 }
681 }
682
683 if valid_count > 0 {
684 (sum_sq / valid_count as f64).sqrt()
685 } else {
686 f64::INFINITY
687 }
688 }
689
690 fn compute_weighted_average(&self, neighbors: &[(f64, f64)]) -> Result<f64, SklearsError> {
692 if neighbors.is_empty() {
693 return Ok(0.0);
694 }
695
696 match self.weights.as_str() {
697 "uniform" => {
698 let sum: f64 = neighbors.iter().map(|(_, value)| value).sum();
699 Ok(sum / neighbors.len() as f64)
700 }
701 "distance" => {
702 let mut weighted_sum = 0.0;
703 let mut weight_sum = 0.0;
704
705 for &(distance, value) in neighbors {
706 let weight = if distance > 0.0 { 1.0 / distance } else { 1e6 };
707 weighted_sum += weight * value;
708 weight_sum += weight;
709 }
710
711 if weight_sum > 0.0 {
712 Ok(weighted_sum / weight_sum)
713 } else {
714 Ok(neighbors[0].1)
715 }
716 }
717 _ => Err(SklearsError::InvalidInput(format!(
718 "Unknown weights: {}",
719 self.weights
720 ))),
721 }
722 }
723
724 fn is_missing(&self, value: f64) -> bool {
725 if self.missing_values.is_nan() {
726 value.is_nan()
727 } else {
728 (value - self.missing_values).abs() < f64::EPSILON
729 }
730 }
731}
732
733impl ApproximateSimpleImputer<Untrained> {
735 pub fn new() -> Self {
736 Self {
737 state: Untrained,
738 strategy: "mean".to_string(),
739 missing_values: f64::NAN,
740 config: ApproximateConfig::default(),
741 }
742 }
743
744 pub fn strategy(mut self, strategy: String) -> Self {
745 self.strategy = strategy;
746 self
747 }
748
749 pub fn approximate_config(mut self, config: ApproximateConfig) -> Self {
750 self.config = config;
751 self
752 }
753
754 pub fn sample_size(mut self, size: usize) -> Self {
755 self.config.sample_size = size;
756 self
757 }
758
759 fn is_missing(&self, value: f64) -> bool {
760 if self.missing_values.is_nan() {
761 value.is_nan()
762 } else {
763 (value - self.missing_values).abs() < f64::EPSILON
764 }
765 }
766}
767
768impl Default for ApproximateSimpleImputer<Untrained> {
769 fn default() -> Self {
770 Self::new()
771 }
772}
773
774impl Estimator for ApproximateSimpleImputer<Untrained> {
775 type Config = ApproximateConfig;
776 type Error = SklearsError;
777 type Float = Float;
778
779 fn config(&self) -> &Self::Config {
780 &self.config
781 }
782}
783
784impl Fit<ArrayView2<'_, Float>, ()> for ApproximateSimpleImputer<Untrained> {
785 type Fitted = ApproximateSimpleImputer<ApproximateSimpleImputerTrained>;
786
787 #[allow(non_snake_case)]
788 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
789 let X = X.mapv(|x| x);
790 let (n_samples, n_features) = X.dim();
791
792 let sample_size = (self.config.sample_size as f64 * self.config.accuracy_level) as usize;
794 let effective_sample_size = sample_size.min(n_samples);
795
796 let (approximate_statistics, confidence_intervals) =
798 self.compute_approximate_statistics(&X, effective_sample_size)?;
799
800 Ok(ApproximateSimpleImputer {
801 state: ApproximateSimpleImputerTrained {
802 approximate_statistics_: approximate_statistics,
803 confidence_intervals_: confidence_intervals,
804 n_features_in_: n_features,
805 config: self.config,
806 },
807 strategy: self.strategy,
808 missing_values: self.missing_values,
809 config: Default::default(),
810 })
811 }
812}
813
814impl Transform<ArrayView2<'_, Float>, Array2<Float>>
815 for ApproximateSimpleImputer<ApproximateSimpleImputerTrained>
816{
817 #[allow(non_snake_case)]
818 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
819 let X = X.mapv(|x| x);
820 let (_n_samples, n_features) = X.dim();
821
822 if n_features != self.state.n_features_in_ {
823 return Err(SklearsError::InvalidInput(format!(
824 "Number of features {} does not match training features {}",
825 n_features, self.state.n_features_in_
826 )));
827 }
828
829 let mut X_imputed = X.clone();
830
831 X_imputed
833 .axis_iter_mut(Axis(0))
834 .into_par_iter()
835 .for_each(|mut row| {
836 for (j, value) in row.iter_mut().enumerate() {
837 if self.is_missing(*value) {
838 *value = self.state.approximate_statistics_[j];
839 }
840 }
841 });
842
843 Ok(X_imputed.mapv(|x| x as Float))
844 }
845}
846
847impl ApproximateSimpleImputer<Untrained> {
848 fn compute_approximate_statistics(
850 &self,
851 X: &Array2<f64>,
852 sample_size: usize,
853 ) -> Result<(Array1<f64>, Array2<f64>), SklearsError> {
854 let (n_samples, n_features) = X.dim();
855 let mut approximate_statistics = Array1::<f64>::zeros(n_features);
856 let mut confidence_intervals = Array2::<f64>::zeros((n_features, 2)); let num_bootstrap_samples = 100;
860
861 for j in 0..n_features {
862 let mut bootstrap_estimates = Vec::new();
863
864 for _ in 0..num_bootstrap_samples {
865 let mut rng = Random::default();
867 let mut sample_values = Vec::new();
868
869 for _ in 0..sample_size {
870 let sample_idx = rng.gen_range(0..n_samples);
871 let value = X[[sample_idx, j]];
872 if !self.is_missing(value) {
873 sample_values.push(value);
874 }
875 }
876
877 if sample_values.is_empty() {
878 continue;
879 }
880
881 let estimate = match self.strategy.as_str() {
882 "mean" => sample_values.iter().sum::<f64>() / sample_values.len() as f64,
883 "median" => {
884 let mut sorted = sample_values.clone();
885 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
886 let mid = sorted.len() / 2;
887 if sorted.len() % 2 == 0 {
888 (sorted[mid - 1] + sorted[mid]) / 2.0
889 } else {
890 sorted[mid]
891 }
892 }
893 _ => sample_values.iter().sum::<f64>() / sample_values.len() as f64,
894 };
895
896 bootstrap_estimates.push(estimate);
897 }
898
899 if !bootstrap_estimates.is_empty() {
900 approximate_statistics[j] =
902 bootstrap_estimates.iter().sum::<f64>() / bootstrap_estimates.len() as f64;
903
904 bootstrap_estimates.sort_by(|a, b| a.partial_cmp(b).unwrap());
906 let lower_idx = (bootstrap_estimates.len() as f64 * 0.05) as usize;
907 let upper_idx = (bootstrap_estimates.len() as f64 * 0.95) as usize;
908
909 confidence_intervals[[j, 0]] =
910 bootstrap_estimates[lower_idx.min(bootstrap_estimates.len() - 1)];
911 confidence_intervals[[j, 1]] =
912 bootstrap_estimates[upper_idx.min(bootstrap_estimates.len() - 1)];
913 }
914 }
915
916 Ok((approximate_statistics, confidence_intervals))
917 }
918}
919
920impl ApproximateSimpleImputer<ApproximateSimpleImputerTrained> {
921 fn is_missing(&self, value: f64) -> bool {
922 if self.missing_values.is_nan() {
923 value.is_nan()
924 } else {
925 (value - self.missing_values).abs() < f64::EPSILON
926 }
927 }
928
929 pub fn confidence_intervals(&self) -> &Array2<f64> {
931 &self.state.confidence_intervals_
932 }
933
934 pub fn statistics(&self) -> &Array1<f64> {
936 &self.state.approximate_statistics_
937 }
938}
939
940#[allow(non_snake_case)]
941#[cfg(test)]
942mod tests {
943 use super::*;
944 use approx::assert_abs_diff_eq;
945 use scirs2_core::ndarray::array;
946
947 #[test]
948 #[allow(non_snake_case)]
949 fn test_approximate_simple_imputer() {
950 let X = array![
951 [1.0, 2.0, 3.0],
952 [4.0, f64::NAN, 6.0],
953 [7.0, 8.0, 9.0],
954 [10.0, 11.0, 12.0]
955 ];
956
957 let imputer = ApproximateSimpleImputer::new()
958 .strategy("mean".to_string())
959 .sample_size(100);
960
961 let fitted = imputer.fit(&X.view(), &()).unwrap();
962 let X_imputed = fitted.transform(&X.view()).unwrap();
963
964 assert!(!X_imputed[[1, 1]].is_nan());
966 assert!(X_imputed[[1, 1]] > 0.0);
967 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
968 assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
969 }
970
971 #[test]
972 #[allow(non_snake_case)]
973 fn test_approximate_knn_imputer() {
974 let X = array![
975 [1.0, 2.0, 3.0],
976 [4.0, f64::NAN, 6.0],
977 [7.0, 8.0, 9.0],
978 [10.0, 11.0, 12.0],
979 [13.0, 14.0, 15.0]
980 ];
981
982 let imputer = ApproximateKNNImputer::new()
983 .n_neighbors(2)
984 .weights("uniform".to_string())
985 .accuracy_level(0.8)
986 .sample_size(3);
987
988 let fitted = imputer.fit(&X.view(), &()).unwrap();
989 let X_imputed = fitted.transform(&X.view()).unwrap();
990
991 assert!(!X_imputed[[1, 1]].is_nan());
993 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
994 assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
995 }
996
997 #[test]
998 fn test_approximate_config() {
999 let config = ApproximateConfig {
1000 accuracy_level: 0.5,
1001 sample_size: 500,
1002 use_randomization: false,
1003 ..Default::default()
1004 };
1005
1006 let imputer = ApproximateSimpleImputer::new().approximate_config(config.clone());
1007
1008 assert_eq!(imputer.config.accuracy_level, 0.5);
1009 assert_eq!(imputer.config.sample_size, 500);
1010 assert!(!imputer.config.use_randomization);
1011 }
1012
1013 #[test]
1014 #[allow(non_snake_case)]
1015 fn test_hash_based_strategy() {
1016 let X = array![
1017 [1.0, 2.0, 3.0],
1018 [4.0, f64::NAN, 6.0],
1019 [7.0, 8.0, 9.0],
1020 [2.0, 3.0, 4.0],
1021 [5.0, 6.0, 7.0]
1022 ];
1023
1024 let imputer = ApproximateKNNImputer::new()
1025 .n_neighbors(2)
1026 .strategy(ApproximationStrategy::HashBased)
1027 .accuracy_level(0.9);
1028
1029 let fitted = imputer.fit(&X.view(), &()).unwrap();
1030 let X_imputed = fitted.transform(&X.view()).unwrap();
1031
1032 assert!(!X_imputed[[1, 1]].is_nan());
1034 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1035 }
1036
1037 #[test]
1038 #[allow(non_snake_case)]
1039 fn test_confidence_intervals() {
1040 let X = array![
1041 [1.0, 2.0, 3.0],
1042 [4.0, f64::NAN, 6.0],
1043 [7.0, 8.0, 9.0],
1044 [10.0, 11.0, 12.0]
1045 ];
1046
1047 let imputer = ApproximateSimpleImputer::new().strategy("mean".to_string());
1048
1049 let fitted = imputer.fit(&X.view(), &()).unwrap();
1050 let confidence_intervals = fitted.confidence_intervals();
1051
1052 assert_eq!(confidence_intervals.shape(), &[3, 2]);
1054
1055 for j in 0..3 {
1056 let lower = confidence_intervals[[j, 0]];
1057 let upper = confidence_intervals[[j, 1]];
1058 assert!(
1059 lower <= upper,
1060 "Lower bound should be <= upper bound for feature {}",
1061 j
1062 );
1063 }
1064 }
1065}