sklears_impute/
mixed_type.rs

1//! Mixed-type data imputation methods
2//!
3//! This module provides imputation strategies for datasets containing heterogeneous data types,
4//! including ordinal variables, semi-continuous data, and bounded variables.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use scirs2_core::random::{Random, Rng};
8use sklears_core::{
9    error::{Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Transform, Untrained},
11    types::Float,
12};
13use std::collections::HashMap;
14
15/// Variable type enumeration for mixed-type data
16#[derive(Debug, Clone, PartialEq)]
17pub enum VariableType {
18    /// Continuous numerical variable
19    Continuous,
20    /// Ordinal categorical variable with ordered levels
21    Ordinal(Vec<f64>),
22    /// Nominal categorical variable
23    Categorical(Vec<f64>),
24    /// Semi-continuous variable (mixture of continuous and discrete components)
25    SemiContinuous { zero_probability: f64 },
26    /// Bounded continuous variable
27    Bounded { lower: f64, upper: f64 },
28    /// Binary variable
29    Binary,
30}
31
32/// Variable metadata for mixed-type imputation
33#[derive(Debug, Clone)]
34pub struct VariableMetadata {
35    /// variable_type
36    pub variable_type: VariableType,
37    /// missing_pattern
38    pub missing_pattern: String,
39    /// is_target
40    pub is_target: bool,
41}
42
43/// Heterogeneous Data Imputer
44///
45/// Imputation for datasets containing multiple data types including continuous,
46/// ordinal, categorical, semi-continuous, and bounded variables.
47///
48/// # Parameters
49///
50/// * `variable_types` - Map from feature index to variable type
51/// * `max_iter` - Maximum number of iterations for iterative methods
52/// * `tol` - Tolerance for convergence
53/// * `random_state` - Random state for reproducibility
54///
55/// # Examples
56///
57/// ```rust,ignore
58/// use sklears_impute::{HeterogeneousImputer, VariableType};
59/// use sklears_core::traits::{Transform, Fit};
60/// use scirs2_core::ndarray::array;
61/// ///
62/// let X = array![[1.0, 2.0, 3.0], [f64::NAN, 3.0, 4.0], [7.0, f64::NAN, 6.0]];
63/// let mut variable_types = HashMap::new();
64/// variable_types.insert(0, VariableType::Continuous);
65/// variable_types.insert(1, VariableType::Ordinal(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
66/// variable_types.insert(2, VariableType::Bounded { lower: 0.0, upper: 10.0 });
67///
68/// let imputer = HeterogeneousImputer::new()
69///     .variable_types(variable_types)
70///     .max_iter(50);
71/// let fitted = imputer.fit(&X.view(), &()).unwrap();
72/// let X_imputed = fitted.transform(&X.view()).unwrap();
73/// ```
74#[derive(Debug, Clone)]
75pub struct HeterogeneousImputer<S = Untrained> {
76    state: S,
77    variable_types: HashMap<usize, VariableType>,
78    max_iter: usize,
79    tol: f64,
80    random_state: Option<u64>,
81    missing_values: f64,
82}
83
84/// Trained state for HeterogeneousImputer
85#[derive(Debug, Clone)]
86pub struct HeterogeneousImputerTrained {
87    variable_types: HashMap<usize, VariableType>,
88    learned_parameters: HashMap<usize, VariableParameters>,
89    n_features_in_: usize,
90}
91
92/// Parameters learned for each variable type
93#[derive(Debug, Clone)]
94pub enum VariableParameters {
95    /// ContinuousParams
96    ContinuousParams {
97        mean: f64,
98        std: f64,
99        coefficients: Option<Array1<f64>>,
100    },
101    /// OrdinalParams
102    OrdinalParams {
103        levels: Vec<f64>,
104        probabilities: Array1<f64>,
105        transition_matrix: Option<Array2<f64>>,
106    },
107    /// CategoricalParams
108    CategoricalParams {
109        categories: Vec<f64>,
110        probabilities: Array1<f64>,
111    },
112    /// SemiContinuousParams
113    SemiContinuousParams {
114        zero_prob: f64,
115        continuous_mean: f64,
116        continuous_std: f64,
117        threshold: f64,
118    },
119    /// BoundedParams
120    BoundedParams {
121        lower: f64,
122        upper: f64,
123        beta_alpha: f64,
124        beta_beta: f64,
125    },
126    /// BinaryParams
127    BinaryParams { probability: f64 },
128}
129
130impl HeterogeneousImputer<Untrained> {
131    /// Create a new HeterogeneousImputer instance
132    pub fn new() -> Self {
133        Self {
134            state: Untrained,
135            variable_types: HashMap::new(),
136            max_iter: 100,
137            tol: 1e-4,
138            random_state: None,
139            missing_values: f64::NAN,
140        }
141    }
142
143    /// Set the variable types for each feature
144    pub fn variable_types(mut self, variable_types: HashMap<usize, VariableType>) -> Self {
145        self.variable_types = variable_types;
146        self
147    }
148
149    /// Set the maximum number of iterations
150    pub fn max_iter(mut self, max_iter: usize) -> Self {
151        self.max_iter = max_iter;
152        self
153    }
154
155    /// Set the tolerance for convergence
156    pub fn tol(mut self, tol: f64) -> Self {
157        self.tol = tol;
158        self
159    }
160
161    /// Set the random state
162    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
163        self.random_state = random_state;
164        self
165    }
166
167    /// Set the missing values placeholder
168    pub fn missing_values(mut self, missing_values: f64) -> Self {
169        self.missing_values = missing_values;
170        self
171    }
172
173    fn is_missing(&self, value: f64) -> bool {
174        if self.missing_values.is_nan() {
175            value.is_nan()
176        } else {
177            (value - self.missing_values).abs() < f64::EPSILON
178        }
179    }
180}
181
182impl Default for HeterogeneousImputer<Untrained> {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl Estimator for HeterogeneousImputer<Untrained> {
189    type Config = ();
190    type Error = SklearsError;
191    type Float = Float;
192
193    fn config(&self) -> &Self::Config {
194        &()
195    }
196}
197
198impl Fit<ArrayView2<'_, Float>, ()> for HeterogeneousImputer<Untrained> {
199    type Fitted = HeterogeneousImputer<HeterogeneousImputerTrained>;
200
201    #[allow(non_snake_case)]
202    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
203        let X = X.mapv(|x| x);
204        let (_, n_features) = X.dim();
205
206        // Auto-detect variable types if not provided
207        let variable_types = if self.variable_types.is_empty() {
208            self.auto_detect_variable_types(&X)?
209        } else {
210            self.variable_types.clone()
211        };
212
213        // Learn parameters for each variable type
214        let mut learned_parameters = HashMap::new();
215
216        for (&feature_idx, var_type) in &variable_types {
217            if feature_idx < n_features {
218                let column = X.column(feature_idx);
219                let observed_values: Vec<f64> = column
220                    .iter()
221                    .filter(|&&x| !self.is_missing(x))
222                    .cloned()
223                    .collect();
224
225                if !observed_values.is_empty() {
226                    let params = self.learn_variable_parameters(var_type, &observed_values)?;
227                    learned_parameters.insert(feature_idx, params);
228                }
229            }
230        }
231
232        Ok(HeterogeneousImputer {
233            state: HeterogeneousImputerTrained {
234                variable_types,
235                learned_parameters,
236                n_features_in_: n_features,
237            },
238            variable_types: self.variable_types,
239            max_iter: self.max_iter,
240            tol: self.tol,
241            random_state: self.random_state,
242            missing_values: self.missing_values,
243        })
244    }
245}
246
247impl HeterogeneousImputer<Untrained> {
248    fn auto_detect_variable_types(
249        &self,
250        X: &Array2<f64>,
251    ) -> SklResult<HashMap<usize, VariableType>> {
252        let mut variable_types = HashMap::new();
253        let (_, n_features) = X.dim();
254
255        for j in 0..n_features {
256            let column = X.column(j);
257            let observed_values: Vec<f64> = column
258                .iter()
259                .filter(|&&x| !self.is_missing(x))
260                .cloned()
261                .collect();
262
263            if observed_values.is_empty() {
264                continue;
265            }
266
267            let var_type = self.detect_variable_type(&observed_values);
268            variable_types.insert(j, var_type);
269        }
270
271        Ok(variable_types)
272    }
273
274    fn detect_variable_type(&self, values: &[f64]) -> VariableType {
275        let unique_values: std::collections::HashSet<_> = values
276            .iter()
277            .map(|&x| (x * 1000.0).round() as i64)
278            .collect();
279
280        // Check if binary
281        if unique_values.len() == 2 {
282            return VariableType::Binary;
283        }
284
285        // Check if all values are integers (potential ordinal/categorical)
286        let all_integers = values.iter().all(|&x| x.fract() == 0.0);
287
288        if all_integers && unique_values.len() <= 10 {
289            // Assume ordinal if few unique integer values
290            let mut sorted_values: Vec<f64> = values.to_vec();
291            sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
292            sorted_values.dedup();
293            return VariableType::Ordinal(sorted_values);
294        }
295
296        // Check for semi-continuous (many zeros)
297        let zero_count = values.iter().filter(|&&x| x == 0.0).count();
298        let zero_proportion = zero_count as f64 / values.len() as f64;
299
300        if zero_proportion > 0.1 && zero_proportion < 0.9 {
301            return VariableType::SemiContinuous {
302                zero_probability: zero_proportion,
303            };
304        }
305
306        // Check if bounded (all values in a specific range)
307        let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
308        let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
309
310        if min_val >= 0.0 && max_val <= 1.0 {
311            return VariableType::Bounded {
312                lower: 0.0,
313                upper: 1.0,
314            };
315        }
316
317        // Default to continuous
318        VariableType::Continuous
319    }
320
321    fn learn_variable_parameters(
322        &self,
323        var_type: &VariableType,
324        observed_values: &[f64],
325    ) -> SklResult<VariableParameters> {
326        match var_type {
327            VariableType::Continuous => {
328                let mean = observed_values.iter().sum::<f64>() / observed_values.len() as f64;
329                let variance = observed_values
330                    .iter()
331                    .map(|&x| (x - mean).powi(2))
332                    .sum::<f64>()
333                    / (observed_values.len() as f64 - 1.0).max(1.0);
334                let std = variance.sqrt();
335
336                Ok(VariableParameters::ContinuousParams {
337                    mean,
338                    std,
339                    coefficients: None,
340                })
341            }
342            VariableType::Ordinal(levels) => {
343                let mut probabilities = Array1::zeros(levels.len());
344                let total_count = observed_values.len() as f64;
345
346                for &value in observed_values {
347                    if let Some(idx) = levels
348                        .iter()
349                        .position(|&level| (level - value).abs() < 1e-10)
350                    {
351                        probabilities[idx] += 1.0 / total_count;
352                    }
353                }
354
355                Ok(VariableParameters::OrdinalParams {
356                    levels: levels.clone(),
357                    probabilities,
358                    transition_matrix: None,
359                })
360            }
361            VariableType::Categorical(categories) => {
362                let mut probabilities = Array1::zeros(categories.len());
363                let total_count = observed_values.len() as f64;
364
365                for &value in observed_values {
366                    if let Some(idx) = categories
367                        .iter()
368                        .position(|&cat| (cat - value).abs() < 1e-10)
369                    {
370                        probabilities[idx] += 1.0 / total_count;
371                    }
372                }
373
374                Ok(VariableParameters::CategoricalParams {
375                    categories: categories.clone(),
376                    probabilities,
377                })
378            }
379            VariableType::SemiContinuous {
380                zero_probability: _,
381            } => {
382                let zero_count = observed_values.iter().filter(|&&x| x == 0.0).count();
383                let zero_prob = zero_count as f64 / observed_values.len() as f64;
384
385                let non_zero_values: Vec<f64> = observed_values
386                    .iter()
387                    .filter(|&&x| x != 0.0)
388                    .cloned()
389                    .collect();
390
391                let (continuous_mean, continuous_std) = if non_zero_values.is_empty() {
392                    (0.0, 1.0)
393                } else {
394                    let mean = non_zero_values.iter().sum::<f64>() / non_zero_values.len() as f64;
395                    let variance = non_zero_values
396                        .iter()
397                        .map(|&x| (x - mean).powi(2))
398                        .sum::<f64>()
399                        / (non_zero_values.len() as f64 - 1.0).max(1.0);
400                    (mean, variance.sqrt())
401                };
402
403                Ok(VariableParameters::SemiContinuousParams {
404                    zero_prob,
405                    continuous_mean,
406                    continuous_std,
407                    threshold: 0.0,
408                })
409            }
410            VariableType::Bounded { lower, upper } => {
411                // Fit Beta distribution parameters using method of moments
412                let mean = observed_values.iter().sum::<f64>() / observed_values.len() as f64;
413                let variance = observed_values
414                    .iter()
415                    .map(|&x| (x - mean).powi(2))
416                    .sum::<f64>()
417                    / (observed_values.len() as f64 - 1.0).max(1.0);
418
419                // Transform to [0,1] scale for Beta distribution
420                let range = upper - lower;
421                let scaled_mean = (mean - lower) / range;
422                let scaled_variance = variance / (range * range);
423
424                // Method of moments for Beta distribution
425                let alpha =
426                    scaled_mean * (scaled_mean * (1.0 - scaled_mean) / scaled_variance - 1.0);
427                let beta = (1.0 - scaled_mean)
428                    * (scaled_mean * (1.0 - scaled_mean) / scaled_variance - 1.0);
429
430                Ok(VariableParameters::BoundedParams {
431                    lower: *lower,
432                    upper: *upper,
433                    beta_alpha: alpha.max(0.1),
434                    beta_beta: beta.max(0.1),
435                })
436            }
437            VariableType::Binary => {
438                let ones = observed_values.iter().filter(|&&x| x == 1.0).count();
439                let probability = ones as f64 / observed_values.len() as f64;
440
441                Ok(VariableParameters::BinaryParams { probability })
442            }
443        }
444    }
445}
446
447impl Transform<ArrayView2<'_, Float>, Array2<Float>>
448    for HeterogeneousImputer<HeterogeneousImputerTrained>
449{
450    #[allow(non_snake_case)]
451    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
452        let X = X.mapv(|x| x);
453        let (n_samples, n_features) = X.dim();
454
455        if n_features != self.state.n_features_in_ {
456            return Err(SklearsError::InvalidInput(format!(
457                "Number of features {} does not match training features {}",
458                n_features, self.state.n_features_in_
459            )));
460        }
461
462        let mut X_imputed = X.clone();
463        let mut rng = Random::default();
464
465        // Iterative imputation for mixed-type data
466        for iteration in 0..self.max_iter {
467            let mut converged = true;
468            let _prev_X = X_imputed.clone();
469
470            for (&feature_idx, var_type) in &self.state.variable_types {
471                if let Some(params) = self.state.learned_parameters.get(&feature_idx) {
472                    for i in 0..n_samples {
473                        if self.is_missing(X[[i, feature_idx]]) {
474                            let imputed_value = self.impute_value(
475                                var_type,
476                                params,
477                                &X_imputed,
478                                i,
479                                feature_idx,
480                                &mut rng,
481                            )?;
482
483                            let old_value = X_imputed[[i, feature_idx]];
484                            X_imputed[[i, feature_idx]] = imputed_value;
485
486                            if (old_value - imputed_value).abs() > self.tol {
487                                converged = false;
488                            }
489                        }
490                    }
491                }
492            }
493
494            if converged && iteration > 0 {
495                break;
496            }
497        }
498
499        Ok(X_imputed.mapv(|x| x as Float))
500    }
501}
502
503impl HeterogeneousImputer<HeterogeneousImputerTrained> {
504    fn is_missing(&self, value: f64) -> bool {
505        if self.missing_values.is_nan() {
506            value.is_nan()
507        } else {
508            (value - self.missing_values).abs() < f64::EPSILON
509        }
510    }
511
512    fn impute_value(
513        &self,
514        var_type: &VariableType,
515        params: &VariableParameters,
516        X: &Array2<f64>,
517        sample_idx: usize,
518        feature_idx: usize,
519        rng: &mut Random,
520    ) -> SklResult<f64> {
521        match (var_type, params) {
522            (VariableType::Continuous, VariableParameters::ContinuousParams { mean, std, .. }) => {
523                // Use regression if other features are available, otherwise use mean
524                if let Some(predicted) = self.predict_continuous(X, sample_idx, feature_idx)? {
525                    Ok(predicted)
526                } else {
527                    Ok(mean + std * rng.gen::<f64>())
528                }
529            }
530            (
531                VariableType::Ordinal(levels),
532                VariableParameters::OrdinalParams { probabilities, .. },
533            ) => {
534                // Sample from learned probability distribution
535                let random_val: f64 = rng.gen();
536                let mut cumulative = 0.0;
537
538                for (i, &prob) in probabilities.iter().enumerate() {
539                    cumulative += prob;
540                    if random_val <= cumulative && i < levels.len() {
541                        return Ok(levels[i]);
542                    }
543                }
544
545                // Fallback to first level
546                Ok(levels.first().copied().unwrap_or(0.0))
547            }
548            (
549                VariableType::Categorical(categories),
550                VariableParameters::CategoricalParams { probabilities, .. },
551            ) => {
552                // Sample from categorical distribution
553                let random_val: f64 = rng.gen();
554                let mut cumulative = 0.0;
555
556                for (i, &prob) in probabilities.iter().enumerate() {
557                    cumulative += prob;
558                    if random_val <= cumulative && i < categories.len() {
559                        return Ok(categories[i]);
560                    }
561                }
562
563                // Fallback to first category
564                Ok(categories.first().copied().unwrap_or(0.0))
565            }
566            (
567                VariableType::SemiContinuous { .. },
568                VariableParameters::SemiContinuousParams {
569                    zero_prob,
570                    continuous_mean,
571                    continuous_std,
572                    ..
573                },
574            ) => {
575                // Two-step process: first decide if zero, then sample continuous part
576                if rng.gen::<f64>() < *zero_prob {
577                    Ok(0.0)
578                } else {
579                    Ok(continuous_mean + continuous_std * rng.gen::<f64>())
580                }
581            }
582            (
583                VariableType::Bounded { .. },
584                VariableParameters::BoundedParams {
585                    lower,
586                    upper,
587                    beta_alpha,
588                    beta_beta,
589                },
590            ) => {
591                // Sample from Beta distribution and transform to bounds
592                let beta_sample = self.sample_beta(*beta_alpha, *beta_beta, rng);
593                Ok(lower + (upper - lower) * beta_sample)
594            }
595            (VariableType::Binary, VariableParameters::BinaryParams { probability }) => {
596                if rng.gen::<f64>() < *probability {
597                    Ok(1.0)
598                } else {
599                    Ok(0.0)
600                }
601            }
602            _ => Err(SklearsError::InvalidInput(
603                "Mismatched variable type and parameters".to_string(),
604            )),
605        }
606    }
607
608    fn predict_continuous(
609        &self,
610        X: &Array2<f64>,
611        sample_idx: usize,
612        target_feature: usize,
613    ) -> SklResult<Option<f64>> {
614        // Simple linear regression using other observed features
615        let mut predictors = Vec::new();
616        let mut targets = Vec::new();
617
618        // Collect training data from other samples where target feature is observed
619        for i in 0..X.nrows() {
620            if i != sample_idx && !self.is_missing(X[[i, target_feature]]) {
621                let mut predictor_row = Vec::new();
622                let mut all_observed = true;
623
624                for j in 0..X.ncols() {
625                    if j != target_feature {
626                        if self.is_missing(X[[i, j]]) {
627                            all_observed = false;
628                            break;
629                        }
630                        predictor_row.push(X[[i, j]]);
631                    }
632                }
633
634                if all_observed && !predictor_row.is_empty() {
635                    predictors.push(predictor_row);
636                    targets.push(X[[i, target_feature]]);
637                }
638            }
639        }
640
641        if predictors.len() < 2 || predictors.is_empty() {
642            return Ok(None);
643        }
644
645        // Simple linear regression (least squares)
646        let n_predictors = predictors[0].len();
647        let n_samples = predictors.len();
648
649        // Build design matrix with intercept
650        let mut design_matrix = Array2::ones((n_samples, n_predictors + 1));
651        for (i, pred_row) in predictors.iter().enumerate() {
652            for (j, &val) in pred_row.iter().enumerate() {
653                design_matrix[[i, j + 1]] = val;
654            }
655        }
656
657        let y = Array1::from_vec(targets);
658
659        // Solve normal equations: (X^T X)^{-1} X^T y
660        let xt = design_matrix.t();
661        let xtx = xt.dot(&design_matrix);
662        let xty = xt.dot(&y);
663
664        // Simple 2x2 matrix inversion for intercept + one predictor
665        if let Some(coefficients) = self.solve_linear_system(&xtx, &xty) {
666            // Make prediction for current sample
667            let mut pred_row = Vec::new();
668            for j in 0..X.ncols() {
669                if j != target_feature && !self.is_missing(X[[sample_idx, j]]) {
670                    pred_row.push(X[[sample_idx, j]]);
671                }
672            }
673
674            if pred_row.len() == n_predictors {
675                let mut prediction = coefficients[0]; // intercept
676                for (i, &val) in pred_row.iter().enumerate() {
677                    prediction += coefficients[i + 1] * val;
678                }
679                return Ok(Some(prediction));
680            }
681        }
682
683        Ok(None)
684    }
685
686    fn solve_linear_system(&self, A: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
687        let n = A.nrows();
688        if n != A.ncols() || n != b.len() || n == 0 {
689            return None;
690        }
691
692        // Simple 2x2 case (intercept + one predictor)
693        if n == 2 {
694            let det = A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]];
695            if det.abs() < 1e-10 {
696                return None;
697            }
698
699            let x0 = (A[[1, 1]] * b[0] - A[[0, 1]] * b[1]) / det;
700            let x1 = (A[[0, 0]] * b[1] - A[[1, 0]] * b[0]) / det;
701
702            return Some(Array1::from_vec(vec![x0, x1]));
703        }
704
705        // For larger systems, use simple Gaussian elimination
706        let mut augmented = Array2::zeros((n, n + 1));
707        for i in 0..n {
708            for j in 0..n {
709                augmented[[i, j]] = A[[i, j]];
710            }
711            augmented[[i, n]] = b[i];
712        }
713
714        // Forward elimination
715        for i in 0..n {
716            // Find pivot
717            let mut max_row = i;
718            for k in (i + 1)..n {
719                if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
720                    max_row = k;
721                }
722            }
723
724            // Swap rows
725            if max_row != i {
726                for j in 0..=n {
727                    let temp = augmented[[i, j]];
728                    augmented[[i, j]] = augmented[[max_row, j]];
729                    augmented[[max_row, j]] = temp;
730                }
731            }
732
733            // Check for singular matrix
734            if augmented[[i, i]].abs() < 1e-10 {
735                return None;
736            }
737
738            // Eliminate
739            for k in (i + 1)..n {
740                let factor = augmented[[k, i]] / augmented[[i, i]];
741                for j in i..=n {
742                    augmented[[k, j]] -= factor * augmented[[i, j]];
743                }
744            }
745        }
746
747        // Back substitution
748        let mut x = Array1::zeros(n);
749        for i in (0..n).rev() {
750            x[i] = augmented[[i, n]];
751            for j in (i + 1)..n {
752                x[i] -= augmented[[i, j]] * x[j];
753            }
754            x[i] /= augmented[[i, i]];
755        }
756
757        Some(x)
758    }
759
760    fn sample_beta(&self, alpha: f64, beta: f64, rng: &mut Random) -> f64 {
761        // Simple rejection sampling for Beta distribution
762        // This is not the most efficient method but works for basic cases
763        if alpha <= 0.0 || beta <= 0.0 {
764            return rng.gen::<f64>();
765        }
766
767        // Use transformation method for Beta(1,1) = Uniform(0,1)
768        if (alpha - 1.0).abs() < 1e-10 && (beta - 1.0).abs() < 1e-10 {
769            return rng.gen::<f64>();
770        }
771
772        // For other cases, use simple approximation
773        let u1: f64 = rng.gen();
774        let u2: f64 = rng.gen();
775
776        let x = u1.powf(1.0 / alpha);
777        let y = u2.powf(1.0 / beta);
778
779        x / (x + y)
780    }
781}
782
783/// Mixed-Type MICE Imputer
784///
785/// Multiple Imputation by Chained Equations specifically designed for mixed-type data.
786/// Handles different variable types appropriately during the chained imputation process.
787///
788/// # Parameters
789///
790/// * `variable_types` - Map from feature index to variable type
791/// * `n_imputations` - Number of multiple imputations to generate
792/// * `max_iter` - Maximum number of iterations for each imputation
793/// * `burn_in` - Number of burn-in iterations before collecting imputations
794/// * `tol` - Tolerance for convergence
795/// * `random_state` - Random state for reproducibility
796///
797/// # Examples
798///
799/// ```rust,ignore
800/// use sklears_impute::{MixedTypeMICEImputer, VariableType};
801/// use sklears_core::traits::{Transform, Fit};
802/// use scirs2_core::ndarray::array;
803/// ///
804/// let X = array![[1.0, 2.0, 3.0], [f64::NAN, 3.0, 4.0], [7.0, f64::NAN, 6.0]];
805/// let mut variable_types = HashMap::new();
806/// variable_types.insert(0, VariableType::Continuous);
807/// variable_types.insert(1, VariableType::Ordinal(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
808/// variable_types.insert(2, VariableType::SemiContinuous { zero_probability: 0.1 });
809///
810/// let imputer = MixedTypeMICEImputer::new()
811///     .variable_types(variable_types)
812///     .n_imputations(5)
813///     .max_iter(20);
814/// let fitted = imputer.fit(&X.view(), &()).unwrap();
815/// let multiple_imputations = fitted.transform_multiple(&X.view()).unwrap();
816/// ```
817#[derive(Debug, Clone)]
818pub struct MixedTypeMICEImputer<S = Untrained> {
819    state: S,
820    variable_types: HashMap<usize, VariableType>,
821    n_imputations: usize,
822    max_iter: usize,
823    burn_in: usize,
824    tol: f64,
825    random_state: Option<u64>,
826    missing_values: f64,
827}
828
829/// Trained state for MixedTypeMICEImputer
830#[derive(Debug, Clone)]
831pub struct MixedTypeMICEImputerTrained {
832    variable_types: HashMap<usize, VariableType>,
833    learned_parameters: HashMap<usize, VariableParameters>,
834    n_features_in_: usize,
835}
836
837/// Multiple imputation results for mixed-type data
838#[derive(Debug, Clone)]
839pub struct MixedTypeMultipleImputationResults {
840    /// imputations
841    pub imputations: Vec<Array2<f64>>,
842    /// pooled_estimates
843    pub pooled_estimates: Option<Array2<f64>>,
844    /// within_imputation_variance
845    pub within_imputation_variance: Option<Array2<f64>>,
846    /// between_imputation_variance
847    pub between_imputation_variance: Option<Array2<f64>>,
848    /// total_variance
849    pub total_variance: Option<Array2<f64>>,
850}
851
852impl MixedTypeMICEImputer<Untrained> {
853    /// Create a new MixedTypeMICEImputer instance
854    pub fn new() -> Self {
855        Self {
856            state: Untrained,
857            variable_types: HashMap::new(),
858            n_imputations: 5,
859            max_iter: 10,
860            burn_in: 5,
861            tol: 1e-4,
862            random_state: None,
863            missing_values: f64::NAN,
864        }
865    }
866
867    /// Set the variable types for each feature
868    pub fn variable_types(mut self, variable_types: HashMap<usize, VariableType>) -> Self {
869        self.variable_types = variable_types;
870        self
871    }
872
873    /// Set the number of imputations
874    pub fn n_imputations(mut self, n_imputations: usize) -> Self {
875        self.n_imputations = n_imputations;
876        self
877    }
878
879    /// Set the maximum number of iterations
880    pub fn max_iter(mut self, max_iter: usize) -> Self {
881        self.max_iter = max_iter;
882        self
883    }
884
885    /// Set the burn-in period
886    pub fn burn_in(mut self, burn_in: usize) -> Self {
887        self.burn_in = burn_in;
888        self
889    }
890
891    /// Set the tolerance for convergence
892    pub fn tol(mut self, tol: f64) -> Self {
893        self.tol = tol;
894        self
895    }
896
897    /// Set the random state
898    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
899        self.random_state = random_state;
900        self
901    }
902
903    /// Set the missing values placeholder
904    pub fn missing_values(mut self, missing_values: f64) -> Self {
905        self.missing_values = missing_values;
906        self
907    }
908}
909
910impl Default for MixedTypeMICEImputer<Untrained> {
911    fn default() -> Self {
912        Self::new()
913    }
914}
915
916impl Estimator for MixedTypeMICEImputer<Untrained> {
917    type Config = ();
918    type Error = SklearsError;
919    type Float = Float;
920
921    fn config(&self) -> &Self::Config {
922        &()
923    }
924}
925
926impl Fit<ArrayView2<'_, Float>, ()> for MixedTypeMICEImputer<Untrained> {
927    type Fitted = MixedTypeMICEImputer<MixedTypeMICEImputerTrained>;
928
929    #[allow(non_snake_case)]
930    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
931        let X = X.mapv(|x| x);
932        let (_, n_features) = X.dim();
933
934        // Use HeterogeneousImputer for initial parameter learning
935        let hetero_imputer = HeterogeneousImputer::new()
936            .variable_types(self.variable_types.clone())
937            .random_state(self.random_state);
938
939        let fitted_hetero = hetero_imputer.fit(&X.view(), &())?;
940
941        Ok(MixedTypeMICEImputer {
942            state: MixedTypeMICEImputerTrained {
943                variable_types: fitted_hetero.state.variable_types.clone(),
944                learned_parameters: fitted_hetero.state.learned_parameters.clone(),
945                n_features_in_: n_features,
946            },
947            variable_types: self.variable_types,
948            n_imputations: self.n_imputations,
949            max_iter: self.max_iter,
950            burn_in: self.burn_in,
951            tol: self.tol,
952            random_state: self.random_state,
953            missing_values: self.missing_values,
954        })
955    }
956}
957
958impl Transform<ArrayView2<'_, Float>, Array2<Float>>
959    for MixedTypeMICEImputer<MixedTypeMICEImputerTrained>
960{
961    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
962        // For single imputation, use the first imputation from multiple imputation
963        let multiple_results = self.transform_multiple(X)?;
964        if let Some(first_imputation) = multiple_results.imputations.first() {
965            Ok(first_imputation.mapv(|x| x as Float))
966        } else {
967            Err(SklearsError::InvalidInput(
968                "No imputations generated".to_string(),
969            ))
970        }
971    }
972}
973
974impl MixedTypeMICEImputer<MixedTypeMICEImputerTrained> {
975    /// Generate multiple imputations
976    #[allow(non_snake_case)]
977    pub fn transform_multiple(
978        &self,
979        X: &ArrayView2<'_, Float>,
980    ) -> SklResult<MixedTypeMultipleImputationResults> {
981        let X = X.mapv(|x| x);
982        let mut imputations = Vec::new();
983
984        let mut base_rng = if let Some(_seed) = self.random_state {
985            Random::default()
986        } else {
987            Random::default()
988        };
989
990        for _m in 0..self.n_imputations {
991            let imputation_seed = base_rng.random::<u64>();
992            let imputation = self.generate_single_imputation(&X, imputation_seed)?;
993            imputations.push(imputation);
994        }
995
996        // Calculate pooled estimates using Rubin's rules
997        let pooled_estimates = self.pool_imputations(&imputations);
998        let (within_var, between_var, total_var) =
999            self.calculate_imputation_variance(&imputations, &pooled_estimates);
1000
1001        Ok(MixedTypeMultipleImputationResults {
1002            imputations,
1003            pooled_estimates: Some(pooled_estimates),
1004            within_imputation_variance: Some(within_var),
1005            between_imputation_variance: Some(between_var),
1006            total_variance: Some(total_var),
1007        })
1008    }
1009
1010    fn generate_single_imputation(&self, X: &Array2<f64>, _seed: u64) -> SklResult<Array2<f64>> {
1011        let mut X_imputed = X.clone();
1012        let mut rng = Random::default();
1013
1014        // Initialize missing values with simple imputation
1015        self.initialize_missing_values(&mut X_imputed, &mut rng)?;
1016
1017        // MICE iterations
1018        for iteration in 0..(self.burn_in + self.max_iter) {
1019            let prev_X = X_imputed.clone();
1020
1021            for (&feature_idx, var_type) in &self.state.variable_types {
1022                if let Some(params) = self.state.learned_parameters.get(&feature_idx) {
1023                    self.update_feature_mice(
1024                        &mut X_imputed,
1025                        X,
1026                        feature_idx,
1027                        var_type,
1028                        params,
1029                        &mut rng,
1030                    )?;
1031                }
1032            }
1033
1034            // Check convergence after burn-in
1035            if iteration >= self.burn_in {
1036                let max_change = self.calculate_max_change(&prev_X, &X_imputed, X);
1037                if max_change < self.tol {
1038                    break;
1039                }
1040            }
1041        }
1042
1043        Ok(X_imputed)
1044    }
1045
1046    fn initialize_missing_values(
1047        &self,
1048        X_imputed: &mut Array2<f64>,
1049        rng: &mut Random,
1050    ) -> SklResult<()> {
1051        let (n_samples, n_features) = X_imputed.dim();
1052
1053        for j in 0..n_features {
1054            if let (Some(var_type), Some(params)) = (
1055                self.state.variable_types.get(&j),
1056                self.state.learned_parameters.get(&j),
1057            ) {
1058                for i in 0..n_samples {
1059                    if self.is_missing(X_imputed[[i, j]]) {
1060                        let initial_value = match (var_type, params) {
1061                            (
1062                                VariableType::Continuous,
1063                                VariableParameters::ContinuousParams { mean, .. },
1064                            ) => *mean,
1065                            (VariableType::Ordinal(levels), _) => {
1066                                let idx = rng.gen_range(0..levels.len());
1067                                levels[idx]
1068                            }
1069                            (VariableType::Categorical(categories), _) => {
1070                                let idx = rng.gen_range(0..categories.len());
1071                                categories[idx]
1072                            }
1073                            (
1074                                VariableType::SemiContinuous { .. },
1075                                VariableParameters::SemiContinuousParams {
1076                                    continuous_mean, ..
1077                                },
1078                            ) => *continuous_mean,
1079                            (VariableType::Bounded { lower, upper }, _) => {
1080                                lower + (upper - lower) * rng.gen::<f64>()
1081                            }
1082                            (
1083                                VariableType::Binary,
1084                                VariableParameters::BinaryParams { probability },
1085                            ) => {
1086                                if rng.gen::<f64>() < *probability {
1087                                    1.0
1088                                } else {
1089                                    0.0
1090                                }
1091                            }
1092                            _ => 0.0,
1093                        };
1094                        X_imputed[[i, j]] = initial_value;
1095                    }
1096                }
1097            }
1098        }
1099
1100        Ok(())
1101    }
1102
1103    fn update_feature_mice(
1104        &self,
1105        X_imputed: &mut Array2<f64>,
1106        X_original: &Array2<f64>,
1107        feature_idx: usize,
1108        var_type: &VariableType,
1109        params: &VariableParameters,
1110        rng: &mut Random,
1111    ) -> SklResult<()> {
1112        let (n_samples, _) = X_imputed.dim();
1113
1114        // Create temporary imputer for this feature
1115        let hetero_imputer = HeterogeneousImputer {
1116            state: HeterogeneousImputerTrained {
1117                variable_types: self.state.variable_types.clone(),
1118                learned_parameters: self.state.learned_parameters.clone(),
1119                n_features_in_: self.state.n_features_in_,
1120            },
1121            variable_types: HashMap::new(),
1122            max_iter: 1,
1123            tol: self.tol,
1124            random_state: Some(rng.gen::<u64>()),
1125            missing_values: self.missing_values,
1126        };
1127
1128        for i in 0..n_samples {
1129            if self.is_missing(X_original[[i, feature_idx]]) {
1130                let imputed_value = hetero_imputer.impute_value(
1131                    var_type,
1132                    params,
1133                    // X_imputed
1134                    X_imputed,
1135                    i,
1136                    feature_idx,
1137                    rng,
1138                )?;
1139                X_imputed[[i, feature_idx]] = imputed_value;
1140            }
1141        }
1142
1143        Ok(())
1144    }
1145
1146    fn calculate_max_change(
1147        &self,
1148        prev_X: &Array2<f64>,
1149        current_X: &Array2<f64>,
1150        original_X: &Array2<f64>,
1151    ) -> f64 {
1152        let mut max_change: f64 = 0.0;
1153
1154        for ((i, j), &orig_val) in original_X.indexed_iter() {
1155            if self.is_missing(orig_val) {
1156                let change = (prev_X[[i, j]] - current_X[[i, j]]).abs();
1157                max_change = max_change.max(change);
1158            }
1159        }
1160
1161        max_change
1162    }
1163
1164    fn pool_imputations(&self, imputations: &[Array2<f64>]) -> Array2<f64> {
1165        if imputations.is_empty() {
1166            return Array2::zeros((0, 0));
1167        }
1168
1169        let (n_samples, n_features) = imputations[0].dim();
1170        let mut pooled = Array2::zeros((n_samples, n_features));
1171
1172        for i in 0..n_samples {
1173            for j in 0..n_features {
1174                let sum: f64 = imputations.iter().map(|imp| imp[[i, j]]).sum();
1175                pooled[[i, j]] = sum / imputations.len() as f64;
1176            }
1177        }
1178
1179        pooled
1180    }
1181
1182    fn calculate_imputation_variance(
1183        &self,
1184        imputations: &[Array2<f64>],
1185        pooled: &Array2<f64>,
1186    ) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
1187        if imputations.is_empty() {
1188            let zero_mat = Array2::zeros((0, 0));
1189            return (zero_mat.clone(), zero_mat.clone(), zero_mat);
1190        }
1191
1192        let (n_samples, n_features) = pooled.dim();
1193        let m = imputations.len() as f64;
1194
1195        let mut within_var = Array2::zeros((n_samples, n_features));
1196        let mut between_var = Array2::zeros((n_samples, n_features));
1197
1198        // Within-imputation variance (average of individual variances)
1199        for imp in imputations {
1200            for i in 0..n_samples {
1201                for j in 0..n_features {
1202                    let diff = imp[[i, j]] - pooled[[i, j]];
1203                    within_var[[i, j]] += diff * diff;
1204                }
1205            }
1206        }
1207        within_var /= m;
1208
1209        // Between-imputation variance
1210        for imp in imputations {
1211            for i in 0..n_samples {
1212                for j in 0..n_features {
1213                    let diff = imp[[i, j]] - pooled[[i, j]];
1214                    between_var[[i, j]] += diff * diff;
1215                }
1216            }
1217        }
1218        between_var /= m - 1.0;
1219
1220        // Total variance (Rubin's rule)
1221        let total_var = &within_var + &between_var * (1.0 + 1.0 / m);
1222
1223        (within_var, between_var, total_var)
1224    }
1225
1226    fn is_missing(&self, value: f64) -> bool {
1227        if self.missing_values.is_nan() {
1228            value.is_nan()
1229        } else {
1230            (value - self.missing_values).abs() < f64::EPSILON
1231        }
1232    }
1233}
1234
1235/// Ordinal Variable Imputer
1236///
1237/// Specialized imputation for ordinal categorical variables that respects
1238/// the ordered nature of the categories.
1239///
1240/// # Parameters
1241///
1242/// * `levels` - Ordered levels of the ordinal variable
1243/// * `method` - Imputation method ("mode", "proportional_odds", "adjacent_categories")
1244/// * `random_state` - Random state for reproducibility
1245///
1246/// # Examples
1247///
1248/// ```
1249/// use sklears_impute::OrdinalImputer;
1250/// use sklears_core::traits::{Transform, Fit};
1251/// use scirs2_core::ndarray::array;
1252///
1253/// let X = array![[1.0], [2.0], [f64::NAN], [3.0], [1.0]];
1254/// let levels = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1255///
1256/// let imputer = OrdinalImputer::new()
1257///     .levels(levels)
1258///     .method("proportional_odds".to_string());
1259/// let fitted = imputer.fit(&X.view(), &()).unwrap();
1260/// let X_imputed = fitted.transform(&X.view()).unwrap();
1261/// ```
1262#[derive(Debug, Clone)]
1263pub struct OrdinalImputer<S = Untrained> {
1264    state: S,
1265    levels: Vec<f64>,
1266    method: String,
1267    random_state: Option<u64>,
1268    missing_values: f64,
1269}
1270
1271/// Trained state for OrdinalImputer
1272#[derive(Debug, Clone)]
1273pub struct OrdinalImputerTrained {
1274    levels: Vec<f64>,
1275    level_probabilities: Array1<f64>,
1276    cumulative_probabilities: Array1<f64>,
1277    transition_matrix: Option<Array2<f64>>,
1278    n_features_in_: usize,
1279}
1280
1281impl OrdinalImputer<Untrained> {
1282    /// Create a new OrdinalImputer instance
1283    pub fn new() -> Self {
1284        Self {
1285            state: Untrained,
1286            levels: Vec::new(),
1287            method: "mode".to_string(),
1288            random_state: None,
1289            missing_values: f64::NAN,
1290        }
1291    }
1292
1293    /// Set the ordered levels
1294    pub fn levels(mut self, levels: Vec<f64>) -> Self {
1295        self.levels = levels;
1296        self
1297    }
1298
1299    /// Set the imputation method
1300    pub fn method(mut self, method: String) -> Self {
1301        self.method = method;
1302        self
1303    }
1304
1305    /// Set the random state
1306    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
1307        self.random_state = random_state;
1308        self
1309    }
1310
1311    /// Set the missing values placeholder
1312    pub fn missing_values(mut self, missing_values: f64) -> Self {
1313        self.missing_values = missing_values;
1314        self
1315    }
1316}
1317
1318impl Default for OrdinalImputer<Untrained> {
1319    fn default() -> Self {
1320        Self::new()
1321    }
1322}
1323
1324impl Estimator for OrdinalImputer<Untrained> {
1325    type Config = ();
1326    type Error = SklearsError;
1327    type Float = Float;
1328
1329    fn config(&self) -> &Self::Config {
1330        &()
1331    }
1332}
1333
1334impl Fit<ArrayView2<'_, Float>, ()> for OrdinalImputer<Untrained> {
1335    type Fitted = OrdinalImputer<OrdinalImputerTrained>;
1336
1337    #[allow(non_snake_case)]
1338    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
1339        let X = X.mapv(|x| x);
1340        let (_, n_features) = X.dim();
1341
1342        if n_features != 1 {
1343            return Err(SklearsError::InvalidInput(
1344                "OrdinalImputer only supports single-column input".to_string(),
1345            ));
1346        }
1347
1348        let column = X.column(0);
1349        let observed_values: Vec<f64> = column
1350            .iter()
1351            .filter(|&&x| !self.is_missing(x))
1352            .cloned()
1353            .collect();
1354
1355        if observed_values.is_empty() {
1356            return Err(SklearsError::InvalidInput(
1357                "No observed values found".to_string(),
1358            ));
1359        }
1360
1361        // Auto-detect levels if not provided
1362        let levels = if self.levels.is_empty() {
1363            let mut unique_values: Vec<f64> = observed_values.clone().into_iter().collect();
1364            unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1365            unique_values.dedup();
1366            unique_values
1367        } else {
1368            self.levels.clone()
1369        };
1370
1371        // Calculate level probabilities
1372        let mut level_counts = Array1::<f64>::zeros(levels.len());
1373        let total_count = column.len() as f64;
1374
1375        for &value in column.iter() {
1376            if !self.is_missing(value) {
1377                if let Some(idx) = levels
1378                    .iter()
1379                    .position(|&level| (level - value).abs() < 1e-10)
1380                {
1381                    level_counts[idx] += 1.0;
1382                }
1383            }
1384        }
1385
1386        let level_probabilities = level_counts.mapv(|count: f64| count / total_count);
1387
1388        // Calculate cumulative probabilities
1389        let mut cumulative_probabilities = Array1::<f64>::zeros(levels.len());
1390        cumulative_probabilities[0] = level_probabilities[0];
1391        for i in 1..levels.len() {
1392            cumulative_probabilities[i] = cumulative_probabilities[i - 1] + level_probabilities[i];
1393        }
1394
1395        // Calculate transition matrix for adjacent categories method
1396        let transition_matrix = if self.method == "adjacent_categories" {
1397            Some(self.estimate_transition_matrix(&levels, &observed_values))
1398        } else {
1399            None
1400        };
1401
1402        Ok(OrdinalImputer {
1403            state: OrdinalImputerTrained {
1404                levels,
1405                level_probabilities,
1406                cumulative_probabilities,
1407                transition_matrix,
1408                n_features_in_: n_features,
1409            },
1410            levels: self.levels,
1411            method: self.method,
1412            random_state: self.random_state,
1413            missing_values: self.missing_values,
1414        })
1415    }
1416}
1417
1418impl OrdinalImputer<Untrained> {
1419    fn is_missing(&self, value: f64) -> bool {
1420        if self.missing_values.is_nan() {
1421            value.is_nan()
1422        } else {
1423            (value - self.missing_values).abs() < f64::EPSILON
1424        }
1425    }
1426
1427    fn estimate_transition_matrix(&self, levels: &[f64], observed_values: &[f64]) -> Array2<f64> {
1428        let n_levels = levels.len();
1429        let mut transition_counts = Array2::zeros((n_levels, n_levels));
1430
1431        // Count transitions between adjacent observations
1432        for window in observed_values.windows(2) {
1433            if let (Some(from_idx), Some(to_idx)) = (
1434                levels
1435                    .iter()
1436                    .position(|&level| (level - window[0]).abs() < 1e-10),
1437                levels
1438                    .iter()
1439                    .position(|&level| (level - window[1]).abs() < 1e-10),
1440            ) {
1441                transition_counts[[from_idx, to_idx]] += 1.0;
1442            }
1443        }
1444
1445        // Normalize to probabilities
1446        let mut transition_matrix = Array2::zeros((n_levels, n_levels));
1447        for i in 0..n_levels {
1448            let row_sum: f64 = transition_counts.row(i).sum();
1449            if row_sum > 0.0 {
1450                for j in 0..n_levels {
1451                    transition_matrix[[i, j]] = transition_counts[[i, j]] / row_sum;
1452                }
1453            } else {
1454                // Uniform distribution if no transitions observed
1455                for j in 0..n_levels {
1456                    transition_matrix[[i, j]] = 1.0 / n_levels as f64;
1457                }
1458            }
1459        }
1460
1461        transition_matrix
1462    }
1463}
1464
1465impl Transform<ArrayView2<'_, Float>, Array2<Float>> for OrdinalImputer<OrdinalImputerTrained> {
1466    #[allow(non_snake_case)]
1467    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1468        let X = X.mapv(|x| x);
1469        let (n_samples, n_features) = X.dim();
1470
1471        if n_features != self.state.n_features_in_ {
1472            return Err(SklearsError::InvalidInput(format!(
1473                "Number of features {} does not match training features {}",
1474                n_features, self.state.n_features_in_
1475            )));
1476        }
1477
1478        let mut X_imputed = X.clone();
1479        let mut rng = Random::default();
1480
1481        for i in 0..n_samples {
1482            if self.is_missing(X_imputed[[i, 0]]) {
1483                let imputed_value = match self.method.as_str() {
1484                    "mode" => self.impute_mode(&mut rng),
1485                    "proportional_odds" => self.impute_proportional_odds(&mut rng),
1486                    "adjacent_categories" => {
1487                        self.impute_adjacent_categories(&X_imputed, i, &mut rng)
1488                    }
1489                    _ => self.impute_mode(&mut rng),
1490                };
1491                X_imputed[[i, 0]] = imputed_value;
1492            }
1493        }
1494
1495        Ok(X_imputed.mapv(|x| x as Float))
1496    }
1497}
1498
1499impl OrdinalImputer<OrdinalImputerTrained> {
1500    fn is_missing(&self, value: f64) -> bool {
1501        if self.missing_values.is_nan() {
1502            value.is_nan()
1503        } else {
1504            (value - self.missing_values).abs() < f64::EPSILON
1505        }
1506    }
1507
1508    fn impute_mode(&self, _rng: &mut Random) -> f64 {
1509        // Return the level with highest probability
1510        let max_idx = self
1511            .state
1512            .level_probabilities
1513            .iter()
1514            .enumerate()
1515            .max_by(|(_, &a), (_, &b)| a.partial_cmp(&b).unwrap())
1516            .map(|(idx, _)| idx)
1517            .unwrap_or(0);
1518
1519        self.state.levels.get(max_idx).copied().unwrap_or(0.0)
1520    }
1521
1522    fn impute_proportional_odds(&self, rng: &mut Random) -> f64 {
1523        // Sample from cumulative distribution
1524        let random_val: f64 = rng.gen();
1525
1526        for (i, &cum_prob) in self.state.cumulative_probabilities.iter().enumerate() {
1527            if random_val <= cum_prob {
1528                return self.state.levels.get(i).copied().unwrap_or(0.0);
1529            }
1530        }
1531
1532        // Fallback to last level
1533        self.state.levels.last().copied().unwrap_or(0.0)
1534    }
1535
1536    fn impute_adjacent_categories(
1537        &self,
1538        X: &Array2<f64>,
1539        sample_idx: usize,
1540        rng: &mut Random,
1541    ) -> f64 {
1542        // Find nearest observed values to inform imputation
1543        if let Some(ref transition_matrix) = self.state.transition_matrix {
1544            // Look for adjacent observed values
1545            let column = X.column(0);
1546
1547            // Find closest observed value
1548            let mut closest_value = None;
1549            let mut min_distance = usize::MAX;
1550
1551            for (i, &value) in column.iter().enumerate() {
1552                if !self.is_missing(value) {
1553                    let distance = (i as i32 - sample_idx as i32).unsigned_abs() as usize;
1554                    if distance < min_distance {
1555                        min_distance = distance;
1556                        closest_value = Some(value);
1557                    }
1558                }
1559            }
1560
1561            if let Some(closest_val) = closest_value {
1562                if let Some(from_idx) = self
1563                    .state
1564                    .levels
1565                    .iter()
1566                    .position(|&level| (level - closest_val).abs() < 1e-10)
1567                {
1568                    // Sample from transition probabilities
1569                    let random_val: f64 = rng.gen();
1570                    let mut cumulative = 0.0;
1571
1572                    for (to_idx, &prob) in transition_matrix.row(from_idx).iter().enumerate() {
1573                        cumulative += prob;
1574                        if random_val <= cumulative {
1575                            return self.state.levels.get(to_idx).copied().unwrap_or(0.0);
1576                        }
1577                    }
1578                }
1579            }
1580        }
1581
1582        // Fallback to proportional odds
1583        self.impute_proportional_odds(rng)
1584    }
1585}
1586
1587#[allow(non_snake_case)]
1588#[cfg(test)]
1589mod tests {
1590    use super::*;
1591    use approx::assert_abs_diff_eq;
1592    use scirs2_core::ndarray::array;
1593    use sklears_core::traits::Transform;
1594
1595    #[test]
1596    fn test_heterogeneous_imputer_basic() {
1597        let data = array![[1.0, 2.0, 0.5], [f64::NAN, 3.0, 0.8], [3.0, f64::NAN, 0.0]];
1598
1599        let mut variable_types = HashMap::new();
1600        variable_types.insert(0, VariableType::Continuous);
1601        variable_types.insert(1, VariableType::Ordinal(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
1602        variable_types.insert(
1603            2,
1604            VariableType::Bounded {
1605                lower: 0.0,
1606                upper: 1.0,
1607            },
1608        );
1609
1610        let imputer = HeterogeneousImputer::new()
1611            .variable_types(variable_types)
1612            .max_iter(10);
1613
1614        let fitted = imputer.fit(&data.view(), &()).unwrap();
1615        let result = fitted.transform(&data.view()).unwrap();
1616
1617        // Should have no missing values
1618        assert!(!result.iter().any(|&x| (x).is_nan()));
1619
1620        // Non-missing values should be preserved
1621        assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1622        assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
1623        assert_abs_diff_eq!(result[[0, 2]], 0.5, epsilon = 1e-10);
1624    }
1625
1626    #[test]
1627    fn test_mixed_type_mice_basic() {
1628        let data = array![[1.0, 2.0, 0.0], [f64::NAN, 3.0, 1.0], [3.0, f64::NAN, 0.0]];
1629
1630        let mut variable_types = HashMap::new();
1631        variable_types.insert(0, VariableType::Continuous);
1632        variable_types.insert(1, VariableType::Ordinal(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
1633        variable_types.insert(
1634            2,
1635            VariableType::SemiContinuous {
1636                zero_probability: 0.6,
1637            },
1638        );
1639
1640        let imputer = MixedTypeMICEImputer::new()
1641            .variable_types(variable_types)
1642            .n_imputations(3)
1643            .max_iter(5);
1644
1645        let fitted = imputer.fit(&data.view(), &()).unwrap();
1646        let results = fitted.transform_multiple(&data.view()).unwrap();
1647
1648        // Should generate requested number of imputations
1649        assert_eq!(results.imputations.len(), 3);
1650
1651        // Each imputation should have no missing values
1652        for imputation in &results.imputations {
1653            assert!(!imputation.iter().any(|&x| x.is_nan()));
1654        }
1655
1656        // Should have pooled estimates
1657        assert!(results.pooled_estimates.is_some());
1658    }
1659
1660    #[test]
1661    fn test_ordinal_imputer_basic() {
1662        let data = array![[1.0], [2.0], [f64::NAN], [3.0], [1.0]];
1663        let levels = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1664
1665        let imputer = OrdinalImputer::new()
1666            .levels(levels)
1667            .method("mode".to_string());
1668
1669        let fitted = imputer.fit(&data.view(), &()).unwrap();
1670        let result = fitted.transform(&data.view()).unwrap();
1671
1672        // Should have no missing values
1673        assert!(!result.iter().any(|&x| (x).is_nan()));
1674
1675        // Non-missing values should be preserved
1676        assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1677        assert_abs_diff_eq!(result[[1, 0]], 2.0, epsilon = 1e-10);
1678        assert_abs_diff_eq!(result[[3, 0]], 3.0, epsilon = 1e-10);
1679        assert_abs_diff_eq!(result[[4, 0]], 1.0, epsilon = 1e-10);
1680
1681        // Imputed value should be one of the levels
1682        let imputed_val = result[[2, 0]];
1683        assert!([1.0, 2.0, 3.0, 4.0, 5.0].contains(&imputed_val));
1684    }
1685
1686    #[test]
1687    fn test_variable_type_auto_detection() {
1688        let data = array![[1.0, 1.0, 0.5], [2.0, 0.0, 0.8], [3.0, 1.0, 0.0]];
1689
1690        let imputer = HeterogeneousImputer::new().max_iter(5);
1691        let fitted = imputer.fit(&data.view(), &()).unwrap();
1692
1693        // Should auto-detect variable types
1694        let variable_types = &fitted.state.variable_types;
1695
1696        // First column should be detected as ordinal (few integer values)
1697        if let Some(VariableType::Ordinal(_)) = variable_types.get(&0) {
1698            // Expected
1699        } else if let Some(VariableType::Continuous) = variable_types.get(&0) {
1700            // Also acceptable
1701        } else {
1702            panic!("Unexpected variable type for first column");
1703        }
1704
1705        // Second column should be detected as semi-continuous or binary
1706        assert!(variable_types.contains_key(&1));
1707
1708        // Third column should be detected as bounded (values in [0,1])
1709        if let Some(VariableType::Bounded { lower, upper }) = variable_types.get(&2) {
1710            assert_abs_diff_eq!(*lower, 0.0, epsilon = 1e-10);
1711            assert_abs_diff_eq!(*upper, 1.0, epsilon = 1e-10);
1712        }
1713    }
1714}