1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use scirs2_core::random::{Random, Rng};
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Transform, Untrained},
11 types::Float,
12};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum VariableType {
18 Continuous,
20 Ordinal(Vec<f64>),
22 Categorical(Vec<f64>),
24 SemiContinuous { zero_probability: f64 },
26 Bounded { lower: f64, upper: f64 },
28 Binary,
30}
31
32#[derive(Debug, Clone)]
34pub struct VariableMetadata {
35 pub variable_type: VariableType,
37 pub missing_pattern: String,
39 pub is_target: bool,
41}
42
43#[derive(Debug, Clone)]
75pub struct HeterogeneousImputer<S = Untrained> {
76 state: S,
77 variable_types: HashMap<usize, VariableType>,
78 max_iter: usize,
79 tol: f64,
80 random_state: Option<u64>,
81 missing_values: f64,
82}
83
84#[derive(Debug, Clone)]
86pub struct HeterogeneousImputerTrained {
87 variable_types: HashMap<usize, VariableType>,
88 learned_parameters: HashMap<usize, VariableParameters>,
89 n_features_in_: usize,
90}
91
92#[derive(Debug, Clone)]
94pub enum VariableParameters {
95 ContinuousParams {
97 mean: f64,
98 std: f64,
99 coefficients: Option<Array1<f64>>,
100 },
101 OrdinalParams {
103 levels: Vec<f64>,
104 probabilities: Array1<f64>,
105 transition_matrix: Option<Array2<f64>>,
106 },
107 CategoricalParams {
109 categories: Vec<f64>,
110 probabilities: Array1<f64>,
111 },
112 SemiContinuousParams {
114 zero_prob: f64,
115 continuous_mean: f64,
116 continuous_std: f64,
117 threshold: f64,
118 },
119 BoundedParams {
121 lower: f64,
122 upper: f64,
123 beta_alpha: f64,
124 beta_beta: f64,
125 },
126 BinaryParams { probability: f64 },
128}
129
130impl HeterogeneousImputer<Untrained> {
131 pub fn new() -> Self {
133 Self {
134 state: Untrained,
135 variable_types: HashMap::new(),
136 max_iter: 100,
137 tol: 1e-4,
138 random_state: None,
139 missing_values: f64::NAN,
140 }
141 }
142
143 pub fn variable_types(mut self, variable_types: HashMap<usize, VariableType>) -> Self {
145 self.variable_types = variable_types;
146 self
147 }
148
149 pub fn max_iter(mut self, max_iter: usize) -> Self {
151 self.max_iter = max_iter;
152 self
153 }
154
155 pub fn tol(mut self, tol: f64) -> Self {
157 self.tol = tol;
158 self
159 }
160
161 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
163 self.random_state = random_state;
164 self
165 }
166
167 pub fn missing_values(mut self, missing_values: f64) -> Self {
169 self.missing_values = missing_values;
170 self
171 }
172
173 fn is_missing(&self, value: f64) -> bool {
174 if self.missing_values.is_nan() {
175 value.is_nan()
176 } else {
177 (value - self.missing_values).abs() < f64::EPSILON
178 }
179 }
180}
181
182impl Default for HeterogeneousImputer<Untrained> {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl Estimator for HeterogeneousImputer<Untrained> {
189 type Config = ();
190 type Error = SklearsError;
191 type Float = Float;
192
193 fn config(&self) -> &Self::Config {
194 &()
195 }
196}
197
198impl Fit<ArrayView2<'_, Float>, ()> for HeterogeneousImputer<Untrained> {
199 type Fitted = HeterogeneousImputer<HeterogeneousImputerTrained>;
200
201 #[allow(non_snake_case)]
202 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
203 let X = X.mapv(|x| x);
204 let (_, n_features) = X.dim();
205
206 let variable_types = if self.variable_types.is_empty() {
208 self.auto_detect_variable_types(&X)?
209 } else {
210 self.variable_types.clone()
211 };
212
213 let mut learned_parameters = HashMap::new();
215
216 for (&feature_idx, var_type) in &variable_types {
217 if feature_idx < n_features {
218 let column = X.column(feature_idx);
219 let observed_values: Vec<f64> = column
220 .iter()
221 .filter(|&&x| !self.is_missing(x))
222 .cloned()
223 .collect();
224
225 if !observed_values.is_empty() {
226 let params = self.learn_variable_parameters(var_type, &observed_values)?;
227 learned_parameters.insert(feature_idx, params);
228 }
229 }
230 }
231
232 Ok(HeterogeneousImputer {
233 state: HeterogeneousImputerTrained {
234 variable_types,
235 learned_parameters,
236 n_features_in_: n_features,
237 },
238 variable_types: self.variable_types,
239 max_iter: self.max_iter,
240 tol: self.tol,
241 random_state: self.random_state,
242 missing_values: self.missing_values,
243 })
244 }
245}
246
247impl HeterogeneousImputer<Untrained> {
248 fn auto_detect_variable_types(
249 &self,
250 X: &Array2<f64>,
251 ) -> SklResult<HashMap<usize, VariableType>> {
252 let mut variable_types = HashMap::new();
253 let (_, n_features) = X.dim();
254
255 for j in 0..n_features {
256 let column = X.column(j);
257 let observed_values: Vec<f64> = column
258 .iter()
259 .filter(|&&x| !self.is_missing(x))
260 .cloned()
261 .collect();
262
263 if observed_values.is_empty() {
264 continue;
265 }
266
267 let var_type = self.detect_variable_type(&observed_values);
268 variable_types.insert(j, var_type);
269 }
270
271 Ok(variable_types)
272 }
273
274 fn detect_variable_type(&self, values: &[f64]) -> VariableType {
275 let unique_values: std::collections::HashSet<_> = values
276 .iter()
277 .map(|&x| (x * 1000.0).round() as i64)
278 .collect();
279
280 if unique_values.len() == 2 {
282 return VariableType::Binary;
283 }
284
285 let all_integers = values.iter().all(|&x| x.fract() == 0.0);
287
288 if all_integers && unique_values.len() <= 10 {
289 let mut sorted_values: Vec<f64> = values.to_vec();
291 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
292 sorted_values.dedup();
293 return VariableType::Ordinal(sorted_values);
294 }
295
296 let zero_count = values.iter().filter(|&&x| x == 0.0).count();
298 let zero_proportion = zero_count as f64 / values.len() as f64;
299
300 if zero_proportion > 0.1 && zero_proportion < 0.9 {
301 return VariableType::SemiContinuous {
302 zero_probability: zero_proportion,
303 };
304 }
305
306 let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
308 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
309
310 if min_val >= 0.0 && max_val <= 1.0 {
311 return VariableType::Bounded {
312 lower: 0.0,
313 upper: 1.0,
314 };
315 }
316
317 VariableType::Continuous
319 }
320
321 fn learn_variable_parameters(
322 &self,
323 var_type: &VariableType,
324 observed_values: &[f64],
325 ) -> SklResult<VariableParameters> {
326 match var_type {
327 VariableType::Continuous => {
328 let mean = observed_values.iter().sum::<f64>() / observed_values.len() as f64;
329 let variance = observed_values
330 .iter()
331 .map(|&x| (x - mean).powi(2))
332 .sum::<f64>()
333 / (observed_values.len() as f64 - 1.0).max(1.0);
334 let std = variance.sqrt();
335
336 Ok(VariableParameters::ContinuousParams {
337 mean,
338 std,
339 coefficients: None,
340 })
341 }
342 VariableType::Ordinal(levels) => {
343 let mut probabilities = Array1::zeros(levels.len());
344 let total_count = observed_values.len() as f64;
345
346 for &value in observed_values {
347 if let Some(idx) = levels
348 .iter()
349 .position(|&level| (level - value).abs() < 1e-10)
350 {
351 probabilities[idx] += 1.0 / total_count;
352 }
353 }
354
355 Ok(VariableParameters::OrdinalParams {
356 levels: levels.clone(),
357 probabilities,
358 transition_matrix: None,
359 })
360 }
361 VariableType::Categorical(categories) => {
362 let mut probabilities = Array1::zeros(categories.len());
363 let total_count = observed_values.len() as f64;
364
365 for &value in observed_values {
366 if let Some(idx) = categories
367 .iter()
368 .position(|&cat| (cat - value).abs() < 1e-10)
369 {
370 probabilities[idx] += 1.0 / total_count;
371 }
372 }
373
374 Ok(VariableParameters::CategoricalParams {
375 categories: categories.clone(),
376 probabilities,
377 })
378 }
379 VariableType::SemiContinuous {
380 zero_probability: _,
381 } => {
382 let zero_count = observed_values.iter().filter(|&&x| x == 0.0).count();
383 let zero_prob = zero_count as f64 / observed_values.len() as f64;
384
385 let non_zero_values: Vec<f64> = observed_values
386 .iter()
387 .filter(|&&x| x != 0.0)
388 .cloned()
389 .collect();
390
391 let (continuous_mean, continuous_std) = if non_zero_values.is_empty() {
392 (0.0, 1.0)
393 } else {
394 let mean = non_zero_values.iter().sum::<f64>() / non_zero_values.len() as f64;
395 let variance = non_zero_values
396 .iter()
397 .map(|&x| (x - mean).powi(2))
398 .sum::<f64>()
399 / (non_zero_values.len() as f64 - 1.0).max(1.0);
400 (mean, variance.sqrt())
401 };
402
403 Ok(VariableParameters::SemiContinuousParams {
404 zero_prob,
405 continuous_mean,
406 continuous_std,
407 threshold: 0.0,
408 })
409 }
410 VariableType::Bounded { lower, upper } => {
411 let mean = observed_values.iter().sum::<f64>() / observed_values.len() as f64;
413 let variance = observed_values
414 .iter()
415 .map(|&x| (x - mean).powi(2))
416 .sum::<f64>()
417 / (observed_values.len() as f64 - 1.0).max(1.0);
418
419 let range = upper - lower;
421 let scaled_mean = (mean - lower) / range;
422 let scaled_variance = variance / (range * range);
423
424 let alpha =
426 scaled_mean * (scaled_mean * (1.0 - scaled_mean) / scaled_variance - 1.0);
427 let beta = (1.0 - scaled_mean)
428 * (scaled_mean * (1.0 - scaled_mean) / scaled_variance - 1.0);
429
430 Ok(VariableParameters::BoundedParams {
431 lower: *lower,
432 upper: *upper,
433 beta_alpha: alpha.max(0.1),
434 beta_beta: beta.max(0.1),
435 })
436 }
437 VariableType::Binary => {
438 let ones = observed_values.iter().filter(|&&x| x == 1.0).count();
439 let probability = ones as f64 / observed_values.len() as f64;
440
441 Ok(VariableParameters::BinaryParams { probability })
442 }
443 }
444 }
445}
446
447impl Transform<ArrayView2<'_, Float>, Array2<Float>>
448 for HeterogeneousImputer<HeterogeneousImputerTrained>
449{
450 #[allow(non_snake_case)]
451 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
452 let X = X.mapv(|x| x);
453 let (n_samples, n_features) = X.dim();
454
455 if n_features != self.state.n_features_in_ {
456 return Err(SklearsError::InvalidInput(format!(
457 "Number of features {} does not match training features {}",
458 n_features, self.state.n_features_in_
459 )));
460 }
461
462 let mut X_imputed = X.clone();
463 let mut rng = Random::default();
464
465 for iteration in 0..self.max_iter {
467 let mut converged = true;
468 let _prev_X = X_imputed.clone();
469
470 for (&feature_idx, var_type) in &self.state.variable_types {
471 if let Some(params) = self.state.learned_parameters.get(&feature_idx) {
472 for i in 0..n_samples {
473 if self.is_missing(X[[i, feature_idx]]) {
474 let imputed_value = self.impute_value(
475 var_type,
476 params,
477 &X_imputed,
478 i,
479 feature_idx,
480 &mut rng,
481 )?;
482
483 let old_value = X_imputed[[i, feature_idx]];
484 X_imputed[[i, feature_idx]] = imputed_value;
485
486 if (old_value - imputed_value).abs() > self.tol {
487 converged = false;
488 }
489 }
490 }
491 }
492 }
493
494 if converged && iteration > 0 {
495 break;
496 }
497 }
498
499 Ok(X_imputed.mapv(|x| x as Float))
500 }
501}
502
503impl HeterogeneousImputer<HeterogeneousImputerTrained> {
504 fn is_missing(&self, value: f64) -> bool {
505 if self.missing_values.is_nan() {
506 value.is_nan()
507 } else {
508 (value - self.missing_values).abs() < f64::EPSILON
509 }
510 }
511
512 fn impute_value(
513 &self,
514 var_type: &VariableType,
515 params: &VariableParameters,
516 X: &Array2<f64>,
517 sample_idx: usize,
518 feature_idx: usize,
519 rng: &mut Random,
520 ) -> SklResult<f64> {
521 match (var_type, params) {
522 (VariableType::Continuous, VariableParameters::ContinuousParams { mean, std, .. }) => {
523 if let Some(predicted) = self.predict_continuous(X, sample_idx, feature_idx)? {
525 Ok(predicted)
526 } else {
527 Ok(mean + std * rng.gen::<f64>())
528 }
529 }
530 (
531 VariableType::Ordinal(levels),
532 VariableParameters::OrdinalParams { probabilities, .. },
533 ) => {
534 let random_val: f64 = rng.gen();
536 let mut cumulative = 0.0;
537
538 for (i, &prob) in probabilities.iter().enumerate() {
539 cumulative += prob;
540 if random_val <= cumulative && i < levels.len() {
541 return Ok(levels[i]);
542 }
543 }
544
545 Ok(levels.first().copied().unwrap_or(0.0))
547 }
548 (
549 VariableType::Categorical(categories),
550 VariableParameters::CategoricalParams { probabilities, .. },
551 ) => {
552 let random_val: f64 = rng.gen();
554 let mut cumulative = 0.0;
555
556 for (i, &prob) in probabilities.iter().enumerate() {
557 cumulative += prob;
558 if random_val <= cumulative && i < categories.len() {
559 return Ok(categories[i]);
560 }
561 }
562
563 Ok(categories.first().copied().unwrap_or(0.0))
565 }
566 (
567 VariableType::SemiContinuous { .. },
568 VariableParameters::SemiContinuousParams {
569 zero_prob,
570 continuous_mean,
571 continuous_std,
572 ..
573 },
574 ) => {
575 if rng.gen::<f64>() < *zero_prob {
577 Ok(0.0)
578 } else {
579 Ok(continuous_mean + continuous_std * rng.gen::<f64>())
580 }
581 }
582 (
583 VariableType::Bounded { .. },
584 VariableParameters::BoundedParams {
585 lower,
586 upper,
587 beta_alpha,
588 beta_beta,
589 },
590 ) => {
591 let beta_sample = self.sample_beta(*beta_alpha, *beta_beta, rng);
593 Ok(lower + (upper - lower) * beta_sample)
594 }
595 (VariableType::Binary, VariableParameters::BinaryParams { probability }) => {
596 if rng.gen::<f64>() < *probability {
597 Ok(1.0)
598 } else {
599 Ok(0.0)
600 }
601 }
602 _ => Err(SklearsError::InvalidInput(
603 "Mismatched variable type and parameters".to_string(),
604 )),
605 }
606 }
607
608 fn predict_continuous(
609 &self,
610 X: &Array2<f64>,
611 sample_idx: usize,
612 target_feature: usize,
613 ) -> SklResult<Option<f64>> {
614 let mut predictors = Vec::new();
616 let mut targets = Vec::new();
617
618 for i in 0..X.nrows() {
620 if i != sample_idx && !self.is_missing(X[[i, target_feature]]) {
621 let mut predictor_row = Vec::new();
622 let mut all_observed = true;
623
624 for j in 0..X.ncols() {
625 if j != target_feature {
626 if self.is_missing(X[[i, j]]) {
627 all_observed = false;
628 break;
629 }
630 predictor_row.push(X[[i, j]]);
631 }
632 }
633
634 if all_observed && !predictor_row.is_empty() {
635 predictors.push(predictor_row);
636 targets.push(X[[i, target_feature]]);
637 }
638 }
639 }
640
641 if predictors.len() < 2 || predictors.is_empty() {
642 return Ok(None);
643 }
644
645 let n_predictors = predictors[0].len();
647 let n_samples = predictors.len();
648
649 let mut design_matrix = Array2::ones((n_samples, n_predictors + 1));
651 for (i, pred_row) in predictors.iter().enumerate() {
652 for (j, &val) in pred_row.iter().enumerate() {
653 design_matrix[[i, j + 1]] = val;
654 }
655 }
656
657 let y = Array1::from_vec(targets);
658
659 let xt = design_matrix.t();
661 let xtx = xt.dot(&design_matrix);
662 let xty = xt.dot(&y);
663
664 if let Some(coefficients) = self.solve_linear_system(&xtx, &xty) {
666 let mut pred_row = Vec::new();
668 for j in 0..X.ncols() {
669 if j != target_feature && !self.is_missing(X[[sample_idx, j]]) {
670 pred_row.push(X[[sample_idx, j]]);
671 }
672 }
673
674 if pred_row.len() == n_predictors {
675 let mut prediction = coefficients[0]; for (i, &val) in pred_row.iter().enumerate() {
677 prediction += coefficients[i + 1] * val;
678 }
679 return Ok(Some(prediction));
680 }
681 }
682
683 Ok(None)
684 }
685
686 fn solve_linear_system(&self, A: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
687 let n = A.nrows();
688 if n != A.ncols() || n != b.len() || n == 0 {
689 return None;
690 }
691
692 if n == 2 {
694 let det = A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]];
695 if det.abs() < 1e-10 {
696 return None;
697 }
698
699 let x0 = (A[[1, 1]] * b[0] - A[[0, 1]] * b[1]) / det;
700 let x1 = (A[[0, 0]] * b[1] - A[[1, 0]] * b[0]) / det;
701
702 return Some(Array1::from_vec(vec![x0, x1]));
703 }
704
705 let mut augmented = Array2::zeros((n, n + 1));
707 for i in 0..n {
708 for j in 0..n {
709 augmented[[i, j]] = A[[i, j]];
710 }
711 augmented[[i, n]] = b[i];
712 }
713
714 for i in 0..n {
716 let mut max_row = i;
718 for k in (i + 1)..n {
719 if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
720 max_row = k;
721 }
722 }
723
724 if max_row != i {
726 for j in 0..=n {
727 let temp = augmented[[i, j]];
728 augmented[[i, j]] = augmented[[max_row, j]];
729 augmented[[max_row, j]] = temp;
730 }
731 }
732
733 if augmented[[i, i]].abs() < 1e-10 {
735 return None;
736 }
737
738 for k in (i + 1)..n {
740 let factor = augmented[[k, i]] / augmented[[i, i]];
741 for j in i..=n {
742 augmented[[k, j]] -= factor * augmented[[i, j]];
743 }
744 }
745 }
746
747 let mut x = Array1::zeros(n);
749 for i in (0..n).rev() {
750 x[i] = augmented[[i, n]];
751 for j in (i + 1)..n {
752 x[i] -= augmented[[i, j]] * x[j];
753 }
754 x[i] /= augmented[[i, i]];
755 }
756
757 Some(x)
758 }
759
760 fn sample_beta(&self, alpha: f64, beta: f64, rng: &mut Random) -> f64 {
761 if alpha <= 0.0 || beta <= 0.0 {
764 return rng.gen::<f64>();
765 }
766
767 if (alpha - 1.0).abs() < 1e-10 && (beta - 1.0).abs() < 1e-10 {
769 return rng.gen::<f64>();
770 }
771
772 let u1: f64 = rng.gen();
774 let u2: f64 = rng.gen();
775
776 let x = u1.powf(1.0 / alpha);
777 let y = u2.powf(1.0 / beta);
778
779 x / (x + y)
780 }
781}
782
783#[derive(Debug, Clone)]
818pub struct MixedTypeMICEImputer<S = Untrained> {
819 state: S,
820 variable_types: HashMap<usize, VariableType>,
821 n_imputations: usize,
822 max_iter: usize,
823 burn_in: usize,
824 tol: f64,
825 random_state: Option<u64>,
826 missing_values: f64,
827}
828
829#[derive(Debug, Clone)]
831pub struct MixedTypeMICEImputerTrained {
832 variable_types: HashMap<usize, VariableType>,
833 learned_parameters: HashMap<usize, VariableParameters>,
834 n_features_in_: usize,
835}
836
837#[derive(Debug, Clone)]
839pub struct MixedTypeMultipleImputationResults {
840 pub imputations: Vec<Array2<f64>>,
842 pub pooled_estimates: Option<Array2<f64>>,
844 pub within_imputation_variance: Option<Array2<f64>>,
846 pub between_imputation_variance: Option<Array2<f64>>,
848 pub total_variance: Option<Array2<f64>>,
850}
851
852impl MixedTypeMICEImputer<Untrained> {
853 pub fn new() -> Self {
855 Self {
856 state: Untrained,
857 variable_types: HashMap::new(),
858 n_imputations: 5,
859 max_iter: 10,
860 burn_in: 5,
861 tol: 1e-4,
862 random_state: None,
863 missing_values: f64::NAN,
864 }
865 }
866
867 pub fn variable_types(mut self, variable_types: HashMap<usize, VariableType>) -> Self {
869 self.variable_types = variable_types;
870 self
871 }
872
873 pub fn n_imputations(mut self, n_imputations: usize) -> Self {
875 self.n_imputations = n_imputations;
876 self
877 }
878
879 pub fn max_iter(mut self, max_iter: usize) -> Self {
881 self.max_iter = max_iter;
882 self
883 }
884
885 pub fn burn_in(mut self, burn_in: usize) -> Self {
887 self.burn_in = burn_in;
888 self
889 }
890
891 pub fn tol(mut self, tol: f64) -> Self {
893 self.tol = tol;
894 self
895 }
896
897 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
899 self.random_state = random_state;
900 self
901 }
902
903 pub fn missing_values(mut self, missing_values: f64) -> Self {
905 self.missing_values = missing_values;
906 self
907 }
908}
909
910impl Default for MixedTypeMICEImputer<Untrained> {
911 fn default() -> Self {
912 Self::new()
913 }
914}
915
916impl Estimator for MixedTypeMICEImputer<Untrained> {
917 type Config = ();
918 type Error = SklearsError;
919 type Float = Float;
920
921 fn config(&self) -> &Self::Config {
922 &()
923 }
924}
925
926impl Fit<ArrayView2<'_, Float>, ()> for MixedTypeMICEImputer<Untrained> {
927 type Fitted = MixedTypeMICEImputer<MixedTypeMICEImputerTrained>;
928
929 #[allow(non_snake_case)]
930 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
931 let X = X.mapv(|x| x);
932 let (_, n_features) = X.dim();
933
934 let hetero_imputer = HeterogeneousImputer::new()
936 .variable_types(self.variable_types.clone())
937 .random_state(self.random_state);
938
939 let fitted_hetero = hetero_imputer.fit(&X.view(), &())?;
940
941 Ok(MixedTypeMICEImputer {
942 state: MixedTypeMICEImputerTrained {
943 variable_types: fitted_hetero.state.variable_types.clone(),
944 learned_parameters: fitted_hetero.state.learned_parameters.clone(),
945 n_features_in_: n_features,
946 },
947 variable_types: self.variable_types,
948 n_imputations: self.n_imputations,
949 max_iter: self.max_iter,
950 burn_in: self.burn_in,
951 tol: self.tol,
952 random_state: self.random_state,
953 missing_values: self.missing_values,
954 })
955 }
956}
957
958impl Transform<ArrayView2<'_, Float>, Array2<Float>>
959 for MixedTypeMICEImputer<MixedTypeMICEImputerTrained>
960{
961 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
962 let multiple_results = self.transform_multiple(X)?;
964 if let Some(first_imputation) = multiple_results.imputations.first() {
965 Ok(first_imputation.mapv(|x| x as Float))
966 } else {
967 Err(SklearsError::InvalidInput(
968 "No imputations generated".to_string(),
969 ))
970 }
971 }
972}
973
974impl MixedTypeMICEImputer<MixedTypeMICEImputerTrained> {
975 #[allow(non_snake_case)]
977 pub fn transform_multiple(
978 &self,
979 X: &ArrayView2<'_, Float>,
980 ) -> SklResult<MixedTypeMultipleImputationResults> {
981 let X = X.mapv(|x| x);
982 let mut imputations = Vec::new();
983
984 let mut base_rng = if let Some(_seed) = self.random_state {
985 Random::default()
986 } else {
987 Random::default()
988 };
989
990 for _m in 0..self.n_imputations {
991 let imputation_seed = base_rng.random::<u64>();
992 let imputation = self.generate_single_imputation(&X, imputation_seed)?;
993 imputations.push(imputation);
994 }
995
996 let pooled_estimates = self.pool_imputations(&imputations);
998 let (within_var, between_var, total_var) =
999 self.calculate_imputation_variance(&imputations, &pooled_estimates);
1000
1001 Ok(MixedTypeMultipleImputationResults {
1002 imputations,
1003 pooled_estimates: Some(pooled_estimates),
1004 within_imputation_variance: Some(within_var),
1005 between_imputation_variance: Some(between_var),
1006 total_variance: Some(total_var),
1007 })
1008 }
1009
1010 fn generate_single_imputation(&self, X: &Array2<f64>, _seed: u64) -> SklResult<Array2<f64>> {
1011 let mut X_imputed = X.clone();
1012 let mut rng = Random::default();
1013
1014 self.initialize_missing_values(&mut X_imputed, &mut rng)?;
1016
1017 for iteration in 0..(self.burn_in + self.max_iter) {
1019 let prev_X = X_imputed.clone();
1020
1021 for (&feature_idx, var_type) in &self.state.variable_types {
1022 if let Some(params) = self.state.learned_parameters.get(&feature_idx) {
1023 self.update_feature_mice(
1024 &mut X_imputed,
1025 X,
1026 feature_idx,
1027 var_type,
1028 params,
1029 &mut rng,
1030 )?;
1031 }
1032 }
1033
1034 if iteration >= self.burn_in {
1036 let max_change = self.calculate_max_change(&prev_X, &X_imputed, X);
1037 if max_change < self.tol {
1038 break;
1039 }
1040 }
1041 }
1042
1043 Ok(X_imputed)
1044 }
1045
1046 fn initialize_missing_values(
1047 &self,
1048 X_imputed: &mut Array2<f64>,
1049 rng: &mut Random,
1050 ) -> SklResult<()> {
1051 let (n_samples, n_features) = X_imputed.dim();
1052
1053 for j in 0..n_features {
1054 if let (Some(var_type), Some(params)) = (
1055 self.state.variable_types.get(&j),
1056 self.state.learned_parameters.get(&j),
1057 ) {
1058 for i in 0..n_samples {
1059 if self.is_missing(X_imputed[[i, j]]) {
1060 let initial_value = match (var_type, params) {
1061 (
1062 VariableType::Continuous,
1063 VariableParameters::ContinuousParams { mean, .. },
1064 ) => *mean,
1065 (VariableType::Ordinal(levels), _) => {
1066 let idx = rng.gen_range(0..levels.len());
1067 levels[idx]
1068 }
1069 (VariableType::Categorical(categories), _) => {
1070 let idx = rng.gen_range(0..categories.len());
1071 categories[idx]
1072 }
1073 (
1074 VariableType::SemiContinuous { .. },
1075 VariableParameters::SemiContinuousParams {
1076 continuous_mean, ..
1077 },
1078 ) => *continuous_mean,
1079 (VariableType::Bounded { lower, upper }, _) => {
1080 lower + (upper - lower) * rng.gen::<f64>()
1081 }
1082 (
1083 VariableType::Binary,
1084 VariableParameters::BinaryParams { probability },
1085 ) => {
1086 if rng.gen::<f64>() < *probability {
1087 1.0
1088 } else {
1089 0.0
1090 }
1091 }
1092 _ => 0.0,
1093 };
1094 X_imputed[[i, j]] = initial_value;
1095 }
1096 }
1097 }
1098 }
1099
1100 Ok(())
1101 }
1102
1103 fn update_feature_mice(
1104 &self,
1105 X_imputed: &mut Array2<f64>,
1106 X_original: &Array2<f64>,
1107 feature_idx: usize,
1108 var_type: &VariableType,
1109 params: &VariableParameters,
1110 rng: &mut Random,
1111 ) -> SklResult<()> {
1112 let (n_samples, _) = X_imputed.dim();
1113
1114 let hetero_imputer = HeterogeneousImputer {
1116 state: HeterogeneousImputerTrained {
1117 variable_types: self.state.variable_types.clone(),
1118 learned_parameters: self.state.learned_parameters.clone(),
1119 n_features_in_: self.state.n_features_in_,
1120 },
1121 variable_types: HashMap::new(),
1122 max_iter: 1,
1123 tol: self.tol,
1124 random_state: Some(rng.gen::<u64>()),
1125 missing_values: self.missing_values,
1126 };
1127
1128 for i in 0..n_samples {
1129 if self.is_missing(X_original[[i, feature_idx]]) {
1130 let imputed_value = hetero_imputer.impute_value(
1131 var_type,
1132 params,
1133 X_imputed,
1135 i,
1136 feature_idx,
1137 rng,
1138 )?;
1139 X_imputed[[i, feature_idx]] = imputed_value;
1140 }
1141 }
1142
1143 Ok(())
1144 }
1145
1146 fn calculate_max_change(
1147 &self,
1148 prev_X: &Array2<f64>,
1149 current_X: &Array2<f64>,
1150 original_X: &Array2<f64>,
1151 ) -> f64 {
1152 let mut max_change: f64 = 0.0;
1153
1154 for ((i, j), &orig_val) in original_X.indexed_iter() {
1155 if self.is_missing(orig_val) {
1156 let change = (prev_X[[i, j]] - current_X[[i, j]]).abs();
1157 max_change = max_change.max(change);
1158 }
1159 }
1160
1161 max_change
1162 }
1163
1164 fn pool_imputations(&self, imputations: &[Array2<f64>]) -> Array2<f64> {
1165 if imputations.is_empty() {
1166 return Array2::zeros((0, 0));
1167 }
1168
1169 let (n_samples, n_features) = imputations[0].dim();
1170 let mut pooled = Array2::zeros((n_samples, n_features));
1171
1172 for i in 0..n_samples {
1173 for j in 0..n_features {
1174 let sum: f64 = imputations.iter().map(|imp| imp[[i, j]]).sum();
1175 pooled[[i, j]] = sum / imputations.len() as f64;
1176 }
1177 }
1178
1179 pooled
1180 }
1181
1182 fn calculate_imputation_variance(
1183 &self,
1184 imputations: &[Array2<f64>],
1185 pooled: &Array2<f64>,
1186 ) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
1187 if imputations.is_empty() {
1188 let zero_mat = Array2::zeros((0, 0));
1189 return (zero_mat.clone(), zero_mat.clone(), zero_mat);
1190 }
1191
1192 let (n_samples, n_features) = pooled.dim();
1193 let m = imputations.len() as f64;
1194
1195 let mut within_var = Array2::zeros((n_samples, n_features));
1196 let mut between_var = Array2::zeros((n_samples, n_features));
1197
1198 for imp in imputations {
1200 for i in 0..n_samples {
1201 for j in 0..n_features {
1202 let diff = imp[[i, j]] - pooled[[i, j]];
1203 within_var[[i, j]] += diff * diff;
1204 }
1205 }
1206 }
1207 within_var /= m;
1208
1209 for imp in imputations {
1211 for i in 0..n_samples {
1212 for j in 0..n_features {
1213 let diff = imp[[i, j]] - pooled[[i, j]];
1214 between_var[[i, j]] += diff * diff;
1215 }
1216 }
1217 }
1218 between_var /= m - 1.0;
1219
1220 let total_var = &within_var + &between_var * (1.0 + 1.0 / m);
1222
1223 (within_var, between_var, total_var)
1224 }
1225
1226 fn is_missing(&self, value: f64) -> bool {
1227 if self.missing_values.is_nan() {
1228 value.is_nan()
1229 } else {
1230 (value - self.missing_values).abs() < f64::EPSILON
1231 }
1232 }
1233}
1234
1235#[derive(Debug, Clone)]
1263pub struct OrdinalImputer<S = Untrained> {
1264 state: S,
1265 levels: Vec<f64>,
1266 method: String,
1267 random_state: Option<u64>,
1268 missing_values: f64,
1269}
1270
1271#[derive(Debug, Clone)]
1273pub struct OrdinalImputerTrained {
1274 levels: Vec<f64>,
1275 level_probabilities: Array1<f64>,
1276 cumulative_probabilities: Array1<f64>,
1277 transition_matrix: Option<Array2<f64>>,
1278 n_features_in_: usize,
1279}
1280
1281impl OrdinalImputer<Untrained> {
1282 pub fn new() -> Self {
1284 Self {
1285 state: Untrained,
1286 levels: Vec::new(),
1287 method: "mode".to_string(),
1288 random_state: None,
1289 missing_values: f64::NAN,
1290 }
1291 }
1292
1293 pub fn levels(mut self, levels: Vec<f64>) -> Self {
1295 self.levels = levels;
1296 self
1297 }
1298
1299 pub fn method(mut self, method: String) -> Self {
1301 self.method = method;
1302 self
1303 }
1304
1305 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
1307 self.random_state = random_state;
1308 self
1309 }
1310
1311 pub fn missing_values(mut self, missing_values: f64) -> Self {
1313 self.missing_values = missing_values;
1314 self
1315 }
1316}
1317
1318impl Default for OrdinalImputer<Untrained> {
1319 fn default() -> Self {
1320 Self::new()
1321 }
1322}
1323
1324impl Estimator for OrdinalImputer<Untrained> {
1325 type Config = ();
1326 type Error = SklearsError;
1327 type Float = Float;
1328
1329 fn config(&self) -> &Self::Config {
1330 &()
1331 }
1332}
1333
1334impl Fit<ArrayView2<'_, Float>, ()> for OrdinalImputer<Untrained> {
1335 type Fitted = OrdinalImputer<OrdinalImputerTrained>;
1336
1337 #[allow(non_snake_case)]
1338 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
1339 let X = X.mapv(|x| x);
1340 let (_, n_features) = X.dim();
1341
1342 if n_features != 1 {
1343 return Err(SklearsError::InvalidInput(
1344 "OrdinalImputer only supports single-column input".to_string(),
1345 ));
1346 }
1347
1348 let column = X.column(0);
1349 let observed_values: Vec<f64> = column
1350 .iter()
1351 .filter(|&&x| !self.is_missing(x))
1352 .cloned()
1353 .collect();
1354
1355 if observed_values.is_empty() {
1356 return Err(SklearsError::InvalidInput(
1357 "No observed values found".to_string(),
1358 ));
1359 }
1360
1361 let levels = if self.levels.is_empty() {
1363 let mut unique_values: Vec<f64> = observed_values.clone().into_iter().collect();
1364 unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1365 unique_values.dedup();
1366 unique_values
1367 } else {
1368 self.levels.clone()
1369 };
1370
1371 let mut level_counts = Array1::<f64>::zeros(levels.len());
1373 let total_count = column.len() as f64;
1374
1375 for &value in column.iter() {
1376 if !self.is_missing(value) {
1377 if let Some(idx) = levels
1378 .iter()
1379 .position(|&level| (level - value).abs() < 1e-10)
1380 {
1381 level_counts[idx] += 1.0;
1382 }
1383 }
1384 }
1385
1386 let level_probabilities = level_counts.mapv(|count: f64| count / total_count);
1387
1388 let mut cumulative_probabilities = Array1::<f64>::zeros(levels.len());
1390 cumulative_probabilities[0] = level_probabilities[0];
1391 for i in 1..levels.len() {
1392 cumulative_probabilities[i] = cumulative_probabilities[i - 1] + level_probabilities[i];
1393 }
1394
1395 let transition_matrix = if self.method == "adjacent_categories" {
1397 Some(self.estimate_transition_matrix(&levels, &observed_values))
1398 } else {
1399 None
1400 };
1401
1402 Ok(OrdinalImputer {
1403 state: OrdinalImputerTrained {
1404 levels,
1405 level_probabilities,
1406 cumulative_probabilities,
1407 transition_matrix,
1408 n_features_in_: n_features,
1409 },
1410 levels: self.levels,
1411 method: self.method,
1412 random_state: self.random_state,
1413 missing_values: self.missing_values,
1414 })
1415 }
1416}
1417
1418impl OrdinalImputer<Untrained> {
1419 fn is_missing(&self, value: f64) -> bool {
1420 if self.missing_values.is_nan() {
1421 value.is_nan()
1422 } else {
1423 (value - self.missing_values).abs() < f64::EPSILON
1424 }
1425 }
1426
1427 fn estimate_transition_matrix(&self, levels: &[f64], observed_values: &[f64]) -> Array2<f64> {
1428 let n_levels = levels.len();
1429 let mut transition_counts = Array2::zeros((n_levels, n_levels));
1430
1431 for window in observed_values.windows(2) {
1433 if let (Some(from_idx), Some(to_idx)) = (
1434 levels
1435 .iter()
1436 .position(|&level| (level - window[0]).abs() < 1e-10),
1437 levels
1438 .iter()
1439 .position(|&level| (level - window[1]).abs() < 1e-10),
1440 ) {
1441 transition_counts[[from_idx, to_idx]] += 1.0;
1442 }
1443 }
1444
1445 let mut transition_matrix = Array2::zeros((n_levels, n_levels));
1447 for i in 0..n_levels {
1448 let row_sum: f64 = transition_counts.row(i).sum();
1449 if row_sum > 0.0 {
1450 for j in 0..n_levels {
1451 transition_matrix[[i, j]] = transition_counts[[i, j]] / row_sum;
1452 }
1453 } else {
1454 for j in 0..n_levels {
1456 transition_matrix[[i, j]] = 1.0 / n_levels as f64;
1457 }
1458 }
1459 }
1460
1461 transition_matrix
1462 }
1463}
1464
1465impl Transform<ArrayView2<'_, Float>, Array2<Float>> for OrdinalImputer<OrdinalImputerTrained> {
1466 #[allow(non_snake_case)]
1467 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1468 let X = X.mapv(|x| x);
1469 let (n_samples, n_features) = X.dim();
1470
1471 if n_features != self.state.n_features_in_ {
1472 return Err(SklearsError::InvalidInput(format!(
1473 "Number of features {} does not match training features {}",
1474 n_features, self.state.n_features_in_
1475 )));
1476 }
1477
1478 let mut X_imputed = X.clone();
1479 let mut rng = Random::default();
1480
1481 for i in 0..n_samples {
1482 if self.is_missing(X_imputed[[i, 0]]) {
1483 let imputed_value = match self.method.as_str() {
1484 "mode" => self.impute_mode(&mut rng),
1485 "proportional_odds" => self.impute_proportional_odds(&mut rng),
1486 "adjacent_categories" => {
1487 self.impute_adjacent_categories(&X_imputed, i, &mut rng)
1488 }
1489 _ => self.impute_mode(&mut rng),
1490 };
1491 X_imputed[[i, 0]] = imputed_value;
1492 }
1493 }
1494
1495 Ok(X_imputed.mapv(|x| x as Float))
1496 }
1497}
1498
1499impl OrdinalImputer<OrdinalImputerTrained> {
1500 fn is_missing(&self, value: f64) -> bool {
1501 if self.missing_values.is_nan() {
1502 value.is_nan()
1503 } else {
1504 (value - self.missing_values).abs() < f64::EPSILON
1505 }
1506 }
1507
1508 fn impute_mode(&self, _rng: &mut Random) -> f64 {
1509 let max_idx = self
1511 .state
1512 .level_probabilities
1513 .iter()
1514 .enumerate()
1515 .max_by(|(_, &a), (_, &b)| a.partial_cmp(&b).unwrap())
1516 .map(|(idx, _)| idx)
1517 .unwrap_or(0);
1518
1519 self.state.levels.get(max_idx).copied().unwrap_or(0.0)
1520 }
1521
1522 fn impute_proportional_odds(&self, rng: &mut Random) -> f64 {
1523 let random_val: f64 = rng.gen();
1525
1526 for (i, &cum_prob) in self.state.cumulative_probabilities.iter().enumerate() {
1527 if random_val <= cum_prob {
1528 return self.state.levels.get(i).copied().unwrap_or(0.0);
1529 }
1530 }
1531
1532 self.state.levels.last().copied().unwrap_or(0.0)
1534 }
1535
1536 fn impute_adjacent_categories(
1537 &self,
1538 X: &Array2<f64>,
1539 sample_idx: usize,
1540 rng: &mut Random,
1541 ) -> f64 {
1542 if let Some(ref transition_matrix) = self.state.transition_matrix {
1544 let column = X.column(0);
1546
1547 let mut closest_value = None;
1549 let mut min_distance = usize::MAX;
1550
1551 for (i, &value) in column.iter().enumerate() {
1552 if !self.is_missing(value) {
1553 let distance = (i as i32 - sample_idx as i32).unsigned_abs() as usize;
1554 if distance < min_distance {
1555 min_distance = distance;
1556 closest_value = Some(value);
1557 }
1558 }
1559 }
1560
1561 if let Some(closest_val) = closest_value {
1562 if let Some(from_idx) = self
1563 .state
1564 .levels
1565 .iter()
1566 .position(|&level| (level - closest_val).abs() < 1e-10)
1567 {
1568 let random_val: f64 = rng.gen();
1570 let mut cumulative = 0.0;
1571
1572 for (to_idx, &prob) in transition_matrix.row(from_idx).iter().enumerate() {
1573 cumulative += prob;
1574 if random_val <= cumulative {
1575 return self.state.levels.get(to_idx).copied().unwrap_or(0.0);
1576 }
1577 }
1578 }
1579 }
1580 }
1581
1582 self.impute_proportional_odds(rng)
1584 }
1585}
1586
1587#[allow(non_snake_case)]
1588#[cfg(test)]
1589mod tests {
1590 use super::*;
1591 use approx::assert_abs_diff_eq;
1592 use scirs2_core::ndarray::array;
1593 use sklears_core::traits::Transform;
1594
1595 #[test]
1596 fn test_heterogeneous_imputer_basic() {
1597 let data = array![[1.0, 2.0, 0.5], [f64::NAN, 3.0, 0.8], [3.0, f64::NAN, 0.0]];
1598
1599 let mut variable_types = HashMap::new();
1600 variable_types.insert(0, VariableType::Continuous);
1601 variable_types.insert(1, VariableType::Ordinal(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
1602 variable_types.insert(
1603 2,
1604 VariableType::Bounded {
1605 lower: 0.0,
1606 upper: 1.0,
1607 },
1608 );
1609
1610 let imputer = HeterogeneousImputer::new()
1611 .variable_types(variable_types)
1612 .max_iter(10);
1613
1614 let fitted = imputer.fit(&data.view(), &()).unwrap();
1615 let result = fitted.transform(&data.view()).unwrap();
1616
1617 assert!(!result.iter().any(|&x| (x).is_nan()));
1619
1620 assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1622 assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
1623 assert_abs_diff_eq!(result[[0, 2]], 0.5, epsilon = 1e-10);
1624 }
1625
1626 #[test]
1627 fn test_mixed_type_mice_basic() {
1628 let data = array![[1.0, 2.0, 0.0], [f64::NAN, 3.0, 1.0], [3.0, f64::NAN, 0.0]];
1629
1630 let mut variable_types = HashMap::new();
1631 variable_types.insert(0, VariableType::Continuous);
1632 variable_types.insert(1, VariableType::Ordinal(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
1633 variable_types.insert(
1634 2,
1635 VariableType::SemiContinuous {
1636 zero_probability: 0.6,
1637 },
1638 );
1639
1640 let imputer = MixedTypeMICEImputer::new()
1641 .variable_types(variable_types)
1642 .n_imputations(3)
1643 .max_iter(5);
1644
1645 let fitted = imputer.fit(&data.view(), &()).unwrap();
1646 let results = fitted.transform_multiple(&data.view()).unwrap();
1647
1648 assert_eq!(results.imputations.len(), 3);
1650
1651 for imputation in &results.imputations {
1653 assert!(!imputation.iter().any(|&x| x.is_nan()));
1654 }
1655
1656 assert!(results.pooled_estimates.is_some());
1658 }
1659
1660 #[test]
1661 fn test_ordinal_imputer_basic() {
1662 let data = array![[1.0], [2.0], [f64::NAN], [3.0], [1.0]];
1663 let levels = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1664
1665 let imputer = OrdinalImputer::new()
1666 .levels(levels)
1667 .method("mode".to_string());
1668
1669 let fitted = imputer.fit(&data.view(), &()).unwrap();
1670 let result = fitted.transform(&data.view()).unwrap();
1671
1672 assert!(!result.iter().any(|&x| (x).is_nan()));
1674
1675 assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1677 assert_abs_diff_eq!(result[[1, 0]], 2.0, epsilon = 1e-10);
1678 assert_abs_diff_eq!(result[[3, 0]], 3.0, epsilon = 1e-10);
1679 assert_abs_diff_eq!(result[[4, 0]], 1.0, epsilon = 1e-10);
1680
1681 let imputed_val = result[[2, 0]];
1683 assert!([1.0, 2.0, 3.0, 4.0, 5.0].contains(&imputed_val));
1684 }
1685
1686 #[test]
1687 fn test_variable_type_auto_detection() {
1688 let data = array![[1.0, 1.0, 0.5], [2.0, 0.0, 0.8], [3.0, 1.0, 0.0]];
1689
1690 let imputer = HeterogeneousImputer::new().max_iter(5);
1691 let fitted = imputer.fit(&data.view(), &()).unwrap();
1692
1693 let variable_types = &fitted.state.variable_types;
1695
1696 if let Some(VariableType::Ordinal(_)) = variable_types.get(&0) {
1698 } else if let Some(VariableType::Continuous) = variable_types.get(&0) {
1700 } else {
1702 panic!("Unexpected variable type for first column");
1703 }
1704
1705 assert!(variable_types.contains_key(&1));
1707
1708 if let Some(VariableType::Bounded { lower, upper }) = variable_types.get(&2) {
1710 assert_abs_diff_eq!(*lower, 0.0, epsilon = 1e-10);
1711 assert_abs_diff_eq!(*upper, 1.0, epsilon = 1e-10);
1712 }
1713 }
1714}