sklears_impute/
approximate.rs

1//! Approximate imputation algorithms for fast processing
2//!
3//! This module provides fast approximation methods for imputation when speed
4//! is more important than perfect accuracy. These methods trade off some
5//! accuracy for significant performance gains.
6
7// ✅ SciRS2 Policy compliant imports
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use scirs2_core::random::{Random, Rng};
10// use scirs2_core::simd::{SimdOps}; // Note: SimdArray and auto_vectorize not available
11// use scirs2_core::parallel::{}; // Note: ParallelExecutor, ChunkStrategy not available
12
13use crate::core::Imputer;
14use rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use sklears_core::{
17    error::{Result as SklResult, SklearsError},
18    traits::{Estimator, Fit, Transform, Untrained},
19    types::Float,
20};
21use std::collections::{HashMap, HashSet};
22use std::time::Duration;
23
24/// Configuration for approximate imputation algorithms
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ApproximateConfig {
27    /// Target accuracy vs speed trade-off (0.0 = fastest, 1.0 = most accurate)
28    pub accuracy_level: f64,
29    /// Maximum processing time per feature (in seconds)
30    pub max_time_per_feature: Duration,
31    /// Sample size for approximation algorithms
32    pub sample_size: usize,
33    /// Use randomized algorithms
34    pub use_randomization: bool,
35    /// Enable early stopping
36    pub early_stopping: bool,
37    /// Convergence tolerance for iterative methods
38    pub tolerance: f64,
39    /// Maximum number of iterations
40    pub max_iterations: usize,
41}
42
43impl Default for ApproximateConfig {
44    fn default() -> Self {
45        Self {
46            accuracy_level: 0.8,
47            max_time_per_feature: Duration::from_secs(1),
48            sample_size: 1000,
49            use_randomization: true,
50            early_stopping: true,
51            tolerance: 1e-3,
52            max_iterations: 10,
53        }
54    }
55}
56
57/// Fast approximation strategies
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum ApproximationStrategy {
60    /// Random sampling approximation
61    RandomSampling,
62    /// Sketching-based approximation
63    Sketching,
64    /// Local approximation using nearest chunks
65    LocalApproximation,
66    /// Linear approximation
67    LinearApproximation,
68    /// Hash-based approximation
69    HashBased,
70}
71
72/// Approximate KNN Imputer with fast neighbor search
73#[derive(Debug)]
74pub struct ApproximateKNNImputer<S = Untrained> {
75    state: S,
76    n_neighbors: usize,
77    weights: String,
78    missing_values: f64,
79    config: ApproximateConfig,
80    strategy: ApproximationStrategy,
81}
82
83/// Trained state for approximate KNN imputer
84#[derive(Debug)]
85pub struct ApproximateKNNImputerTrained {
86    reference_samples: Array2<f64>,
87    sample_indices: Vec<usize>,
88    n_features_in_: usize,
89    config: ApproximateConfig,
90    strategy: ApproximationStrategy,
91    locality_hash: Option<LocalityHashTable>,
92}
93
94/// Locality-sensitive hash table for fast neighbor search
95#[derive(Debug)]
96pub struct LocalityHashTable {
97    hash_functions: Vec<RandomHashFunction>,
98    buckets: HashMap<Vec<u32>, Vec<usize>>,
99    num_hash_functions: usize,
100    bucket_width: f64,
101}
102
103/// Random hash function for LSH
104#[derive(Debug, Clone)]
105pub struct RandomHashFunction {
106    random_vector: Array1<f64>,
107    offset: f64,
108    bucket_width: f64,
109}
110
111/// Approximate Simple Imputer with sampling
112#[derive(Debug)]
113pub struct ApproximateSimpleImputer<S = Untrained> {
114    state: S,
115    strategy: String,
116    missing_values: f64,
117    config: ApproximateConfig,
118}
119
120/// Trained state for approximate simple imputer
121#[derive(Debug)]
122pub struct ApproximateSimpleImputerTrained {
123    approximate_statistics_: Array1<f64>,
124    confidence_intervals_: Array2<f64>, // [feature, (lower, upper)]
125    n_features_in_: usize,
126    config: ApproximateConfig,
127}
128
129/// Sketching-based Imputer
130#[derive(Debug)]
131pub struct SketchingImputer<S = Untrained> {
132    state: S,
133    sketch_size: usize,
134    missing_values: f64,
135    config: ApproximateConfig,
136    hash_family: HashFamily,
137}
138
139/// Trained state for sketching imputer
140#[derive(Debug)]
141pub struct SketchingImputerTrained {
142    sketches: Vec<CountSketch>,
143    n_features_in_: usize,
144    config: ApproximateConfig,
145}
146
147/// Count sketch data structure
148#[derive(Debug, Clone)]
149pub struct CountSketch {
150    sketch: Array1<f64>,
151    hash_functions: Vec<(usize, i32)>, // (hash_function_index, sign)
152    size: usize,
153}
154
155/// Hash family for sketching
156#[derive(Debug, Clone)]
157pub enum HashFamily {
158    /// Universal hash family
159    Universal,
160    /// Polynomial hash family
161    Polynomial,
162    /// MurmurHash family
163    Murmur,
164}
165
166/// Randomized Iterative Imputer
167#[derive(Debug)]
168pub struct RandomizedIterativeImputer<S = Untrained> {
169    state: S,
170    max_iter: usize,
171    missing_values: f64,
172    config: ApproximateConfig,
173    random_order: bool,
174    subsample_features: f64,
175}
176
177/// Trained state for randomized iterative imputer
178pub struct RandomizedIterativeImputerTrained {
179    estimators_: Vec<Box<dyn Imputer>>,
180    feature_order: Vec<usize>,
181    n_features_in_: usize,
182    config: ApproximateConfig,
183}
184
185impl ApproximateKNNImputer<Untrained> {
186    pub fn new() -> Self {
187        Self {
188            state: Untrained,
189            n_neighbors: 5,
190            weights: "uniform".to_string(),
191            missing_values: f64::NAN,
192            config: ApproximateConfig::default(),
193            strategy: ApproximationStrategy::RandomSampling,
194        }
195    }
196
197    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
198        self.n_neighbors = n_neighbors;
199        self
200    }
201
202    pub fn weights(mut self, weights: String) -> Self {
203        self.weights = weights;
204        self
205    }
206
207    pub fn approximate_config(mut self, config: ApproximateConfig) -> Self {
208        self.config = config;
209        self
210    }
211
212    pub fn strategy(mut self, strategy: ApproximationStrategy) -> Self {
213        self.strategy = strategy;
214        self
215    }
216
217    pub fn accuracy_level(mut self, level: f64) -> Self {
218        self.config.accuracy_level = level.clamp(0.0, 1.0);
219        self
220    }
221
222    pub fn sample_size(mut self, size: usize) -> Self {
223        self.config.sample_size = size;
224        self
225    }
226
227    fn is_missing(&self, value: f64) -> bool {
228        if self.missing_values.is_nan() {
229            value.is_nan()
230        } else {
231            (value - self.missing_values).abs() < f64::EPSILON
232        }
233    }
234}
235
236impl Default for ApproximateKNNImputer<Untrained> {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl Estimator for ApproximateKNNImputer<Untrained> {
243    type Config = ApproximateConfig;
244    type Error = SklearsError;
245    type Float = Float;
246
247    fn config(&self) -> &Self::Config {
248        &self.config
249    }
250}
251
252impl Fit<ArrayView2<'_, Float>, ()> for ApproximateKNNImputer<Untrained> {
253    type Fitted = ApproximateKNNImputer<ApproximateKNNImputerTrained>;
254
255    #[allow(non_snake_case)]
256    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
257        let X = X.mapv(|x| x);
258        let (n_samples, n_features) = X.dim();
259
260        // Determine sample size based on accuracy level
261        let effective_sample_size = ((self.config.sample_size as f64 * self.config.accuracy_level)
262            as usize)
263            .min(n_samples)
264            .max(self.n_neighbors * 10); // Ensure minimum samples
265
266        // Sample training data for approximation
267        let (reference_samples, sample_indices) =
268            self.sample_training_data(&X, effective_sample_size)?;
269
270        // Build locality hash table if using hash-based strategy
271        let locality_hash = match self.strategy {
272            ApproximationStrategy::HashBased => {
273                Some(self.build_locality_hash_table(&reference_samples)?)
274            }
275            _ => None,
276        };
277
278        Ok(ApproximateKNNImputer {
279            state: ApproximateKNNImputerTrained {
280                reference_samples,
281                sample_indices,
282                n_features_in_: n_features,
283                config: self.config,
284                strategy: self.strategy,
285                locality_hash,
286            },
287            n_neighbors: self.n_neighbors,
288            weights: self.weights,
289            missing_values: self.missing_values,
290            config: Default::default(),
291            strategy: ApproximationStrategy::RandomSampling,
292        })
293    }
294}
295
296impl Transform<ArrayView2<'_, Float>, Array2<Float>>
297    for ApproximateKNNImputer<ApproximateKNNImputerTrained>
298{
299    #[allow(non_snake_case)]
300    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
301        let X = X.mapv(|x| x);
302        let (_n_samples, n_features) = X.dim();
303
304        if n_features != self.state.n_features_in_ {
305            return Err(SklearsError::InvalidInput(format!(
306                "Number of features {} does not match training features {}",
307                n_features, self.state.n_features_in_
308            )));
309        }
310
311        let mut X_imputed = X.clone();
312
313        // Process samples in parallel
314        X_imputed
315            .axis_iter_mut(Axis(0))
316            .into_par_iter()
317            .enumerate()
318            .for_each(|(_i, mut row)| {
319                for j in 0..n_features {
320                    if self.is_missing(row[j]) {
321                        // Find approximate neighbors
322                        if let Ok(neighbors) = self.find_approximate_neighbors(&row.to_owned(), j) {
323                            if !neighbors.is_empty() {
324                                if let Ok(imputed_value) = self.compute_weighted_average(&neighbors)
325                                {
326                                    row[j] = imputed_value;
327                                }
328                            }
329                        }
330                    }
331                }
332            });
333
334        Ok(X_imputed.mapv(|x| x as Float))
335    }
336}
337
338impl ApproximateKNNImputer<Untrained> {
339    /// Sample training data for approximation
340    fn sample_training_data(
341        &self,
342        X: &Array2<f64>,
343        sample_size: usize,
344    ) -> Result<(Array2<f64>, Vec<usize>), SklearsError> {
345        let n_samples = X.nrows();
346
347        if sample_size >= n_samples {
348            return Ok((X.clone(), (0..n_samples).collect()));
349        }
350
351        // Create random sample indices
352        let mut rng = Random::default();
353        let mut indices: Vec<usize> = (0..n_samples).collect();
354
355        // Fisher-Yates shuffle for random sampling
356        for i in (1..indices.len()).rev() {
357            let j = rng.gen_range(0..i + 1);
358            indices.swap(i, j);
359        }
360
361        indices.truncate(sample_size);
362        indices.sort(); // Keep sorted for consistent results
363
364        // Extract sampled rows
365        let mut sampled_data = Array2::<f64>::zeros((sample_size, X.ncols()));
366        for (new_idx, &orig_idx) in indices.iter().enumerate() {
367            sampled_data.row_mut(new_idx).assign(&X.row(orig_idx));
368        }
369
370        Ok((sampled_data, indices))
371    }
372
373    /// Build locality-sensitive hash table
374    fn build_locality_hash_table(
375        &self,
376        data: &Array2<f64>,
377    ) -> Result<LocalityHashTable, SklearsError> {
378        let n_features = data.ncols();
379        let num_hash_functions = (self.config.accuracy_level * 10.0) as usize + 2;
380        let bucket_width = 1.0 / (self.config.accuracy_level + 0.1);
381
382        let mut hash_functions = Vec::new();
383        let mut rng = Random::default();
384
385        // Create random hash functions
386        for _ in 0..num_hash_functions {
387            let mut random_vector = Array1::<f64>::zeros(n_features);
388            for i in 0..n_features {
389                // Generate standard normal using Box-Muller transform
390                let u1: f64 = rng.gen();
391                let u2: f64 = rng.gen();
392                let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
393                random_vector[i] = z;
394            }
395            let offset: f64 = rng.gen::<f64>() * bucket_width;
396
397            hash_functions.push(RandomHashFunction {
398                random_vector,
399                offset,
400                bucket_width,
401            });
402        }
403
404        // Hash all data points
405        let mut buckets = HashMap::new();
406        for (row_idx, row) in data.rows().into_iter().enumerate() {
407            let hash_values = self.compute_hash_values(&row.to_owned(), &hash_functions);
408            buckets
409                .entry(hash_values)
410                .or_insert_with(Vec::new)
411                .push(row_idx);
412        }
413
414        Ok(LocalityHashTable {
415            hash_functions,
416            buckets,
417            num_hash_functions,
418            bucket_width,
419        })
420    }
421
422    /// Compute hash values for a data point
423    fn compute_hash_values(
424        &self,
425        point: &Array1<f64>,
426        hash_functions: &[RandomHashFunction],
427    ) -> Vec<u32> {
428        hash_functions
429            .iter()
430            .map(|hash_fn| {
431                let dot_product: f64 = point
432                    .iter()
433                    .zip(hash_fn.random_vector.iter())
434                    .filter(|(&x, _)| !self.is_missing(x))
435                    .map(|(&x, &h)| x * h)
436                    .sum();
437
438                ((dot_product + hash_fn.offset) / hash_fn.bucket_width).floor() as u32
439            })
440            .collect()
441    }
442}
443
444impl ApproximateKNNImputer<ApproximateKNNImputerTrained> {
445    /// Find approximate neighbors for a query point
446    fn find_approximate_neighbors(
447        &self,
448        query_row: &Array1<f64>,
449        target_feature: usize,
450    ) -> Result<Vec<(f64, f64)>, SklearsError> {
451        match self.state.strategy {
452            ApproximationStrategy::RandomSampling => {
453                self.find_neighbors_random_sampling(query_row, target_feature)
454            }
455            ApproximationStrategy::HashBased => {
456                self.find_neighbors_hash_based(query_row, target_feature)
457            }
458            ApproximationStrategy::LocalApproximation => {
459                self.find_neighbors_local_approximation(query_row, target_feature)
460            }
461            _ => self.find_neighbors_random_sampling(query_row, target_feature),
462        }
463    }
464
465    /// Find neighbors using random sampling
466    fn find_neighbors_random_sampling(
467        &self,
468        query_row: &Array1<f64>,
469        target_feature: usize,
470    ) -> Result<Vec<(f64, f64)>, SklearsError> {
471        let mut neighbors = Vec::new();
472        let max_candidates = (self.n_neighbors * 3).min(self.state.reference_samples.nrows());
473
474        // Randomly sample candidates for distance computation
475        let mut rng = Random::default();
476        let mut candidate_indices: Vec<usize> = (0..self.state.reference_samples.nrows()).collect();
477
478        for i in (1..candidate_indices.len()).rev() {
479            let j = rng.gen_range(0..i + 1);
480            candidate_indices.swap(i, j);
481        }
482
483        candidate_indices.truncate(max_candidates);
484
485        for &idx in &candidate_indices {
486            let ref_row = self.state.reference_samples.row(idx);
487
488            if self.is_missing(ref_row[target_feature]) {
489                continue;
490            }
491
492            let distance = self.compute_approximate_distance(query_row, &ref_row.to_owned());
493            if distance.is_finite() {
494                neighbors.push((distance, ref_row[target_feature]));
495            }
496        }
497
498        // Sort by distance and take k nearest
499        neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
500        neighbors.truncate(self.n_neighbors);
501
502        Ok(neighbors)
503    }
504
505    /// Find neighbors using hash-based approach
506    fn find_neighbors_hash_based(
507        &self,
508        query_row: &Array1<f64>,
509        target_feature: usize,
510    ) -> Result<Vec<(f64, f64)>, SklearsError> {
511        if let Some(ref hash_table) = self.state.locality_hash {
512            let query_hash = self.compute_query_hash_values(query_row, &hash_table.hash_functions);
513            let mut candidates = HashSet::new();
514
515            // Get candidates from the same bucket
516            if let Some(bucket_candidates) = hash_table.buckets.get(&query_hash) {
517                candidates.extend(bucket_candidates);
518            }
519
520            // If not enough candidates, check neighboring buckets
521            if candidates.len() < self.n_neighbors * 2 {
522                for (hash_key, bucket_candidates) in &hash_table.buckets {
523                    let hamming_distance = self.hamming_distance(&query_hash, hash_key);
524                    if hamming_distance <= 2 {
525                        // Allow some hash collisions
526                        candidates.extend(bucket_candidates);
527                    }
528                    if candidates.len() >= self.n_neighbors * 3 {
529                        break;
530                    }
531                }
532            }
533
534            if candidates.is_empty() {
535                return self.find_neighbors_random_sampling(query_row, target_feature);
536            }
537
538            // Compute distances to candidates
539            let mut neighbors = Vec::new();
540            for &idx in &candidates {
541                let ref_row = self.state.reference_samples.row(idx);
542
543                if self.is_missing(ref_row[target_feature]) {
544                    continue;
545                }
546
547                let distance = self.compute_approximate_distance(query_row, &ref_row.to_owned());
548                if distance.is_finite() {
549                    neighbors.push((distance, ref_row[target_feature]));
550                }
551            }
552
553            neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
554            neighbors.truncate(self.n_neighbors);
555
556            if neighbors.is_empty() {
557                return self.find_neighbors_random_sampling(query_row, target_feature);
558            }
559
560            Ok(neighbors)
561        } else {
562            self.find_neighbors_random_sampling(query_row, target_feature)
563        }
564    }
565
566    /// Find neighbors using local approximation
567    fn find_neighbors_local_approximation(
568        &self,
569        query_row: &Array1<f64>,
570        target_feature: usize,
571    ) -> Result<Vec<(f64, f64)>, SklearsError> {
572        // Use a subset of features for distance computation
573        let n_features = query_row.len();
574        let subset_size = ((n_features as f64 * self.state.config.accuracy_level) as usize).max(1);
575
576        let mut rng = Random::default();
577        let mut feature_indices: Vec<usize> = (0..n_features).collect();
578        for i in (1..feature_indices.len()).rev() {
579            let j = rng.gen_range(0..i + 1);
580            feature_indices.swap(i, j);
581        }
582        feature_indices.truncate(subset_size);
583        feature_indices.sort();
584
585        // Compute distances using subset of features
586        let mut neighbors = Vec::new();
587        for ref_row in self.state.reference_samples.rows() {
588            if self.is_missing(ref_row[target_feature]) {
589                continue;
590            }
591
592            let distance =
593                self.compute_subset_distance(query_row, &ref_row.to_owned(), &feature_indices);
594            if distance.is_finite() {
595                neighbors.push((distance, ref_row[target_feature]));
596            }
597        }
598
599        neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
600        neighbors.truncate(self.n_neighbors);
601
602        Ok(neighbors)
603    }
604
605    /// Compute hash values for query point
606    fn compute_query_hash_values(
607        &self,
608        query_row: &Array1<f64>,
609        hash_functions: &[RandomHashFunction],
610    ) -> Vec<u32> {
611        hash_functions
612            .iter()
613            .map(|hash_fn| {
614                let dot_product: f64 = query_row
615                    .iter()
616                    .zip(hash_fn.random_vector.iter())
617                    .filter(|(&x, _)| !self.is_missing(x))
618                    .map(|(&x, &h)| x * h)
619                    .sum();
620
621                ((dot_product + hash_fn.offset) / hash_fn.bucket_width).floor() as u32
622            })
623            .collect()
624    }
625
626    /// Compute Hamming distance between hash values
627    fn hamming_distance(&self, hash1: &[u32], hash2: &[u32]) -> usize {
628        hash1
629            .iter()
630            .zip(hash2.iter())
631            .map(|(a, b)| if a == b { 0 } else { 1 })
632            .sum()
633    }
634
635    /// Compute approximate distance (using fewer features)
636    fn compute_approximate_distance(&self, row1: &Array1<f64>, row2: &Array1<f64>) -> f64 {
637        let mut sum_sq = 0.0;
638        let mut valid_count = 0;
639
640        // Use sampling to reduce computation
641        let sample_rate = self.state.config.accuracy_level;
642        let mut rng = Random::default();
643
644        for (&x1, &x2) in row1.iter().zip(row2.iter()) {
645            // Skip some features based on sampling rate
646            if rng.gen::<f64>() > sample_rate {
647                continue;
648            }
649
650            if !self.is_missing(x1) && !self.is_missing(x2) {
651                sum_sq += (x1 - x2).powi(2);
652                valid_count += 1;
653            }
654        }
655
656        if valid_count > 0 {
657            (sum_sq / valid_count as f64).sqrt()
658        } else {
659            f64::INFINITY
660        }
661    }
662
663    /// Compute distance using subset of features
664    fn compute_subset_distance(
665        &self,
666        row1: &Array1<f64>,
667        row2: &Array1<f64>,
668        feature_indices: &[usize],
669    ) -> f64 {
670        let mut sum_sq = 0.0;
671        let mut valid_count = 0;
672
673        for &idx in feature_indices {
674            let x1 = row1[idx];
675            let x2 = row2[idx];
676
677            if !self.is_missing(x1) && !self.is_missing(x2) {
678                sum_sq += (x1 - x2).powi(2);
679                valid_count += 1;
680            }
681        }
682
683        if valid_count > 0 {
684            (sum_sq / valid_count as f64).sqrt()
685        } else {
686            f64::INFINITY
687        }
688    }
689
690    /// Compute weighted average of neighbor values
691    fn compute_weighted_average(&self, neighbors: &[(f64, f64)]) -> Result<f64, SklearsError> {
692        if neighbors.is_empty() {
693            return Ok(0.0);
694        }
695
696        match self.weights.as_str() {
697            "uniform" => {
698                let sum: f64 = neighbors.iter().map(|(_, value)| value).sum();
699                Ok(sum / neighbors.len() as f64)
700            }
701            "distance" => {
702                let mut weighted_sum = 0.0;
703                let mut weight_sum = 0.0;
704
705                for &(distance, value) in neighbors {
706                    let weight = if distance > 0.0 { 1.0 / distance } else { 1e6 };
707                    weighted_sum += weight * value;
708                    weight_sum += weight;
709                }
710
711                if weight_sum > 0.0 {
712                    Ok(weighted_sum / weight_sum)
713                } else {
714                    Ok(neighbors[0].1)
715                }
716            }
717            _ => Err(SklearsError::InvalidInput(format!(
718                "Unknown weights: {}",
719                self.weights
720            ))),
721        }
722    }
723
724    fn is_missing(&self, value: f64) -> bool {
725        if self.missing_values.is_nan() {
726            value.is_nan()
727        } else {
728            (value - self.missing_values).abs() < f64::EPSILON
729        }
730    }
731}
732
733// Implement Approximate Simple Imputer
734impl ApproximateSimpleImputer<Untrained> {
735    pub fn new() -> Self {
736        Self {
737            state: Untrained,
738            strategy: "mean".to_string(),
739            missing_values: f64::NAN,
740            config: ApproximateConfig::default(),
741        }
742    }
743
744    pub fn strategy(mut self, strategy: String) -> Self {
745        self.strategy = strategy;
746        self
747    }
748
749    pub fn approximate_config(mut self, config: ApproximateConfig) -> Self {
750        self.config = config;
751        self
752    }
753
754    pub fn sample_size(mut self, size: usize) -> Self {
755        self.config.sample_size = size;
756        self
757    }
758
759    fn is_missing(&self, value: f64) -> bool {
760        if self.missing_values.is_nan() {
761            value.is_nan()
762        } else {
763            (value - self.missing_values).abs() < f64::EPSILON
764        }
765    }
766}
767
768impl Default for ApproximateSimpleImputer<Untrained> {
769    fn default() -> Self {
770        Self::new()
771    }
772}
773
774impl Estimator for ApproximateSimpleImputer<Untrained> {
775    type Config = ApproximateConfig;
776    type Error = SklearsError;
777    type Float = Float;
778
779    fn config(&self) -> &Self::Config {
780        &self.config
781    }
782}
783
784impl Fit<ArrayView2<'_, Float>, ()> for ApproximateSimpleImputer<Untrained> {
785    type Fitted = ApproximateSimpleImputer<ApproximateSimpleImputerTrained>;
786
787    #[allow(non_snake_case)]
788    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
789        let X = X.mapv(|x| x);
790        let (n_samples, n_features) = X.dim();
791
792        // Determine sample size for approximation
793        let sample_size = (self.config.sample_size as f64 * self.config.accuracy_level) as usize;
794        let effective_sample_size = sample_size.min(n_samples);
795
796        // Compute approximate statistics using sampling
797        let (approximate_statistics, confidence_intervals) =
798            self.compute_approximate_statistics(&X, effective_sample_size)?;
799
800        Ok(ApproximateSimpleImputer {
801            state: ApproximateSimpleImputerTrained {
802                approximate_statistics_: approximate_statistics,
803                confidence_intervals_: confidence_intervals,
804                n_features_in_: n_features,
805                config: self.config,
806            },
807            strategy: self.strategy,
808            missing_values: self.missing_values,
809            config: Default::default(),
810        })
811    }
812}
813
814impl Transform<ArrayView2<'_, Float>, Array2<Float>>
815    for ApproximateSimpleImputer<ApproximateSimpleImputerTrained>
816{
817    #[allow(non_snake_case)]
818    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
819        let X = X.mapv(|x| x);
820        let (_n_samples, n_features) = X.dim();
821
822        if n_features != self.state.n_features_in_ {
823            return Err(SklearsError::InvalidInput(format!(
824                "Number of features {} does not match training features {}",
825                n_features, self.state.n_features_in_
826            )));
827        }
828
829        let mut X_imputed = X.clone();
830
831        // Apply imputation in parallel
832        X_imputed
833            .axis_iter_mut(Axis(0))
834            .into_par_iter()
835            .for_each(|mut row| {
836                for (j, value) in row.iter_mut().enumerate() {
837                    if self.is_missing(*value) {
838                        *value = self.state.approximate_statistics_[j];
839                    }
840                }
841            });
842
843        Ok(X_imputed.mapv(|x| x as Float))
844    }
845}
846
847impl ApproximateSimpleImputer<Untrained> {
848    /// Compute approximate statistics using sampling
849    fn compute_approximate_statistics(
850        &self,
851        X: &Array2<f64>,
852        sample_size: usize,
853    ) -> Result<(Array1<f64>, Array2<f64>), SklearsError> {
854        let (n_samples, n_features) = X.dim();
855        let mut approximate_statistics = Array1::<f64>::zeros(n_features);
856        let mut confidence_intervals = Array2::<f64>::zeros((n_features, 2)); // [lower, upper]
857
858        // Use bootstrap sampling for confidence intervals
859        let num_bootstrap_samples = 100;
860
861        for j in 0..n_features {
862            let mut bootstrap_estimates = Vec::new();
863
864            for _ in 0..num_bootstrap_samples {
865                // Sample with replacement
866                let mut rng = Random::default();
867                let mut sample_values = Vec::new();
868
869                for _ in 0..sample_size {
870                    let sample_idx = rng.gen_range(0..n_samples);
871                    let value = X[[sample_idx, j]];
872                    if !self.is_missing(value) {
873                        sample_values.push(value);
874                    }
875                }
876
877                if sample_values.is_empty() {
878                    continue;
879                }
880
881                let estimate = match self.strategy.as_str() {
882                    "mean" => sample_values.iter().sum::<f64>() / sample_values.len() as f64,
883                    "median" => {
884                        let mut sorted = sample_values.clone();
885                        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
886                        let mid = sorted.len() / 2;
887                        if sorted.len() % 2 == 0 {
888                            (sorted[mid - 1] + sorted[mid]) / 2.0
889                        } else {
890                            sorted[mid]
891                        }
892                    }
893                    _ => sample_values.iter().sum::<f64>() / sample_values.len() as f64,
894                };
895
896                bootstrap_estimates.push(estimate);
897            }
898
899            if !bootstrap_estimates.is_empty() {
900                // Main estimate
901                approximate_statistics[j] =
902                    bootstrap_estimates.iter().sum::<f64>() / bootstrap_estimates.len() as f64;
903
904                // Confidence interval (5th and 95th percentiles)
905                bootstrap_estimates.sort_by(|a, b| a.partial_cmp(b).unwrap());
906                let lower_idx = (bootstrap_estimates.len() as f64 * 0.05) as usize;
907                let upper_idx = (bootstrap_estimates.len() as f64 * 0.95) as usize;
908
909                confidence_intervals[[j, 0]] =
910                    bootstrap_estimates[lower_idx.min(bootstrap_estimates.len() - 1)];
911                confidence_intervals[[j, 1]] =
912                    bootstrap_estimates[upper_idx.min(bootstrap_estimates.len() - 1)];
913            }
914        }
915
916        Ok((approximate_statistics, confidence_intervals))
917    }
918}
919
920impl ApproximateSimpleImputer<ApproximateSimpleImputerTrained> {
921    fn is_missing(&self, value: f64) -> bool {
922        if self.missing_values.is_nan() {
923            value.is_nan()
924        } else {
925            (value - self.missing_values).abs() < f64::EPSILON
926        }
927    }
928
929    /// Get confidence intervals for imputed values
930    pub fn confidence_intervals(&self) -> &Array2<f64> {
931        &self.state.confidence_intervals_
932    }
933
934    /// Get approximate statistics
935    pub fn statistics(&self) -> &Array1<f64> {
936        &self.state.approximate_statistics_
937    }
938}
939
940#[allow(non_snake_case)]
941#[cfg(test)]
942mod tests {
943    use super::*;
944    use approx::assert_abs_diff_eq;
945    use scirs2_core::ndarray::array;
946
947    #[test]
948    #[allow(non_snake_case)]
949    fn test_approximate_simple_imputer() {
950        let X = array![
951            [1.0, 2.0, 3.0],
952            [4.0, f64::NAN, 6.0],
953            [7.0, 8.0, 9.0],
954            [10.0, 11.0, 12.0]
955        ];
956
957        let imputer = ApproximateSimpleImputer::new()
958            .strategy("mean".to_string())
959            .sample_size(100);
960
961        let fitted = imputer.fit(&X.view(), &()).unwrap();
962        let X_imputed = fitted.transform(&X.view()).unwrap();
963
964        // Check that NaN was replaced (value should be reasonable)
965        assert!(!X_imputed[[1, 1]].is_nan());
966        assert!(X_imputed[[1, 1]] > 0.0);
967        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
968        assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
969    }
970
971    #[test]
972    #[allow(non_snake_case)]
973    fn test_approximate_knn_imputer() {
974        let X = array![
975            [1.0, 2.0, 3.0],
976            [4.0, f64::NAN, 6.0],
977            [7.0, 8.0, 9.0],
978            [10.0, 11.0, 12.0],
979            [13.0, 14.0, 15.0]
980        ];
981
982        let imputer = ApproximateKNNImputer::new()
983            .n_neighbors(2)
984            .weights("uniform".to_string())
985            .accuracy_level(0.8)
986            .sample_size(3);
987
988        let fitted = imputer.fit(&X.view(), &()).unwrap();
989        let X_imputed = fitted.transform(&X.view()).unwrap();
990
991        // Verify that missing value was imputed
992        assert!(!X_imputed[[1, 1]].is_nan());
993        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
994        assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
995    }
996
997    #[test]
998    fn test_approximate_config() {
999        let config = ApproximateConfig {
1000            accuracy_level: 0.5,
1001            sample_size: 500,
1002            use_randomization: false,
1003            ..Default::default()
1004        };
1005
1006        let imputer = ApproximateSimpleImputer::new().approximate_config(config.clone());
1007
1008        assert_eq!(imputer.config.accuracy_level, 0.5);
1009        assert_eq!(imputer.config.sample_size, 500);
1010        assert!(!imputer.config.use_randomization);
1011    }
1012
1013    #[test]
1014    #[allow(non_snake_case)]
1015    fn test_hash_based_strategy() {
1016        let X = array![
1017            [1.0, 2.0, 3.0],
1018            [4.0, f64::NAN, 6.0],
1019            [7.0, 8.0, 9.0],
1020            [2.0, 3.0, 4.0],
1021            [5.0, 6.0, 7.0]
1022        ];
1023
1024        let imputer = ApproximateKNNImputer::new()
1025            .n_neighbors(2)
1026            .strategy(ApproximationStrategy::HashBased)
1027            .accuracy_level(0.9);
1028
1029        let fitted = imputer.fit(&X.view(), &()).unwrap();
1030        let X_imputed = fitted.transform(&X.view()).unwrap();
1031
1032        // Verify that missing value was imputed
1033        assert!(!X_imputed[[1, 1]].is_nan());
1034        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1035    }
1036
1037    #[test]
1038    #[allow(non_snake_case)]
1039    fn test_confidence_intervals() {
1040        let X = array![
1041            [1.0, 2.0, 3.0],
1042            [4.0, f64::NAN, 6.0],
1043            [7.0, 8.0, 9.0],
1044            [10.0, 11.0, 12.0]
1045        ];
1046
1047        let imputer = ApproximateSimpleImputer::new().strategy("mean".to_string());
1048
1049        let fitted = imputer.fit(&X.view(), &()).unwrap();
1050        let confidence_intervals = fitted.confidence_intervals();
1051
1052        // Check that confidence intervals exist and make sense
1053        assert_eq!(confidence_intervals.shape(), &[3, 2]);
1054
1055        for j in 0..3 {
1056            let lower = confidence_intervals[[j, 0]];
1057            let upper = confidence_intervals[[j, 1]];
1058            assert!(
1059                lower <= upper,
1060                "Lower bound should be <= upper bound for feature {}",
1061                j
1062            );
1063        }
1064    }
1065}