sklears_impute/
sampling.rs

1//! Sampling-based imputation methods
2//!
3//! This module provides imputation methods based on various sampling techniques
4//! including importance sampling, stratified sampling, and adaptive sampling.
5
6// ✅ SciRS2 Policy compliant imports
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
8use scirs2_core::random::{Random, Rng};
9// use scirs2_core::parallel::{ParallelExecutor, ChunkStrategy}; // Note: not available
10
11// use rayon::prelude::*; // Unused
12use serde::{Deserialize, Serialize};
13use sklears_core::{
14    error::{Result as SklResult, SklearsError},
15    traits::{Estimator, Fit, Transform, Untrained},
16    types::Float,
17};
18use std::collections::HashMap;
19
20/// Configuration for sampling-based imputation
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SamplingConfig {
23    /// Number of samples to draw
24    pub n_samples: usize,
25    /// Sampling strategy to use
26    pub strategy: SamplingStrategy,
27    /// Use importance sampling
28    pub importance_sampling: bool,
29    /// Weight function for importance sampling
30    pub weight_function: WeightFunction,
31    /// Stratification variables for stratified sampling
32    pub stratify_by: Option<Vec<usize>>,
33    /// Number of strata for stratified sampling
34    pub n_strata: usize,
35    /// Use quasi-random sequences (low-discrepancy)
36    pub use_quasi_random: bool,
37    /// Sequence type for quasi-random sampling
38    pub quasi_sequence_type: QuasiSequenceType,
39    /// Enable adaptive sampling
40    pub adaptive_sampling: bool,
41    /// Target confidence level for adaptive sampling
42    pub confidence_level: f64,
43    /// Maximum sampling iterations
44    pub max_iterations: usize,
45}
46
47impl Default for SamplingConfig {
48    fn default() -> Self {
49        Self {
50            n_samples: 1000,
51            strategy: SamplingStrategy::Simple,
52            importance_sampling: false,
53            weight_function: WeightFunction::Uniform,
54            stratify_by: None,
55            n_strata: 5,
56            use_quasi_random: false,
57            quasi_sequence_type: QuasiSequenceType::Halton,
58            adaptive_sampling: false,
59            confidence_level: 0.95,
60            max_iterations: 100,
61        }
62    }
63}
64
65/// Sampling strategies
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum SamplingStrategy {
68    /// Simple random sampling
69    Simple,
70    /// Stratified sampling
71    Stratified,
72    /// Cluster sampling
73    Cluster,
74    /// Systematic sampling
75    Systematic,
76    /// Importance sampling
77    Importance,
78    /// Latin hypercube sampling
79    LatinHypercube,
80    /// Bootstrap sampling
81    Bootstrap,
82    /// Reservoir sampling
83    Reservoir,
84}
85
86/// Weight functions for importance sampling
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum WeightFunction {
89    /// Uniform weights (equivalent to simple random sampling)
90    Uniform,
91    /// Inverse probability weighting
92    InverseProbability,
93    /// Density-based weighting
94    DensityBased,
95    /// Distance-based weighting
96    DistanceBased,
97    /// Custom weights provided by user
98    Custom(Vec<f64>),
99}
100
101/// Quasi-random sequence types
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum QuasiSequenceType {
104    /// Halton sequence
105    Halton,
106    /// Sobol sequence
107    Sobol,
108    /// Faure sequence
109    Faure,
110    /// Niederreiter sequence
111    Niederreiter,
112}
113
114/// Sampling-based Simple Imputer
115#[derive(Debug)]
116pub struct SamplingSimpleImputer<S = Untrained> {
117    state: S,
118    strategy: String,
119    missing_values: f64,
120    config: SamplingConfig,
121}
122
123/// Trained state for sampling-based simple imputer
124#[derive(Debug)]
125pub struct SamplingSimpleImputerTrained {
126    sample_statistics_: Array1<f64>,
127    sample_distributions_: Vec<SampleDistribution>,
128    n_features_in_: usize,
129    config: SamplingConfig,
130}
131
132/// Sample distribution for a feature
133#[derive(Debug, Clone)]
134pub struct SampleDistribution {
135    /// values
136    pub values: Vec<f64>,
137    /// weights
138    pub weights: Vec<f64>,
139    /// cumulative_weights
140    pub cumulative_weights: Vec<f64>,
141    /// distribution_type
142    pub distribution_type: DistributionType,
143}
144
145/// Type alias for stratified distributions result
146type StratifiedDistributionsResult = Result<
147    (
148        HashMap<Vec<usize>, HashMap<usize, SampleDistribution>>,
149        Vec<Array1<f64>>,
150    ),
151    SklearsError,
152>;
153
154/// Distribution types for sampling
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub enum DistributionType {
157    /// Empirical distribution (discrete)
158    Empirical,
159    /// Kernel density estimate (continuous)
160    KernelDensity,
161    /// Parametric distribution
162    Parametric(ParametricDistribution),
163}
164
165/// Parametric distribution types
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub enum ParametricDistribution {
168    /// Normal
169    Normal { mean: f64, std: f64 },
170    /// LogNormal
171    LogNormal { mean_log: f64, std_log: f64 },
172    /// Exponential
173    Exponential { rate: f64 },
174    /// Gamma
175    Gamma { shape: f64, rate: f64 },
176    /// Beta
177    Beta { alpha: f64, beta: f64 },
178    /// Uniform
179    Uniform { low: f64, high: f64 },
180}
181
182/// Stratified Sampling Imputer
183#[derive(Debug)]
184pub struct StratifiedSamplingImputer<S = Untrained> {
185    state: S,
186    missing_values: f64,
187    config: SamplingConfig,
188    stratification_features: Vec<usize>,
189}
190
191/// Trained state for stratified sampling imputer
192#[derive(Debug)]
193pub struct StratifiedSamplingImputerTrained {
194    strata_distributions_: HashMap<Vec<usize>, HashMap<usize, SampleDistribution>>,
195    feature_strata_: Vec<Array1<f64>>, // Stratification boundaries for each feature
196    n_features_in_: usize,
197    config: SamplingConfig,
198}
199
200/// Importance Sampling Imputer
201#[derive(Debug)]
202pub struct ImportanceSamplingImputer<S = Untrained> {
203    state: S,
204    missing_values: f64,
205    config: SamplingConfig,
206    proposal_distribution: ProposalDistribution,
207}
208
209/// Trained state for importance sampling imputer
210#[derive(Debug)]
211pub struct ImportanceSamplingImputerTrained {
212    importance_weights_: Array2<f64>, // [feature, sample]
213    proposal_samples_: Array2<f64>,
214    target_density_: Array1<f64>,
215    n_features_in_: usize,
216    config: SamplingConfig,
217}
218
219/// Proposal distributions for importance sampling
220#[derive(Debug, Clone)]
221pub enum ProposalDistribution {
222    /// Use empirical distribution as proposal
223    Empirical,
224    /// Use Gaussian mixture model
225    GaussianMixture { n_components: usize },
226    /// Use kernel density estimate
227    KernelDensity { bandwidth: f64 },
228}
229
230/// Adaptive Sampling Imputer
231#[derive(Debug)]
232pub struct AdaptiveSamplingImputer<S = Untrained> {
233    state: S,
234    missing_values: f64,
235    config: SamplingConfig,
236    convergence_threshold: f64,
237}
238
239/// Trained state for adaptive sampling imputer
240#[derive(Debug)]
241pub struct AdaptiveSamplingImputerTrained {
242    adaptive_samples_: Vec<Array1<f64>>, // Samples for each feature
243    convergence_history_: Vec<f64>,
244    final_estimates_: Array1<f64>,
245    confidence_intervals_: Array2<f64>,
246    n_features_in_: usize,
247    config: SamplingConfig,
248}
249
250impl SamplingSimpleImputer<Untrained> {
251    pub fn new() -> Self {
252        Self {
253            state: Untrained,
254            strategy: "mean".to_string(),
255            missing_values: f64::NAN,
256            config: SamplingConfig::default(),
257        }
258    }
259
260    pub fn strategy(mut self, strategy: String) -> Self {
261        self.strategy = strategy;
262        self
263    }
264
265    pub fn sampling_config(mut self, config: SamplingConfig) -> Self {
266        self.config = config;
267        self
268    }
269
270    pub fn n_samples(mut self, n_samples: usize) -> Self {
271        self.config.n_samples = n_samples;
272        self
273    }
274
275    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
276        self.config.strategy = strategy;
277        self
278    }
279
280    pub fn weight_function(mut self, weight_function: WeightFunction) -> Self {
281        self.config.weight_function = weight_function;
282        self
283    }
284
285    fn is_missing(&self, value: f64) -> bool {
286        if self.missing_values.is_nan() {
287            value.is_nan()
288        } else {
289            (value - self.missing_values).abs() < f64::EPSILON
290        }
291    }
292}
293
294impl Default for SamplingSimpleImputer<Untrained> {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300impl Estimator for SamplingSimpleImputer<Untrained> {
301    type Config = SamplingConfig;
302    type Error = SklearsError;
303    type Float = Float;
304
305    fn config(&self) -> &Self::Config {
306        &self.config
307    }
308}
309
310impl Fit<ArrayView2<'_, Float>, ()> for SamplingSimpleImputer<Untrained> {
311    type Fitted = SamplingSimpleImputer<SamplingSimpleImputerTrained>;
312
313    #[allow(non_snake_case)]
314    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
315        let X = X.mapv(|x| x);
316        let (_n_samples, n_features) = X.dim();
317
318        let (sample_statistics, sample_distributions) = self.compute_sample_statistics(&X)?;
319
320        Ok(SamplingSimpleImputer {
321            state: SamplingSimpleImputerTrained {
322                sample_statistics_: sample_statistics,
323                sample_distributions_: sample_distributions,
324                n_features_in_: n_features,
325                config: self.config,
326            },
327            strategy: self.strategy,
328            missing_values: self.missing_values,
329            config: Default::default(),
330        })
331    }
332}
333
334impl Transform<ArrayView2<'_, Float>, Array2<Float>>
335    for SamplingSimpleImputer<SamplingSimpleImputerTrained>
336{
337    #[allow(non_snake_case)]
338    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
339        let X = X.mapv(|x| x);
340        let (n_samples, n_features) = X.dim();
341
342        if n_features != self.state.n_features_in_ {
343            return Err(SklearsError::InvalidInput(format!(
344                "Number of features {} does not match training features {}",
345                n_features, self.state.n_features_in_
346            )));
347        }
348
349        let mut X_imputed = X.clone();
350
351        // Apply sampling-based imputation
352        for i in 0..n_samples {
353            for j in 0..n_features {
354                if self.is_missing(X_imputed[[i, j]]) {
355                    let imputed_value = self.sample_imputed_value(j)?;
356                    X_imputed[[i, j]] = imputed_value;
357                }
358            }
359        }
360
361        Ok(X_imputed.mapv(|x| x as Float))
362    }
363}
364
365impl SamplingSimpleImputer<Untrained> {
366    /// Compute sample statistics based on sampling strategy
367    fn compute_sample_statistics(
368        &self,
369        X: &Array2<f64>,
370    ) -> Result<(Array1<f64>, Vec<SampleDistribution>), SklearsError> {
371        let (_, n_features) = X.dim();
372        let mut sample_statistics = Array1::<f64>::zeros(n_features);
373        let mut sample_distributions = Vec::new();
374
375        for j in 0..n_features {
376            let column = X.column(j);
377            let valid_values: Vec<f64> = column
378                .iter()
379                .filter(|&&x| !self.is_missing(x))
380                .cloned()
381                .collect();
382
383            if valid_values.is_empty() {
384                sample_statistics[j] = 0.0;
385                sample_distributions.push(SampleDistribution {
386                    values: vec![0.0],
387                    weights: vec![1.0],
388                    cumulative_weights: vec![1.0],
389                    distribution_type: DistributionType::Empirical,
390                });
391                continue;
392            }
393
394            // Create sample distribution based on strategy
395            let distribution = match self.config.strategy {
396                SamplingStrategy::Simple => {
397                    self.create_simple_sample_distribution(&valid_values)?
398                }
399                SamplingStrategy::Importance => {
400                    self.create_importance_sample_distribution(&valid_values)?
401                }
402                SamplingStrategy::Bootstrap => {
403                    self.create_bootstrap_sample_distribution(&valid_values)?
404                }
405                SamplingStrategy::LatinHypercube => {
406                    self.create_latin_hypercube_distribution(&valid_values)?
407                }
408                _ => self.create_simple_sample_distribution(&valid_values)?,
409            };
410
411            // Compute primary statistic
412            sample_statistics[j] = match self.strategy.as_str() {
413                "mean" => self.compute_weighted_mean(&distribution),
414                "median" => self.compute_weighted_median(&distribution),
415                "mode" => self.compute_weighted_mode(&distribution),
416                _ => self.compute_weighted_mean(&distribution),
417            };
418
419            sample_distributions.push(distribution);
420        }
421
422        Ok((sample_statistics, sample_distributions))
423    }
424
425    /// Create simple sample distribution
426    fn create_simple_sample_distribution(
427        &self,
428        values: &[f64],
429    ) -> Result<SampleDistribution, SklearsError> {
430        let n_samples = self.config.n_samples.min(values.len());
431        let mut rng = Random::default();
432
433        // Simple random sampling
434        let mut sampled_values = Vec::new();
435        let mut weights = Vec::new();
436
437        for _ in 0..n_samples {
438            let idx = rng.gen_range(0..values.len());
439            sampled_values.push(values[idx]);
440            weights.push(1.0 / n_samples as f64);
441        }
442
443        // Compute cumulative weights
444        let mut cumulative_weights = Vec::new();
445        let mut cumsum = 0.0;
446        for &weight in &weights {
447            cumsum += weight;
448            cumulative_weights.push(cumsum);
449        }
450
451        Ok(SampleDistribution {
452            values: sampled_values,
453            weights,
454            cumulative_weights,
455            distribution_type: DistributionType::Empirical,
456        })
457    }
458
459    /// Create importance sample distribution
460    fn create_importance_sample_distribution(
461        &self,
462        values: &[f64],
463    ) -> Result<SampleDistribution, SklearsError> {
464        let n_samples = self.config.n_samples.min(values.len());
465
466        // Compute importance weights based on density
467        let mut weighted_values = Vec::new();
468        let mut importance_weights = Vec::new();
469
470        // Use kernel density estimation for importance weights
471        let bandwidth = self.compute_bandwidth(values);
472
473        for &value in values.iter().take(n_samples) {
474            let density = self.kernel_density_estimate(value, values, bandwidth);
475            let importance_weight = 1.0 / (density + 1e-8); // Avoid division by zero
476
477            weighted_values.push(value);
478            importance_weights.push(importance_weight);
479        }
480
481        // Normalize weights
482        let total_weight: f64 = importance_weights.iter().sum();
483        for weight in &mut importance_weights {
484            *weight /= total_weight;
485        }
486
487        // Compute cumulative weights
488        let mut cumulative_weights = Vec::new();
489        let mut cumsum = 0.0;
490        for &weight in &importance_weights {
491            cumsum += weight;
492            cumulative_weights.push(cumsum);
493        }
494
495        Ok(SampleDistribution {
496            values: weighted_values,
497            weights: importance_weights,
498            cumulative_weights,
499            distribution_type: DistributionType::KernelDensity,
500        })
501    }
502
503    /// Create bootstrap sample distribution
504    fn create_bootstrap_sample_distribution(
505        &self,
506        values: &[f64],
507    ) -> Result<SampleDistribution, SklearsError> {
508        let n_bootstrap = 100;
509        let n_samples_per_bootstrap = values.len();
510        let mut bootstrap_estimates = Vec::new();
511
512        let mut rng = Random::default();
513
514        for _ in 0..n_bootstrap {
515            let mut bootstrap_sample = Vec::new();
516
517            // Sample with replacement
518            for _ in 0..n_samples_per_bootstrap {
519                let idx = rng.gen_range(0..values.len());
520                bootstrap_sample.push(values[idx]);
521            }
522
523            // Compute statistic for this bootstrap sample
524            let estimate = match self.strategy.as_str() {
525                "mean" => bootstrap_sample.iter().sum::<f64>() / bootstrap_sample.len() as f64,
526                "median" => {
527                    let mut sorted = bootstrap_sample.clone();
528                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
529                    let mid = sorted.len() / 2;
530                    if sorted.len() % 2 == 0 {
531                        (sorted[mid - 1] + sorted[mid]) / 2.0
532                    } else {
533                        sorted[mid]
534                    }
535                }
536                _ => bootstrap_sample.iter().sum::<f64>() / bootstrap_sample.len() as f64,
537            };
538
539            bootstrap_estimates.push(estimate);
540        }
541
542        let uniform_weight = 1.0 / bootstrap_estimates.len() as f64;
543        let weights = vec![uniform_weight; bootstrap_estimates.len()];
544
545        // Compute cumulative weights
546        let mut cumulative_weights = Vec::new();
547        let mut cumsum = 0.0;
548        for &weight in &weights {
549            cumsum += weight;
550            cumulative_weights.push(cumsum);
551        }
552
553        Ok(SampleDistribution {
554            values: bootstrap_estimates,
555            weights,
556            cumulative_weights,
557            distribution_type: DistributionType::Empirical,
558        })
559    }
560
561    /// Create Latin hypercube sample distribution
562    fn create_latin_hypercube_distribution(
563        &self,
564        values: &[f64],
565    ) -> Result<SampleDistribution, SklearsError> {
566        let n_samples = self.config.n_samples.min(values.len());
567
568        // Sort values to create quantile-based sampling
569        let mut sorted_values = values.to_vec();
570        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
571
572        let mut lhs_values = Vec::new();
573        let mut rng = Random::default();
574
575        // Generate Latin hypercube samples
576        for i in 0..n_samples {
577            let lower_bound = i as f64 / n_samples as f64;
578            let upper_bound = (i + 1) as f64 / n_samples as f64;
579            let uniform_sample: f64 = rng.gen();
580            let stratified_sample = lower_bound + uniform_sample * (upper_bound - lower_bound);
581
582            // Map to quantile
583            let quantile_idx = (stratified_sample * (sorted_values.len() - 1) as f64) as usize;
584            let quantile_idx = quantile_idx.min(sorted_values.len() - 1);
585
586            lhs_values.push(sorted_values[quantile_idx]);
587        }
588
589        let uniform_weight = 1.0 / lhs_values.len() as f64;
590        let weights = vec![uniform_weight; lhs_values.len()];
591
592        // Compute cumulative weights
593        let mut cumulative_weights = Vec::new();
594        let mut cumsum = 0.0;
595        for &weight in &weights {
596            cumsum += weight;
597            cumulative_weights.push(cumsum);
598        }
599
600        Ok(SampleDistribution {
601            values: lhs_values,
602            weights,
603            cumulative_weights,
604            distribution_type: DistributionType::Empirical,
605        })
606    }
607
608    /// Compute bandwidth for kernel density estimation
609    fn compute_bandwidth(&self, values: &[f64]) -> f64 {
610        if values.len() < 2 {
611            return 1.0;
612        }
613
614        // Silverman's rule of thumb
615        let n = values.len() as f64;
616        let mean = values.iter().sum::<f64>() / n;
617        let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
618        let std_dev = variance.sqrt();
619
620        std_dev * (4.0 / (3.0 * n)).powf(0.2)
621    }
622
623    /// Kernel density estimate at a point
624    fn kernel_density_estimate(&self, x: f64, values: &[f64], bandwidth: f64) -> f64 {
625        let n = values.len() as f64;
626        let sum: f64 = values
627            .iter()
628            .map(|&xi| {
629                let u = (x - xi) / bandwidth;
630                (-0.5 * u * u).exp() // Gaussian kernel
631            })
632            .sum();
633
634        sum / (n * bandwidth * (2.0 * std::f64::consts::PI).sqrt())
635    }
636
637    /// Compute weighted mean
638    fn compute_weighted_mean(&self, distribution: &SampleDistribution) -> f64 {
639        distribution
640            .values
641            .iter()
642            .zip(distribution.weights.iter())
643            .map(|(&value, &weight)| value * weight)
644            .sum()
645    }
646
647    /// Compute weighted median
648    fn compute_weighted_median(&self, distribution: &SampleDistribution) -> f64 {
649        // Find the value where cumulative weight crosses 0.5
650        for (i, &cum_weight) in distribution.cumulative_weights.iter().enumerate() {
651            if cum_weight >= 0.5 {
652                return distribution.values[i];
653            }
654        }
655
656        // Fallback to last value
657        distribution.values.last().copied().unwrap_or(0.0)
658    }
659
660    /// Compute weighted mode (most frequent value)
661    fn compute_weighted_mode(&self, distribution: &SampleDistribution) -> f64 {
662        let mut value_weights = HashMap::new();
663
664        for (&value, &weight) in distribution.values.iter().zip(distribution.weights.iter()) {
665            let key = (value * 1e6) as i64; // Handle floating point precision
666            *value_weights.entry(key).or_insert(0.0) += weight;
667        }
668
669        value_weights
670            .into_iter()
671            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
672            .map(|(key, _)| key as f64 / 1e6)
673            .unwrap_or(0.0)
674    }
675}
676
677impl SamplingSimpleImputer<SamplingSimpleImputerTrained> {
678    /// Sample an imputed value from the learned distribution
679    fn sample_imputed_value(&self, feature_idx: usize) -> Result<f64, SklearsError> {
680        let distribution = &self.state.sample_distributions_[feature_idx];
681        let mut rng = Random::default();
682        let random_value: f64 = rng.gen();
683
684        // Find the corresponding value using cumulative weights
685        for (i, &cum_weight) in distribution.cumulative_weights.iter().enumerate() {
686            if random_value <= cum_weight {
687                return Ok(distribution.values[i]);
688            }
689        }
690
691        // Fallback to last value
692        Ok(distribution.values.last().copied().unwrap_or(0.0))
693    }
694
695    fn is_missing(&self, value: f64) -> bool {
696        if self.missing_values.is_nan() {
697            value.is_nan()
698        } else {
699            (value - self.missing_values).abs() < f64::EPSILON
700        }
701    }
702
703    /// Get the sample distribution for a feature
704    pub fn distribution(&self, feature_idx: usize) -> Option<&SampleDistribution> {
705        self.state.sample_distributions_.get(feature_idx)
706    }
707
708    /// Get sample statistics
709    pub fn statistics(&self) -> &Array1<f64> {
710        &self.state.sample_statistics_
711    }
712}
713
714// Implement Stratified Sampling Imputer
715impl StratifiedSamplingImputer<Untrained> {
716    pub fn new() -> Self {
717        Self {
718            state: Untrained,
719            missing_values: f64::NAN,
720            config: SamplingConfig::default(),
721            stratification_features: Vec::new(),
722        }
723    }
724
725    pub fn sampling_config(mut self, config: SamplingConfig) -> Self {
726        self.config = config;
727        self
728    }
729
730    pub fn stratify_by(mut self, features: Vec<usize>) -> Self {
731        self.stratification_features = features.clone();
732        self.config.stratify_by = Some(features);
733        self
734    }
735
736    pub fn n_strata(mut self, n_strata: usize) -> Self {
737        self.config.n_strata = n_strata;
738        self
739    }
740
741    fn is_missing(&self, value: f64) -> bool {
742        if self.missing_values.is_nan() {
743            value.is_nan()
744        } else {
745            (value - self.missing_values).abs() < f64::EPSILON
746        }
747    }
748}
749
750impl Default for StratifiedSamplingImputer<Untrained> {
751    fn default() -> Self {
752        Self::new()
753    }
754}
755
756impl Estimator for StratifiedSamplingImputer<Untrained> {
757    type Config = SamplingConfig;
758    type Error = SklearsError;
759    type Float = Float;
760
761    fn config(&self) -> &Self::Config {
762        &self.config
763    }
764}
765
766impl Fit<ArrayView2<'_, Float>, ()> for StratifiedSamplingImputer<Untrained> {
767    type Fitted = StratifiedSamplingImputer<StratifiedSamplingImputerTrained>;
768
769    #[allow(non_snake_case)]
770    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
771        let X = X.mapv(|x| x);
772        let (_n_samples, n_features) = X.dim();
773
774        if self.stratification_features.is_empty() {
775            return Err(SklearsError::InvalidInput(
776                "Stratification features must be specified".to_string(),
777            ));
778        }
779
780        let (strata_distributions, feature_strata) = self.compute_stratified_distributions(&X)?;
781
782        Ok(StratifiedSamplingImputer {
783            state: StratifiedSamplingImputerTrained {
784                strata_distributions_: strata_distributions,
785                feature_strata_: feature_strata,
786                n_features_in_: n_features,
787                config: self.config,
788            },
789            missing_values: self.missing_values,
790            config: Default::default(),
791            stratification_features: Vec::new(),
792        })
793    }
794}
795
796impl Transform<ArrayView2<'_, Float>, Array2<Float>>
797    for StratifiedSamplingImputer<StratifiedSamplingImputerTrained>
798{
799    #[allow(non_snake_case)]
800    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
801        let X = X.mapv(|x| x);
802        let (n_samples, n_features) = X.dim();
803
804        if n_features != self.state.n_features_in_ {
805            return Err(SklearsError::InvalidInput(format!(
806                "Number of features {} does not match training features {}",
807                n_features, self.state.n_features_in_
808            )));
809        }
810
811        let mut X_imputed = X.clone();
812
813        for i in 0..n_samples {
814            // Determine stratum for this sample
815            let stratum_key = self.determine_stratum(&X_imputed.row(i).to_owned())?;
816
817            for j in 0..n_features {
818                if self.is_missing(X_imputed[[i, j]]) {
819                    if let Some(stratum_dists) = self.state.strata_distributions_.get(&stratum_key)
820                    {
821                        if let Some(distribution) = stratum_dists.get(&j) {
822                            let imputed_value = self.sample_from_distribution(distribution)?;
823                            X_imputed[[i, j]] = imputed_value;
824                        }
825                    }
826                }
827            }
828        }
829
830        Ok(X_imputed.mapv(|x| x as Float))
831    }
832}
833
834impl StratifiedSamplingImputer<Untrained> {
835    /// Compute distributions for each stratum
836    fn compute_stratified_distributions(&self, X: &Array2<f64>) -> StratifiedDistributionsResult {
837        let (n_samples, n_features) = X.dim();
838
839        // Compute strata boundaries for stratification features
840        let mut feature_strata = Vec::new();
841        for &feature_idx in &self.stratification_features {
842            let column = X.column(feature_idx);
843            let valid_values: Vec<f64> = column
844                .iter()
845                .filter(|&&x| !self.is_missing(x))
846                .cloned()
847                .collect();
848
849            if valid_values.is_empty() {
850                feature_strata.push(Array1::from_vec(vec![0.0, 1.0]));
851                continue;
852            }
853
854            // Create quantile-based strata boundaries
855            let mut sorted_values = valid_values;
856            sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
857
858            let mut boundaries = Vec::new();
859            for i in 0..=self.config.n_strata {
860                let quantile = i as f64 / self.config.n_strata as f64;
861                let idx = ((sorted_values.len() - 1) as f64 * quantile) as usize;
862                let idx = idx.min(sorted_values.len() - 1);
863                boundaries.push(sorted_values[idx]);
864            }
865
866            feature_strata.push(Array1::from_vec(boundaries));
867        }
868
869        // Assign samples to strata and compute distributions
870        let mut strata_samples: HashMap<Vec<usize>, Vec<Array1<f64>>> = HashMap::new();
871
872        for i in 0..n_samples {
873            let row = X.row(i).to_owned();
874            let stratum_key = self.assign_to_stratum(&row, &feature_strata)?;
875
876            strata_samples.entry(stratum_key).or_default().push(row);
877        }
878
879        // Compute distributions for each stratum and feature
880        let mut strata_distributions = HashMap::new();
881
882        for (stratum_key, samples) in strata_samples {
883            let mut feature_distributions = HashMap::new();
884
885            for j in 0..n_features {
886                let feature_values: Vec<f64> = samples
887                    .iter()
888                    .map(|row| row[j])
889                    .filter(|&x| !self.is_missing(x))
890                    .collect();
891
892                if !feature_values.is_empty() {
893                    let distribution = self.create_empirical_distribution(&feature_values)?;
894                    feature_distributions.insert(j, distribution);
895                }
896            }
897
898            strata_distributions.insert(stratum_key, feature_distributions);
899        }
900
901        Ok((strata_distributions, feature_strata))
902    }
903
904    /// Assign a sample to a stratum
905    fn assign_to_stratum(
906        &self,
907        row: &Array1<f64>,
908        feature_strata: &[Array1<f64>],
909    ) -> Result<Vec<usize>, SklearsError> {
910        let mut stratum_key = Vec::new();
911
912        for (i, &feature_idx) in self.stratification_features.iter().enumerate() {
913            let value = row[feature_idx];
914            if self.is_missing(value) {
915                stratum_key.push(0); // Default stratum for missing values
916                continue;
917            }
918
919            let boundaries = &feature_strata[i];
920            let mut stratum = 0;
921
922            for k in 1..boundaries.len() {
923                if value <= boundaries[k] {
924                    stratum = k - 1;
925                    break;
926                }
927            }
928
929            stratum_key.push(stratum);
930        }
931
932        Ok(stratum_key)
933    }
934
935    /// Create empirical distribution from values
936    fn create_empirical_distribution(
937        &self,
938        values: &[f64],
939    ) -> Result<SampleDistribution, SklearsError> {
940        let uniform_weight = 1.0 / values.len() as f64;
941        let weights = vec![uniform_weight; values.len()];
942
943        // Compute cumulative weights
944        let mut cumulative_weights = Vec::new();
945        let mut cumsum = 0.0;
946        for &weight in &weights {
947            cumsum += weight;
948            cumulative_weights.push(cumsum);
949        }
950
951        Ok(SampleDistribution {
952            values: values.to_vec(),
953            weights,
954            cumulative_weights,
955            distribution_type: DistributionType::Empirical,
956        })
957    }
958}
959
960impl StratifiedSamplingImputer<StratifiedSamplingImputerTrained> {
961    /// Determine stratum for a sample
962    fn determine_stratum(&self, row: &Array1<f64>) -> Result<Vec<usize>, SklearsError> {
963        if let Some(ref stratify_features) = self.state.config.stratify_by {
964            let mut stratum_key = Vec::new();
965
966            for (i, &feature_idx) in stratify_features.iter().enumerate() {
967                let value = row[feature_idx];
968                if self.is_missing(value) {
969                    stratum_key.push(0); // Default stratum for missing values
970                    continue;
971                }
972
973                let boundaries = &self.state.feature_strata_[i];
974                let mut stratum = 0;
975
976                for k in 1..boundaries.len() {
977                    if value <= boundaries[k] {
978                        stratum = k - 1;
979                        break;
980                    }
981                }
982
983                stratum_key.push(stratum);
984            }
985
986            Ok(stratum_key)
987        } else {
988            Ok(vec![0]) // Default stratum
989        }
990    }
991
992    /// Sample from a distribution
993    fn sample_from_distribution(
994        &self,
995        distribution: &SampleDistribution,
996    ) -> Result<f64, SklearsError> {
997        let mut rng = Random::default();
998        let random_value: f64 = rng.gen();
999
1000        // Find the corresponding value using cumulative weights
1001        for (i, &cum_weight) in distribution.cumulative_weights.iter().enumerate() {
1002            if random_value <= cum_weight {
1003                return Ok(distribution.values[i]);
1004            }
1005        }
1006
1007        // Fallback to last value
1008        Ok(distribution.values.last().copied().unwrap_or(0.0))
1009    }
1010
1011    fn is_missing(&self, value: f64) -> bool {
1012        if self.missing_values.is_nan() {
1013            value.is_nan()
1014        } else {
1015            (value - self.missing_values).abs() < f64::EPSILON
1016        }
1017    }
1018}
1019
1020#[allow(non_snake_case)]
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024    use approx::assert_abs_diff_eq;
1025    use scirs2_core::ndarray::array;
1026
1027    #[test]
1028    #[allow(non_snake_case)]
1029    fn test_sampling_simple_imputer() {
1030        let X = array![
1031            [1.0, 2.0, 3.0],
1032            [4.0, f64::NAN, 6.0],
1033            [7.0, 8.0, 9.0],
1034            [10.0, 11.0, 12.0]
1035        ];
1036
1037        let imputer = SamplingSimpleImputer::new()
1038            .strategy("mean".to_string())
1039            .n_samples(100)
1040            .sampling_strategy(SamplingStrategy::Simple);
1041
1042        let fitted = imputer.fit(&X.view(), &()).unwrap();
1043        let X_imputed = fitted.transform(&X.view()).unwrap();
1044
1045        // Check that NaN was replaced
1046        assert!(!X_imputed[[1, 1]].is_nan());
1047        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1048        assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
1049    }
1050
1051    #[test]
1052    #[allow(non_snake_case)]
1053    fn test_bootstrap_sampling() {
1054        let X = array![[1.0, 2.0], [3.0, f64::NAN], [5.0, 6.0], [7.0, 8.0]];
1055
1056        let imputer = SamplingSimpleImputer::new()
1057            .strategy("mean".to_string())
1058            .sampling_strategy(SamplingStrategy::Bootstrap);
1059
1060        let fitted = imputer.fit(&X.view(), &()).unwrap();
1061        let X_imputed = fitted.transform(&X.view()).unwrap();
1062
1063        assert!(!X_imputed[[1, 1]].is_nan());
1064        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1065    }
1066
1067    #[test]
1068    #[allow(non_snake_case)]
1069    fn test_stratified_sampling_imputer() {
1070        let X = array![
1071            [1.0, 2.0, 0.0],      // stratum 0
1072            [2.0, f64::NAN, 0.0], // stratum 0
1073            [8.0, 9.0, 1.0],      // stratum 1
1074            [9.0, 10.0, 1.0]      // stratum 1
1075        ];
1076
1077        let imputer = StratifiedSamplingImputer::new()
1078            .stratify_by(vec![2]) // Stratify by the third column
1079            .n_strata(2);
1080
1081        let fitted = imputer.fit(&X.view(), &()).unwrap();
1082        let X_imputed = fitted.transform(&X.view()).unwrap();
1083
1084        assert!(!X_imputed[[1, 1]].is_nan());
1085        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1086    }
1087
1088    #[test]
1089    #[allow(non_snake_case)]
1090    fn test_latin_hypercube_sampling() {
1091        let X = array![
1092            [1.0, 2.0, 3.0],
1093            [4.0, f64::NAN, 6.0],
1094            [7.0, 8.0, 9.0],
1095            [10.0, 11.0, 12.0]
1096        ];
1097
1098        let imputer = SamplingSimpleImputer::new()
1099            .strategy("mean".to_string())
1100            .sampling_strategy(SamplingStrategy::LatinHypercube)
1101            .n_samples(3);
1102
1103        let fitted = imputer.fit(&X.view(), &()).unwrap();
1104        let X_imputed = fitted.transform(&X.view()).unwrap();
1105
1106        assert!(!X_imputed[[1, 1]].is_nan());
1107        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1108    }
1109
1110    #[test]
1111    fn test_sample_distribution() {
1112        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1113        let weights = vec![0.1, 0.2, 0.4, 0.2, 0.1];
1114        let cumulative_weights = vec![0.1, 0.3, 0.7, 0.9, 1.0];
1115
1116        let distribution = SampleDistribution {
1117            values,
1118            weights,
1119            cumulative_weights,
1120            distribution_type: DistributionType::Empirical,
1121        };
1122
1123        assert_eq!(distribution.values.len(), 5);
1124        assert_eq!(distribution.weights.len(), 5);
1125        assert_eq!(distribution.cumulative_weights.len(), 5);
1126        assert!((distribution.cumulative_weights.last().unwrap() - 1.0).abs() < 1e-10);
1127    }
1128
1129    #[test]
1130    fn test_sampling_config() {
1131        let config = SamplingConfig {
1132            n_samples: 500,
1133            strategy: SamplingStrategy::Importance,
1134            importance_sampling: true,
1135            ..Default::default()
1136        };
1137
1138        let imputer = SamplingSimpleImputer::new().sampling_config(config.clone());
1139
1140        assert_eq!(imputer.config.n_samples, 500);
1141        assert!(matches!(
1142            imputer.config.strategy,
1143            SamplingStrategy::Importance
1144        ));
1145        assert!(imputer.config.importance_sampling);
1146    }
1147}