1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 traits::{Fit, Transform},
10 types::Float,
11};
12use std::collections::HashMap;
13
14use crate::{parallel::ParallelConfig, KNNImputer, ParallelKNNImputer, SimpleImputer};
15
16type PreprocessingResult = SklResult<(Option<Array1<Float>>, Option<Array1<Float>>)>;
18
19#[derive(Debug, Clone)]
21pub struct ImputationBuilder {
22 method: ImputationMethod,
23 validation: ValidationConfig,
24 preprocessing: PreprocessingConfig,
25 postprocessing: PostprocessingConfig,
26 parallel_config: Option<ParallelConfig>,
27}
28
29#[derive(Debug, Clone)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub enum ImputationMethod {
33 Simple(SimpleImputationConfig),
35 KNN(KNNImputationConfig),
37 Iterative(IterativeImputationConfig),
39 GaussianProcess(GaussianProcessConfig),
41 MatrixFactorization(MatrixFactorizationConfig),
43 Bayesian(BayesianImputationConfig),
45 Ensemble(EnsembleImputationConfig),
47 DeepLearning(DeepLearningConfig),
49 Custom(CustomImputationConfig),
51}
52
53#[derive(Debug, Clone)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct SimpleImputationConfig {
57 pub strategy: String,
59 pub fill_value: Option<f64>,
61 pub copy: bool,
63}
64
65#[derive(Debug, Clone)]
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68pub struct KNNImputationConfig {
69 pub n_neighbors: usize,
71 pub weights: String,
73 pub metric: String,
75 pub add_indicator: bool,
77}
78
79#[derive(Debug, Clone)]
81#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
82pub struct IterativeImputationConfig {
83 pub max_iter: usize,
85 pub tol: f64,
87 pub n_nearest_features: Option<usize>,
89 pub sample_posterior: bool,
91 pub random_state: Option<u64>,
93}
94
95#[derive(Debug, Clone)]
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98pub struct GaussianProcessConfig {
99 pub kernel: String,
101 pub alpha: f64,
103 pub n_restarts_optimizer: usize,
105 pub normalize_y: bool,
107 pub random_state: Option<u64>,
109}
110
111#[derive(Debug, Clone)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
114pub struct MatrixFactorizationConfig {
115 pub n_components: usize,
117 pub max_iter: usize,
119 pub tol: f64,
121 pub regularization: f64,
123 pub random_state: Option<u64>,
125}
126
127#[derive(Debug, Clone)]
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130pub struct BayesianImputationConfig {
131 pub n_imputations: usize,
133 pub max_iter: usize,
135 pub burn_in: usize,
137 pub prior_variance: f64,
139 pub random_state: Option<u64>,
141}
142
143#[derive(Debug, Clone)]
145#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
146pub struct EnsembleImputationConfig {
147 pub method: String, pub n_estimators: usize,
151 pub max_depth: Option<usize>,
153 pub random_state: Option<u64>,
155}
156
157#[derive(Debug, Clone)]
159#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
160pub struct DeepLearningConfig {
161 pub method: String, pub hidden_dims: Vec<usize>,
165 pub learning_rate: f64,
167 pub epochs: usize,
169 pub batch_size: usize,
171 pub device: String, }
174
175#[derive(Debug, Clone)]
177#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
178pub struct CustomImputationConfig {
179 pub name: String,
181 #[cfg(feature = "serde")]
182 pub parameters: HashMap<String, serde_json::Value>,
183 #[cfg(not(feature = "serde"))]
184 pub parameters: HashMap<String, String>,
185}
186
187#[derive(Debug, Clone)]
189#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
190pub struct ValidationConfig {
191 pub cross_validation: bool,
193 pub cv_folds: usize,
195 pub holdout_fraction: Option<f64>,
197 pub metrics: Vec<String>,
199 pub synthetic_missing_patterns: Vec<String>,
201}
202
203#[derive(Debug, Clone)]
205#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
206pub struct PreprocessingConfig {
207 pub normalize: bool,
209 pub scale: bool,
211 pub remove_constant_features: bool,
213 pub handle_outliers: bool,
215 pub outlier_method: String,
217}
218
219#[derive(Debug, Clone)]
221#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
222pub struct PostprocessingConfig {
223 pub clip_values: Option<(f64, f64)>,
225 pub round_integers: bool,
227 pub preserve_dtypes: bool,
229 pub add_uncertainty_estimates: bool,
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq)]
235pub enum ImputationPreset {
236 Fast,
238 Balanced,
240 HighQuality,
242 Memory,
244 Parallel,
246 Academic,
248 Production,
250}
251
252impl Default for ImputationBuilder {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258impl ImputationBuilder {
259 pub fn new() -> Self {
261 Self {
262 method: ImputationMethod::Simple(SimpleImputationConfig {
263 strategy: "mean".to_string(),
264 fill_value: None,
265 copy: true,
266 }),
267 validation: ValidationConfig {
268 cross_validation: false,
269 cv_folds: 5,
270 holdout_fraction: None,
271 metrics: vec!["rmse".to_string()],
272 synthetic_missing_patterns: vec!["mcar".to_string()],
273 },
274 preprocessing: PreprocessingConfig {
275 normalize: false,
276 scale: false,
277 remove_constant_features: false,
278 handle_outliers: false,
279 outlier_method: "iqr".to_string(),
280 },
281 postprocessing: PostprocessingConfig {
282 clip_values: None,
283 round_integers: false,
284 preserve_dtypes: true,
285 add_uncertainty_estimates: false,
286 },
287 parallel_config: None,
288 }
289 }
290
291 pub fn preset(mut self, preset: ImputationPreset) -> Self {
293 match preset {
294 ImputationPreset::Fast => {
295 self.method = ImputationMethod::Simple(SimpleImputationConfig {
296 strategy: "mean".to_string(),
297 fill_value: None,
298 copy: true,
299 });
300 }
301 ImputationPreset::Balanced => {
302 self.method = ImputationMethod::KNN(KNNImputationConfig {
303 n_neighbors: 5,
304 weights: "uniform".to_string(),
305 metric: "euclidean".to_string(),
306 add_indicator: false,
307 });
308 }
309 ImputationPreset::HighQuality => {
310 self.method = ImputationMethod::Iterative(IterativeImputationConfig {
311 max_iter: 10,
312 tol: 1e-3,
313 n_nearest_features: None,
314 sample_posterior: true,
315 random_state: None,
316 });
317 self.validation.cross_validation = true;
318 self.postprocessing.add_uncertainty_estimates = true;
319 }
320 ImputationPreset::Memory => {
321 self.method = ImputationMethod::Simple(SimpleImputationConfig {
322 strategy: "median".to_string(),
323 fill_value: None,
324 copy: false,
325 });
326 self.preprocessing.remove_constant_features = true;
327 }
328 ImputationPreset::Parallel => {
329 self.method = ImputationMethod::KNN(KNNImputationConfig {
330 n_neighbors: 3,
331 weights: "distance".to_string(),
332 metric: "euclidean".to_string(),
333 add_indicator: false,
334 });
335 self.parallel_config = Some(ParallelConfig::default());
336 }
337 ImputationPreset::Academic => {
338 self.method = ImputationMethod::Bayesian(BayesianImputationConfig {
339 n_imputations: 5,
340 max_iter: 100,
341 burn_in: 20,
342 prior_variance: 1.0,
343 random_state: Some(42),
344 });
345 self.validation.cross_validation = true;
346 self.validation.cv_folds = 10;
347 self.validation.metrics = vec![
348 "rmse".to_string(),
349 "mae".to_string(),
350 "bias".to_string(),
351 "coverage".to_string(),
352 ];
353 self.postprocessing.add_uncertainty_estimates = true;
354 }
355 ImputationPreset::Production => {
356 self.method = ImputationMethod::Ensemble(EnsembleImputationConfig {
357 method: "random_forest".to_string(),
358 n_estimators: 100,
359 max_depth: Some(10),
360 random_state: Some(42),
361 });
362 self.validation.cross_validation = true;
363 self.preprocessing.handle_outliers = true;
364 self.postprocessing.preserve_dtypes = true;
365 }
366 }
367 self
368 }
369
370 pub fn simple(self) -> SimpleImputationBuilder {
372 SimpleImputationBuilder::new(self)
373 }
374
375 pub fn knn(self) -> KNNImputationBuilder {
377 KNNImputationBuilder::new(self)
378 }
379
380 pub fn iterative(self) -> IterativeImputationBuilder {
382 IterativeImputationBuilder::new(self)
383 }
384
385 pub fn gaussian_process(self) -> GaussianProcessBuilder {
387 GaussianProcessBuilder::new(self)
388 }
389
390 pub fn ensemble(self) -> EnsembleImputationBuilder {
392 EnsembleImputationBuilder::new(self)
393 }
394
395 pub fn deep_learning(self) -> DeepLearningBuilder {
397 DeepLearningBuilder::new(self)
398 }
399
400 pub fn parallel(mut self, config: Option<ParallelConfig>) -> Self {
402 self.parallel_config = config.or_else(|| Some(ParallelConfig::default()));
403 self
404 }
405
406 pub fn validation(mut self, config: ValidationConfig) -> Self {
408 self.validation = config;
409 self
410 }
411
412 pub fn cross_validate(mut self, folds: usize) -> Self {
414 self.validation.cross_validation = true;
415 self.validation.cv_folds = folds;
416 self
417 }
418
419 pub fn preprocessing(mut self, config: PreprocessingConfig) -> Self {
421 self.preprocessing = config;
422 self
423 }
424
425 pub fn normalize(mut self) -> Self {
427 self.preprocessing.normalize = true;
428 self
429 }
430
431 pub fn scale(mut self) -> Self {
433 self.preprocessing.scale = true;
434 self
435 }
436
437 pub fn postprocessing(mut self, config: PostprocessingConfig) -> Self {
439 self.postprocessing = config;
440 self
441 }
442
443 pub fn with_uncertainty(mut self) -> Self {
445 self.postprocessing.add_uncertainty_estimates = true;
446 self
447 }
448
449 pub fn build(self) -> SklResult<ImputationPipeline> {
451 ImputationPipeline::new(
452 self.method,
453 self.validation,
454 self.preprocessing,
455 self.postprocessing,
456 self.parallel_config,
457 )
458 }
459}
460
461pub struct SimpleImputationBuilder {
463 builder: ImputationBuilder,
464 config: SimpleImputationConfig,
465}
466
467impl SimpleImputationBuilder {
468 fn new(builder: ImputationBuilder) -> Self {
469 Self {
470 builder,
471 config: SimpleImputationConfig {
472 strategy: "mean".to_string(),
473 fill_value: None,
474 copy: true,
475 },
476 }
477 }
478
479 pub fn strategy(mut self, strategy: &str) -> Self {
480 self.config.strategy = strategy.to_string();
481 self
482 }
483
484 pub fn mean(mut self) -> Self {
485 self.config.strategy = "mean".to_string();
486 self
487 }
488
489 pub fn median(mut self) -> Self {
490 self.config.strategy = "median".to_string();
491 self
492 }
493
494 pub fn mode(mut self) -> Self {
495 self.config.strategy = "most_frequent".to_string();
496 self
497 }
498
499 pub fn constant(mut self, value: f64) -> Self {
500 self.config.strategy = "constant".to_string();
501 self.config.fill_value = Some(value);
502 self
503 }
504
505 pub fn finish(mut self) -> ImputationBuilder {
506 self.builder.method = ImputationMethod::Simple(self.config);
507 self.builder
508 }
509}
510
511pub struct KNNImputationBuilder {
513 builder: ImputationBuilder,
514 config: KNNImputationConfig,
515}
516
517impl KNNImputationBuilder {
518 fn new(builder: ImputationBuilder) -> Self {
519 Self {
520 builder,
521 config: KNNImputationConfig {
522 n_neighbors: 5,
523 weights: "uniform".to_string(),
524 metric: "euclidean".to_string(),
525 add_indicator: false,
526 },
527 }
528 }
529
530 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
531 self.config.n_neighbors = n_neighbors;
532 self
533 }
534
535 pub fn weights(mut self, weights: &str) -> Self {
536 self.config.weights = weights.to_string();
537 self
538 }
539
540 pub fn uniform_weights(mut self) -> Self {
541 self.config.weights = "uniform".to_string();
542 self
543 }
544
545 pub fn distance_weights(mut self) -> Self {
546 self.config.weights = "distance".to_string();
547 self
548 }
549
550 pub fn metric(mut self, metric: &str) -> Self {
551 self.config.metric = metric.to_string();
552 self
553 }
554
555 pub fn euclidean(mut self) -> Self {
556 self.config.metric = "euclidean".to_string();
557 self
558 }
559
560 pub fn manhattan(mut self) -> Self {
561 self.config.metric = "manhattan".to_string();
562 self
563 }
564
565 pub fn add_indicator(mut self, add_indicator: bool) -> Self {
566 self.config.add_indicator = add_indicator;
567 self
568 }
569
570 pub fn finish(mut self) -> ImputationBuilder {
571 self.builder.method = ImputationMethod::KNN(self.config);
572 self.builder
573 }
574}
575
576pub struct IterativeImputationBuilder {
578 builder: ImputationBuilder,
579 config: IterativeImputationConfig,
580}
581
582impl IterativeImputationBuilder {
583 fn new(builder: ImputationBuilder) -> Self {
584 Self {
585 builder,
586 config: IterativeImputationConfig {
587 max_iter: 10,
588 tol: 1e-3,
589 n_nearest_features: None,
590 sample_posterior: false,
591 random_state: None,
592 },
593 }
594 }
595
596 pub fn max_iter(mut self, max_iter: usize) -> Self {
597 self.config.max_iter = max_iter;
598 self
599 }
600
601 pub fn tolerance(mut self, tol: f64) -> Self {
602 self.config.tol = tol;
603 self
604 }
605
606 pub fn n_nearest_features(mut self, n_features: usize) -> Self {
607 self.config.n_nearest_features = Some(n_features);
608 self
609 }
610
611 pub fn sample_posterior(mut self, sample: bool) -> Self {
612 self.config.sample_posterior = sample;
613 self
614 }
615
616 pub fn random_state(mut self, seed: u64) -> Self {
617 self.config.random_state = Some(seed);
618 self
619 }
620
621 pub fn finish(mut self) -> ImputationBuilder {
622 self.builder.method = ImputationMethod::Iterative(self.config);
623 self.builder
624 }
625}
626
627pub struct GaussianProcessBuilder {
629 builder: ImputationBuilder,
630 config: GaussianProcessConfig,
631}
632
633impl GaussianProcessBuilder {
634 fn new(builder: ImputationBuilder) -> Self {
635 Self {
636 builder,
637 config: GaussianProcessConfig {
638 kernel: "rbf".to_string(),
639 alpha: 1e-6,
640 n_restarts_optimizer: 0,
641 normalize_y: false,
642 random_state: None,
643 },
644 }
645 }
646
647 pub fn kernel(mut self, kernel: &str) -> Self {
648 self.config.kernel = kernel.to_string();
649 self
650 }
651
652 pub fn rbf_kernel(mut self) -> Self {
653 self.config.kernel = "rbf".to_string();
654 self
655 }
656
657 pub fn matern_kernel(mut self) -> Self {
658 self.config.kernel = "matern".to_string();
659 self
660 }
661
662 pub fn alpha(mut self, alpha: f64) -> Self {
663 self.config.alpha = alpha;
664 self
665 }
666
667 pub fn n_restarts(mut self, n_restarts: usize) -> Self {
668 self.config.n_restarts_optimizer = n_restarts;
669 self
670 }
671
672 pub fn normalize_y(mut self, normalize: bool) -> Self {
673 self.config.normalize_y = normalize;
674 self
675 }
676
677 pub fn random_state(mut self, seed: u64) -> Self {
678 self.config.random_state = Some(seed);
679 self
680 }
681
682 pub fn finish(mut self) -> ImputationBuilder {
683 self.builder.method = ImputationMethod::GaussianProcess(self.config);
684 self.builder
685 }
686}
687
688pub struct EnsembleImputationBuilder {
690 builder: ImputationBuilder,
691 config: EnsembleImputationConfig,
692}
693
694impl EnsembleImputationBuilder {
695 fn new(builder: ImputationBuilder) -> Self {
696 Self {
697 builder,
698 config: EnsembleImputationConfig {
699 method: "random_forest".to_string(),
700 n_estimators: 100,
701 max_depth: None,
702 random_state: None,
703 },
704 }
705 }
706
707 pub fn random_forest(mut self) -> Self {
708 self.config.method = "random_forest".to_string();
709 self
710 }
711
712 pub fn gradient_boosting(mut self) -> Self {
713 self.config.method = "gradient_boosting".to_string();
714 self
715 }
716
717 pub fn extra_trees(mut self) -> Self {
718 self.config.method = "extra_trees".to_string();
719 self
720 }
721
722 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
723 self.config.n_estimators = n_estimators;
724 self
725 }
726
727 pub fn max_depth(mut self, max_depth: usize) -> Self {
728 self.config.max_depth = Some(max_depth);
729 self
730 }
731
732 pub fn random_state(mut self, seed: u64) -> Self {
733 self.config.random_state = Some(seed);
734 self
735 }
736
737 pub fn finish(mut self) -> ImputationBuilder {
738 self.builder.method = ImputationMethod::Ensemble(self.config);
739 self.builder
740 }
741}
742
743pub struct DeepLearningBuilder {
745 builder: ImputationBuilder,
746 config: DeepLearningConfig,
747}
748
749impl DeepLearningBuilder {
750 fn new(builder: ImputationBuilder) -> Self {
751 Self {
752 builder,
753 config: DeepLearningConfig {
754 method: "autoencoder".to_string(),
755 hidden_dims: vec![128, 64, 32],
756 learning_rate: 0.001,
757 epochs: 100,
758 batch_size: 32,
759 device: "cpu".to_string(),
760 },
761 }
762 }
763
764 pub fn autoencoder(mut self) -> Self {
765 self.config.method = "autoencoder".to_string();
766 self
767 }
768
769 pub fn vae(mut self) -> Self {
770 self.config.method = "vae".to_string();
771 self
772 }
773
774 pub fn gan(mut self) -> Self {
775 self.config.method = "gan".to_string();
776 self
777 }
778
779 pub fn hidden_dims(mut self, dims: Vec<usize>) -> Self {
780 self.config.hidden_dims = dims;
781 self
782 }
783
784 pub fn learning_rate(mut self, lr: f64) -> Self {
785 self.config.learning_rate = lr;
786 self
787 }
788
789 pub fn epochs(mut self, epochs: usize) -> Self {
790 self.config.epochs = epochs;
791 self
792 }
793
794 pub fn batch_size(mut self, batch_size: usize) -> Self {
795 self.config.batch_size = batch_size;
796 self
797 }
798
799 pub fn device(mut self, device: &str) -> Self {
800 self.config.device = device.to_string();
801 self
802 }
803
804 pub fn finish(mut self) -> ImputationBuilder {
805 self.builder.method = ImputationMethod::DeepLearning(self.config);
806 self.builder
807 }
808}
809
810pub struct ImputationPipeline {
812 method: ImputationMethod,
813 validation: ValidationConfig,
814 preprocessing: PreprocessingConfig,
815 postprocessing: PostprocessingConfig,
816 parallel_config: Option<ParallelConfig>,
817}
818
819impl ImputationPipeline {
820 fn new(
821 method: ImputationMethod,
822 validation: ValidationConfig,
823 preprocessing: PreprocessingConfig,
824 postprocessing: PostprocessingConfig,
825 parallel_config: Option<ParallelConfig>,
826 ) -> SklResult<Self> {
827 Ok(Self {
828 method,
829 validation,
830 preprocessing,
831 postprocessing,
832 parallel_config,
833 })
834 }
835
836 pub fn fit_transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
838 let X_preprocessed = X.to_owned();
842 let (means, stds) = self.apply_preprocessing(&X_preprocessed)?;
843
844 let X_imputed = match &self.method {
846 ImputationMethod::Simple(config) => {
847 let imputer = SimpleImputer::new().strategy(config.strategy.clone());
848 let fitted = imputer.fit(X, &())?;
849 fitted.transform(X)
850 }
851 ImputationMethod::KNN(config) => {
852 if let Some(parallel_config) = &self.parallel_config {
853 let imputer = ParallelKNNImputer::new()
854 .n_neighbors(config.n_neighbors)
855 .weights(config.weights.clone())
856 .metric(config.metric.clone())
857 .parallel_config(parallel_config.clone());
858 let fitted = imputer.fit(X, &())?;
859 fitted.transform(X)
860 } else {
861 let imputer = KNNImputer::new()
862 .n_neighbors(config.n_neighbors)
863 .weights(config.weights.clone())
864 .metric(config.metric.clone());
865 let fitted = imputer.fit(X, &())?;
866 fitted.transform(X)
867 }
868 }
869 _ => {
870 let imputer = SimpleImputer::new().strategy("mean".to_string());
872 let fitted = imputer.fit(&X_preprocessed.view(), &())?;
873 fitted.transform(&X_preprocessed.view())
874 }
875 }?;
876
877 let X_final = self.apply_postprocessing(X_imputed, &means, &stds)?;
879
880 Ok(X_final)
881 }
882
883 fn apply_preprocessing(&self, X: &Array2<Float>) -> PreprocessingResult {
885 let mut X_proc = X.clone();
886 let mut means = None;
887 let mut stds = None;
888
889 if self.preprocessing.normalize || self.preprocessing.scale {
891 let (n_samples, n_features) = X_proc.dim();
892 let mut feature_means = Array1::zeros(n_features);
893 let mut feature_stds = Array1::ones(n_features);
894
895 for j in 0..n_features {
896 let col = X_proc.column(j);
897 let valid_values: Vec<Float> =
898 col.iter().filter(|x| x.is_finite()).copied().collect();
899
900 if !valid_values.is_empty() {
901 let mean = valid_values.iter().sum::<Float>() / valid_values.len() as Float;
902 feature_means[j] = mean;
903
904 if self.preprocessing.scale {
905 let variance = valid_values
906 .iter()
907 .map(|x| (x - mean).powi(2))
908 .sum::<Float>()
909 / valid_values.len() as Float;
910 feature_stds[j] = variance.sqrt().max(1e-8);
911 }
912 }
913 }
914
915 for j in 0..n_features {
917 for i in 0..n_samples {
918 if X_proc[[i, j]].is_finite() {
919 X_proc[[i, j]] = (X_proc[[i, j]] - feature_means[j]) / feature_stds[j];
920 }
921 }
922 }
923
924 means = Some(feature_means);
925 stds = Some(feature_stds);
926 }
927
928 Ok((means, stds))
929 }
930
931 fn apply_postprocessing(
933 &self,
934 mut X: Array2<Float>,
935 means: &Option<Array1<Float>>,
936 stds: &Option<Array1<Float>>,
937 ) -> SklResult<Array2<Float>> {
938 let (n_samples, n_features) = X.dim();
939
940 if let (Some(means_arr), Some(stds_arr)) = (means, stds) {
942 for j in 0..n_features {
943 for i in 0..n_samples {
944 X[[i, j]] = X[[i, j]] * stds_arr[j] + means_arr[j];
945 }
946 }
947 }
948
949 if let Some((min_val, max_val)) = self.postprocessing.clip_values {
951 for value in X.iter_mut() {
952 *value = value.clamp(min_val, max_val);
953 }
954 }
955
956 if self.postprocessing.round_integers {
958 for value in X.iter_mut() {
959 *value = value.round();
960 }
961 }
962
963 Ok(X)
964 }
965
966 #[cfg(feature = "serde")]
968 pub fn to_json(&self) -> SklResult<String> {
969 #[derive(serde::Serialize)]
970 struct PipelineConfig<'a> {
971 method: &'a ImputationMethod,
972 validation: &'a ValidationConfig,
973 preprocessing: &'a PreprocessingConfig,
974 postprocessing: &'a PostprocessingConfig,
975 parallel_config: &'a Option<ParallelConfig>,
976 }
977
978 let config = PipelineConfig {
979 method: &self.method,
980 validation: &self.validation,
981 preprocessing: &self.preprocessing,
982 postprocessing: &self.postprocessing,
983 parallel_config: &self.parallel_config,
984 };
985
986 serde_json::to_string_pretty(&config).map_err(|e| {
987 SklearsError::SerializationError(format!("Failed to serialize config: {}", e))
988 })
989 }
990
991 #[cfg(not(feature = "serde"))]
993 pub fn to_json(&self) -> SklResult<String> {
994 Err(SklearsError::NotImplemented(
995 "to_json requires serde feature".to_string(),
996 ))
997 }
998
999 pub fn from_json(_json: &str) -> SklResult<Self> {
1001 Err(SklearsError::NotImplemented(
1003 "from_json not yet implemented".to_string(),
1004 ))
1005 }
1006}
1007
1008pub mod quick {
1010 use super::*;
1011
1012 pub fn mean_impute(X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1014 ImputationBuilder::new()
1015 .simple()
1016 .mean()
1017 .finish()
1018 .build()?
1019 .fit_transform(X)
1020 }
1021
1022 pub fn knn_impute(X: &ArrayView2<'_, Float>, n_neighbors: usize) -> SklResult<Array2<Float>> {
1024 ImputationBuilder::new()
1025 .knn()
1026 .n_neighbors(n_neighbors)
1027 .finish()
1028 .build()?
1029 .fit_transform(X)
1030 }
1031
1032 pub fn parallel_knn_impute(
1034 X: &ArrayView2<'_, Float>,
1035 n_neighbors: usize,
1036 ) -> SklResult<Array2<Float>> {
1037 ImputationBuilder::new()
1038 .knn()
1039 .n_neighbors(n_neighbors)
1040 .finish()
1041 .parallel(None)
1042 .build()?
1043 .fit_transform(X)
1044 }
1045
1046 pub fn iterative_impute(X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1048 ImputationBuilder::new()
1049 .iterative()
1050 .max_iter(10)
1051 .finish()
1052 .build()?
1053 .fit_transform(X)
1054 }
1055
1056 pub fn high_quality_impute(X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1058 ImputationBuilder::new()
1059 .preset(ImputationPreset::HighQuality)
1060 .build()?
1061 .fit_transform(X)
1062 }
1063}
1064
1065pub mod pluggable {
1067 use super::*;
1068
1069 pub trait ImputationModule: Send + Sync {
1071 fn name(&self) -> &str;
1073
1074 fn version(&self) -> &str;
1076
1077 fn can_handle(&self, data_info: &DataCharacteristics) -> bool;
1079
1080 fn config_schema(&self) -> ModuleConfigSchema;
1082
1083 fn create_instance(&self, config: &ModuleConfig) -> SklResult<Box<dyn ImputationInstance>>;
1085
1086 fn dependencies(&self) -> Vec<&str> {
1088 vec![]
1089 }
1090
1091 fn priority(&self) -> i32 {
1093 0
1094 }
1095 }
1096
1097 pub trait ImputationInstance: Send + Sync {
1099 fn fit(&mut self, X: &ArrayView2<Float>) -> SklResult<()>;
1101
1102 fn transform(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>>;
1104
1105 fn fit_transform(&mut self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1107 self.fit(X)?;
1108 self.transform(X)
1109 }
1110
1111 fn transform_with_uncertainty(
1113 &self,
1114 X: &ArrayView2<Float>,
1115 ) -> SklResult<(Array2<Float>, Option<Array2<Float>>)> {
1116 let result = self.transform(X)?;
1117 Ok((result, None))
1118 }
1119
1120 fn supports_uncertainty(&self) -> bool {
1122 false
1123 }
1124
1125 fn feature_importance(&self) -> Option<Array1<Float>> {
1127 None
1128 }
1129
1130 fn partial_fit(&mut self, _X: &ArrayView2<Float>) -> SklResult<()> {
1132 Err(SklearsError::NotImplemented(
1133 "Partial fit not supported".to_string(),
1134 ))
1135 }
1136
1137 fn supports_partial_fit(&self) -> bool {
1139 false
1140 }
1141 }
1142
1143 #[derive(Debug, Clone)]
1145 pub struct DataCharacteristics {
1146 pub n_samples: usize,
1148 pub n_features: usize,
1150 pub missing_rate: f64,
1152 pub missing_pattern: MissingPatternType,
1154 pub data_types: Vec<DataType>,
1156 pub has_categorical: bool,
1158 pub has_temporal: bool,
1160 pub is_sparse: bool,
1162 pub memory_constraints: Option<usize>, }
1165
1166 #[derive(Debug, Clone, PartialEq)]
1168 pub enum MissingPatternType {
1169 MCAR,
1171 MAR,
1173 MNAR,
1175 Unknown,
1177 Block,
1179 Monotone,
1181 }
1182
1183 #[derive(Debug, Clone, PartialEq)]
1185 pub enum DataType {
1186 Continuous,
1188 Categorical,
1190 Ordinal,
1192 Binary,
1194 Count,
1196 Temporal,
1198 Text,
1200 }
1201
1202 #[derive(Debug, Clone)]
1204 pub struct ModuleConfigSchema {
1205 pub parameters: HashMap<String, ParameterSchema>,
1207 pub required_parameters: Vec<String>,
1209 pub parameter_groups: Vec<ParameterGroup>,
1211 }
1212
1213 #[derive(Debug, Clone)]
1215 pub struct ParameterSchema {
1216 pub name: String,
1218 pub parameter_type: ParameterType,
1220 #[cfg(feature = "serde")]
1221 pub default_value: Option<serde_json::Value>,
1222 #[cfg(not(feature = "serde"))]
1223 pub default_value: Option<String>,
1224 pub valid_range: Option<ParameterRange>,
1226 pub description: String,
1228 pub dependencies: Vec<String>,
1230 }
1231
1232 #[derive(Debug, Clone)]
1234 pub enum ParameterType {
1235 Integer,
1237 Float,
1239 String,
1241 Boolean,
1243 Array(Box<ParameterType>),
1245 Enum(Vec<String>),
1247 Object(HashMap<String, ParameterType>),
1249 }
1250
1251 #[derive(Debug, Clone)]
1253 pub enum ParameterRange {
1254 IntRange { min: Option<i64>, max: Option<i64> },
1256 FloatRange { min: Option<f64>, max: Option<f64> },
1258 StringPattern(String), ArrayLength {
1262 min: Option<usize>,
1263 max: Option<usize>,
1264 },
1265 }
1266
1267 #[derive(Debug, Clone)]
1269 pub struct ParameterGroup {
1270 pub name: String,
1272 pub description: String,
1274 pub parameters: Vec<String>,
1276 pub optional: bool,
1278 }
1279
1280 #[derive(Debug, Clone)]
1282 pub struct ModuleConfig {
1283 #[cfg(feature = "serde")]
1284 pub parameters: HashMap<String, serde_json::Value>,
1285 #[cfg(not(feature = "serde"))]
1286 pub parameters: HashMap<String, String>,
1287 }
1288
1289 pub struct ModuleRegistry {
1291 modules: HashMap<String, Box<dyn ImputationModule>>,
1292 aliases: HashMap<String, String>,
1293 }
1294
1295 impl Default for ModuleRegistry {
1296 fn default() -> Self {
1297 Self::new()
1298 }
1299 }
1300
1301 impl ModuleRegistry {
1302 pub fn new() -> Self {
1303 Self {
1304 modules: HashMap::new(),
1305 aliases: HashMap::new(),
1306 }
1307 }
1308
1309 pub fn register_module(&mut self, module: Box<dyn ImputationModule>) -> SklResult<()> {
1311 let name = module.name().to_string();
1312 if self.modules.contains_key(&name) {
1313 return Err(SklearsError::InvalidInput(format!(
1314 "Module '{}' already registered",
1315 name
1316 )));
1317 }
1318 self.modules.insert(name, module);
1319 Ok(())
1320 }
1321
1322 pub fn register_alias(&mut self, alias: String, module_name: String) -> SklResult<()> {
1324 if !self.modules.contains_key(&module_name) {
1325 return Err(SklearsError::InvalidInput(format!(
1326 "Module '{}' not found",
1327 module_name
1328 )));
1329 }
1330 self.aliases.insert(alias, module_name);
1331 Ok(())
1332 }
1333
1334 pub fn get_module(&self, name: &str) -> Option<&dyn ImputationModule> {
1336 if let Some(actual_name) = self.aliases.get(name) {
1337 self.modules.get(actual_name).map(|m| m.as_ref())
1338 } else {
1339 self.modules.get(name).map(|m| m.as_ref())
1340 }
1341 }
1342
1343 pub fn list_modules(&self) -> Vec<&str> {
1345 self.modules.keys().map(|s| s.as_str()).collect()
1346 }
1347
1348 pub fn find_suitable_modules(
1350 &self,
1351 data_info: &DataCharacteristics,
1352 ) -> Vec<&dyn ImputationModule> {
1353 let mut suitable: Vec<_> = self
1354 .modules
1355 .values()
1356 .filter(|m| m.can_handle(data_info))
1357 .map(|m| m.as_ref())
1358 .collect();
1359
1360 suitable.sort_by_key(|b| std::cmp::Reverse(b.priority()));
1362 suitable
1363 }
1364
1365 pub fn recommend_module(
1367 &self,
1368 data_info: &DataCharacteristics,
1369 ) -> Option<&dyn ImputationModule> {
1370 self.find_suitable_modules(data_info).into_iter().next()
1371 }
1372 }
1373
1374 pub struct PipelineComposer {
1376 stages: Vec<PipelineStage>,
1377 registry: ModuleRegistry,
1378 }
1379
1380 #[derive(Debug, Clone)]
1382 pub struct PipelineStage {
1383 pub name: String,
1385 pub module_name: String,
1387 pub config: ModuleConfig,
1389 pub condition: Option<StageCondition>,
1391 }
1392
1393 #[derive(Debug, Clone)]
1395 pub enum StageCondition {
1396 MissingRate(f64), FeatureCount(usize), DataType(DataType), Custom(String), }
1405
1406 impl PipelineComposer {
1407 pub fn new(registry: ModuleRegistry) -> Self {
1408 Self {
1409 stages: Vec::new(),
1410 registry,
1411 }
1412 }
1413
1414 pub fn add_stage(&mut self, stage: PipelineStage) -> &mut Self {
1416 self.stages.push(stage);
1417 self
1418 }
1419
1420 pub fn add_conditional_stage(
1422 &mut self,
1423 stage: PipelineStage,
1424 condition: StageCondition,
1425 ) -> &mut Self {
1426 let mut stage = stage;
1427 stage.condition = Some(condition);
1428 self.stages.push(stage);
1429 self
1430 }
1431
1432 pub fn build(&self) -> SklResult<ComposedPipeline> {
1434 let mut instances = Vec::new();
1435
1436 for stage in &self.stages {
1437 let module = self
1438 .registry
1439 .get_module(&stage.module_name)
1440 .ok_or_else(|| {
1441 SklearsError::InvalidInput(format!(
1442 "Module '{}' not found",
1443 stage.module_name
1444 ))
1445 })?;
1446
1447 let instance = module.create_instance(&stage.config)?;
1448 instances.push((stage.clone(), instance));
1449 }
1450
1451 Ok(ComposedPipeline { stages: instances })
1452 }
1453 }
1454
1455 pub struct ComposedPipeline {
1457 stages: Vec<(PipelineStage, Box<dyn ImputationInstance>)>,
1458 }
1459
1460 impl ComposedPipeline {
1461 pub fn fit_transform(&mut self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1462 let mut data = X.to_owned();
1463 let data_info = self.analyze_data(&data.view())?;
1464
1465 for (stage, instance) in &mut self.stages {
1466 if let Some(condition) = &stage.condition {
1468 if !Self::evaluate_condition_static(condition, &data_info)? {
1469 continue;
1470 }
1471 }
1472
1473 data = instance.fit_transform(&data.view())?;
1474 }
1475
1476 Ok(data)
1477 }
1478
1479 fn analyze_data(&self, X: &ArrayView2<Float>) -> SklResult<DataCharacteristics> {
1480 let (n_samples, n_features) = X.dim();
1481 let missing_count = X.iter().filter(|&&x| (x).is_nan()).count();
1482 let missing_rate = missing_count as f64 / (n_samples * n_features) as f64;
1483
1484 Ok(DataCharacteristics {
1485 n_samples,
1486 n_features,
1487 missing_rate,
1488 missing_pattern: MissingPatternType::Unknown, data_types: vec![DataType::Continuous; n_features], has_categorical: false,
1491 has_temporal: false,
1492 is_sparse: missing_rate > 0.5,
1493 memory_constraints: None,
1494 })
1495 }
1496
1497 fn evaluate_condition_static(
1498 condition: &StageCondition,
1499 data_info: &DataCharacteristics,
1500 ) -> SklResult<bool> {
1501 Ok(match condition {
1502 StageCondition::MissingRate(threshold) => data_info.missing_rate > *threshold,
1503 StageCondition::FeatureCount(threshold) => data_info.n_features > *threshold,
1504 StageCondition::DataType(data_type) => data_info.data_types.contains(data_type),
1505 StageCondition::Custom(_) => true, })
1507 }
1508 }
1509
1510 pub trait ImputationMiddleware: Send + Sync {
1512 fn before_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1514 Ok(X.to_owned())
1515 }
1516
1517 fn after_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1519 Ok(X.to_owned())
1520 }
1521
1522 fn on_error(&self, error: &SklearsError) -> SklResult<()> {
1524 Err(error.clone())
1525 }
1526 }
1527
1528 pub struct ValidationMiddleware {
1530 pub validate_completeness: bool,
1532 pub validate_ranges: bool,
1534 pub expected_ranges: Option<HashMap<usize, (f64, f64)>>,
1536 }
1537
1538 impl ImputationMiddleware for ValidationMiddleware {
1539 fn after_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1540 if self.validate_completeness && X.iter().any(|&x| (x).is_nan()) {
1541 return Err(SklearsError::InvalidInput(
1542 "Imputation failed: missing values remain".to_string(),
1543 ));
1544 }
1545
1546 if self.validate_ranges {
1547 if let Some(ranges) = &self.expected_ranges {
1548 for ((_, j), &value) in X.indexed_iter() {
1549 if let Some((min_val, max_val)) = ranges.get(&j) {
1550 let val = value;
1551 if val < *min_val || val > *max_val {
1552 return Err(SklearsError::InvalidInput(
1553 format!("Imputed value {} out of expected range [{}, {}] for feature {}",
1554 val, min_val, max_val, j)
1555 ));
1556 }
1557 }
1558 }
1559 }
1560 }
1561
1562 Ok(X.to_owned())
1563 }
1564 }
1565
1566 pub struct LoggingMiddleware {
1568 pub log_level: LogLevel,
1570 pub log_performance: bool,
1572 }
1573
1574 #[derive(Debug, Clone)]
1575 pub enum LogLevel {
1576 Debug,
1578 Info,
1580 Warn,
1582 Error,
1584 }
1585
1586 impl ImputationMiddleware for LoggingMiddleware {
1587 fn before_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1588 if matches!(self.log_level, LogLevel::Debug | LogLevel::Info) {
1589 let missing_count = X.iter().filter(|&&x| (x).is_nan()).count();
1590 println!(
1591 "Starting imputation: {} missing values in {}x{} matrix",
1592 missing_count,
1593 X.nrows(),
1594 X.ncols()
1595 );
1596 }
1597 Ok(X.to_owned())
1598 }
1599
1600 fn after_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1601 if matches!(self.log_level, LogLevel::Debug | LogLevel::Info) {
1602 let remaining_missing = X.iter().filter(|&&x| (x).is_nan()).count();
1603 println!(
1604 "Imputation completed: {} missing values remaining",
1605 remaining_missing
1606 );
1607 }
1608 Ok(X.to_owned())
1609 }
1610 }
1611}
1612
1613#[allow(non_snake_case)]
1614#[cfg(test)]
1615mod tests {
1616 use super::*;
1617 use approx::assert_abs_diff_eq;
1618
1619 #[test]
1620 fn test_fluent_api_simple_imputation() {
1621 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
1622
1623 let pipeline = ImputationBuilder::new()
1624 .simple()
1625 .mean()
1626 .finish()
1627 .build()
1628 .unwrap();
1629
1630 let result = pipeline.fit_transform(&data.view()).unwrap();
1631
1632 assert!(!result.iter().any(|&x| (x).is_nan()));
1634
1635 assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1637 assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
1638 }
1639
1640 #[test]
1641 fn test_fluent_api_knn_imputation() {
1642 let data =
1643 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0, 7.0, 8.0])
1644 .unwrap();
1645
1646 let pipeline = ImputationBuilder::new()
1647 .knn()
1648 .n_neighbors(2)
1649 .distance_weights()
1650 .finish()
1651 .build()
1652 .unwrap();
1653
1654 let result = pipeline.fit_transform(&data.view()).unwrap();
1655
1656 assert!(!result.iter().any(|&x| (x).is_nan()));
1658 }
1659
1660 #[test]
1661 fn test_preset_configurations() {
1662 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
1663
1664 let pipeline = ImputationBuilder::new()
1666 .preset(ImputationPreset::Fast)
1667 .build()
1668 .unwrap();
1669
1670 let result = pipeline.fit_transform(&data.view()).unwrap();
1671 assert!(!result.iter().any(|&x| (x).is_nan()));
1672
1673 let pipeline = ImputationBuilder::new()
1675 .preset(ImputationPreset::Balanced)
1676 .build()
1677 .unwrap();
1678
1679 let result = pipeline.fit_transform(&data.view()).unwrap();
1680 assert!(!result.iter().any(|&x| (x).is_nan()));
1681 }
1682
1683 #[test]
1684 fn test_quick_functions() {
1685 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
1686
1687 let result = quick::mean_impute(&data.view()).unwrap();
1689 assert!(!result.iter().any(|&x| (x).is_nan()));
1690
1691 let result = quick::knn_impute(&data.view(), 2).unwrap();
1693 assert!(!result.iter().any(|&x| (x).is_nan()));
1694 }
1695
1696 #[test]
1697 fn test_method_chaining() {
1698 let builder = ImputationBuilder::new()
1699 .normalize()
1700 .cross_validate(5)
1701 .with_uncertainty()
1702 .parallel(None);
1703
1704 assert!(builder.validation.cross_validation);
1706 assert_eq!(builder.validation.cv_folds, 5);
1707 assert!(builder.preprocessing.normalize);
1708 assert!(builder.postprocessing.add_uncertainty_estimates);
1709 assert!(builder.parallel_config.is_some());
1710 }
1711}