1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::{Result as SklResult, SklearsError};
8use std::marker::PhantomData;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct MCAR;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct MAR;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct MNAR;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct UnknownMechanism;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Complete;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct WithMissing;
33
34#[derive(Debug, Clone)]
36pub struct TypedArray<T, M, S> {
37 data: Array2<T>,
38 missing_mask: Option<Array2<bool>>,
39 _mechanism: PhantomData<M>,
40 _state: PhantomData<S>,
41}
42
43pub type CompleteArray<T> = TypedArray<T, UnknownMechanism, Complete>;
45
46pub type MCARArray<T> = TypedArray<T, MCAR, WithMissing>;
48
49pub type MARArray<T> = TypedArray<T, MAR, WithMissing>;
51
52pub type MNARArray<T> = TypedArray<T, MNAR, WithMissing>;
54
55#[derive(Debug, Clone)]
57pub struct MissingPattern {
58 pub pattern: Vec<bool>,
60 pub count: usize,
62 pub frequency: f64,
64}
65
66pub trait MissingPatternValidator<M> {
68 fn validate_assumptions(&self) -> SklResult<()>;
69 fn recommended_imputers(&self) -> Vec<&'static str>;
70}
71
72impl<T: Clone + PartialEq> TypedArray<T, UnknownMechanism, Complete> {
73 pub fn new_complete(data: Array2<T>) -> Self {
75 Self {
76 data,
77 missing_mask: None,
78 _mechanism: PhantomData,
79 _state: PhantomData,
80 }
81 }
82}
83
84impl<T: Clone + PartialEq> TypedArray<T, UnknownMechanism, WithMissing> {
85 pub fn new_with_missing(data: Array2<T>, missing_mask: Array2<bool>) -> Self {
87 Self {
88 data,
89 missing_mask: Some(missing_mask),
90 _mechanism: PhantomData,
91 _state: PhantomData,
92 }
93 }
94}
95
96impl<T: Clone + PartialEq> TypedArray<T, MCAR, WithMissing> {
97 pub fn new_with_missing(data: Array2<T>, missing_mask: Array2<bool>) -> Self {
99 Self {
100 data,
101 missing_mask: Some(missing_mask),
102 _mechanism: PhantomData,
103 _state: PhantomData,
104 }
105 }
106}
107
108impl<T, M, S> TypedArray<T, M, S> {
109 pub fn data(&self) -> &Array2<T> {
111 &self.data
112 }
113
114 pub fn missing_mask(&self) -> Option<&Array2<bool>> {
116 self.missing_mask.as_ref()
117 }
118
119 pub fn shape(&self) -> (usize, usize) {
121 self.data.dim()
122 }
123
124 pub fn nrows(&self) -> usize {
126 self.data.nrows()
127 }
128
129 pub fn ncols(&self) -> usize {
131 self.data.ncols()
132 }
133}
134
135impl<T: Clone + PartialEq> TypedArray<T, UnknownMechanism, WithMissing> {
136 pub fn classify_mechanism(self) -> SklResult<ClassifiedArray<T>> {
138 let mechanism = self.infer_mechanism()?;
139 Ok(ClassifiedArray::new(
140 self.data,
141 self.missing_mask.unwrap(),
142 mechanism,
143 ))
144 }
145
146 fn infer_mechanism(&self) -> SklResult<MissingMechanism> {
148 let missing_mask = self.missing_mask.as_ref().unwrap();
150 let missing_rate =
151 missing_mask.iter().filter(|&&x| x).count() as f64 / missing_mask.len() as f64;
152
153 if missing_rate < 0.05 {
154 Ok(MissingMechanism::MCAR)
155 } else if missing_rate < 0.2 {
156 Ok(MissingMechanism::MAR)
157 } else {
158 Ok(MissingMechanism::MNAR)
159 }
160 }
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum MissingMechanism {
166 MCAR,
168 MAR,
170 MNAR,
172}
173
174#[derive(Debug, Clone)]
176pub struct ClassifiedArray<T> {
177 data: Array2<T>,
178 missing_mask: Array2<bool>,
179 mechanism: MissingMechanism,
180}
181
182impl<T: Clone> ClassifiedArray<T> {
183 pub fn new(data: Array2<T>, missing_mask: Array2<bool>, mechanism: MissingMechanism) -> Self {
184 Self {
185 data,
186 missing_mask,
187 mechanism,
188 }
189 }
190
191 pub fn mechanism(&self) -> MissingMechanism {
192 self.mechanism
193 }
194
195 pub fn data(&self) -> &Array2<T> {
196 &self.data
197 }
198
199 pub fn missing_mask(&self) -> &Array2<bool> {
200 &self.missing_mask
201 }
202}
203
204impl<T> MissingPatternValidator<MCAR> for TypedArray<T, MCAR, WithMissing> {
205 fn validate_assumptions(&self) -> SklResult<()> {
206 Ok(())
209 }
210
211 fn recommended_imputers(&self) -> Vec<&'static str> {
212 vec!["SimpleImputer", "KNNImputer", "MatrixFactorization"]
213 }
214}
215
216impl<T> MissingPatternValidator<MAR> for TypedArray<T, MAR, WithMissing> {
217 fn validate_assumptions(&self) -> SklResult<()> {
218 Ok(())
221 }
222
223 fn recommended_imputers(&self) -> Vec<&'static str> {
224 vec![
225 "IterativeImputer",
226 "BayesianImputer",
227 "GaussianProcessImputer",
228 ]
229 }
230}
231
232impl<T> MissingPatternValidator<MNAR> for TypedArray<T, MNAR, WithMissing> {
233 fn validate_assumptions(&self) -> SklResult<()> {
234 Ok(())
237 }
238
239 fn recommended_imputers(&self) -> Vec<&'static str> {
240 vec!["PatternMixtureModel", "SelectionModel", "BayesianImputer"]
241 }
242}
243
244pub trait TypeSafeMissingOps<T, M, S> {
246 fn is_complete(&self) -> bool;
248
249 fn count_missing(&self) -> usize;
251
252 fn missing_rate_per_feature(&self) -> Array1<f64>;
254
255 fn analyze_patterns(&self) -> Vec<MissingPattern>;
257}
258
259impl<T: Clone + PartialEq> TypeSafeMissingOps<T, UnknownMechanism, WithMissing>
260 for TypedArray<T, UnknownMechanism, WithMissing>
261{
262 fn is_complete(&self) -> bool {
263 self.missing_mask
264 .as_ref()
265 .map_or(true, |mask| !mask.iter().any(|&x| x))
266 }
267
268 fn count_missing(&self) -> usize {
269 self.missing_mask
270 .as_ref()
271 .map_or(0, |mask| mask.iter().filter(|&&x| x).count())
272 }
273
274 fn missing_rate_per_feature(&self) -> Array1<f64> {
275 if let Some(mask) = &self.missing_mask {
276 let n_rows = mask.nrows() as f64;
277 let mut rates = Array1::zeros(mask.ncols());
278
279 for j in 0..mask.ncols() {
280 let missing_count = mask.column(j).iter().filter(|&&x| x).count() as f64;
281 rates[j] = missing_count / n_rows;
282 }
283
284 rates
285 } else {
286 Array1::zeros(self.data.ncols())
287 }
288 }
289
290 fn analyze_patterns(&self) -> Vec<MissingPattern> {
291 if let Some(mask) = &self.missing_mask {
292 let mut pattern_counts = std::collections::HashMap::new();
293 let n_rows = mask.nrows();
294
295 for row in mask.rows() {
296 let pattern: Vec<bool> = row.to_vec();
297 *pattern_counts.entry(pattern).or_insert(0) += 1;
298 }
299
300 pattern_counts
301 .into_iter()
302 .map(|(pattern, count)| MissingPattern {
303 pattern,
304 count,
305 frequency: count as f64 / n_rows as f64,
306 })
307 .collect()
308 } else {
309 vec![]
310 }
311 }
312}
313
314pub trait FixedSizeValidation<const N: usize, const M: usize> {
316 fn validate_dimensions(&self) -> SklResult<()>;
317}
318
319#[derive(Debug, Clone)]
321pub struct FixedSizeArray<T, const N: usize, const M: usize> {
322 data: Array2<T>,
323 _phantom: PhantomData<(T, [(); N], [(); M])>,
324}
325
326impl<T: Clone, const N: usize, const M: usize> FixedSizeArray<T, N, M> {
327 pub fn new(data: Array2<T>) -> SklResult<Self> {
328 if data.nrows() != N || data.ncols() != M {
329 return Err(SklearsError::InvalidInput(format!(
330 "Array dimensions {}x{} do not match required {}x{}",
331 data.nrows(),
332 data.ncols(),
333 N,
334 M
335 )));
336 }
337
338 Ok(Self {
339 data,
340 _phantom: PhantomData,
341 })
342 }
343
344 pub fn data(&self) -> &Array2<T> {
345 &self.data
346 }
347}
348
349impl<T, const N: usize, const M: usize> FixedSizeValidation<N, M> for FixedSizeArray<T, N, M> {
350 fn validate_dimensions(&self) -> SklResult<()> {
351 if self.data.nrows() != N || self.data.ncols() != M {
352 Err(SklearsError::InvalidInput(format!(
353 "Invalid dimensions: expected {}x{}, got {}x{}",
354 N,
355 M,
356 self.data.nrows(),
357 self.data.ncols()
358 )))
359 } else {
360 Ok(())
361 }
362 }
363}
364
365pub trait MissingValueDetector<T> {
367 fn is_missing(&self, value: &T) -> bool;
368}
369
370pub struct NaNDetector;
372
373impl MissingValueDetector<f64> for NaNDetector {
374 fn is_missing(&self, value: &f64) -> bool {
375 value.is_nan()
376 }
377}
378
379impl MissingValueDetector<f32> for NaNDetector {
380 fn is_missing(&self, value: &f32) -> bool {
381 value.is_nan()
382 }
383}
384
385pub struct SentinelDetector<T> {
387 sentinel: T,
388}
389
390impl<T: PartialEq> SentinelDetector<T> {
391 pub fn new(sentinel: T) -> Self {
392 Self { sentinel }
393 }
394}
395
396impl<T: PartialEq> MissingValueDetector<T> for SentinelDetector<T> {
397 fn is_missing(&self, value: &T) -> bool {
398 *value == self.sentinel
399 }
400}
401
402#[derive(Debug, Clone)]
404pub struct ImputationResult<T> {
405 pub data: Array2<T>,
407 pub imputed_positions: Vec<(usize, usize)>,
409 pub imputation_method: String,
411 pub quality_metrics: Option<ImputationQualityMetrics>,
413}
414
415#[derive(Debug, Clone)]
417pub struct ImputationQualityMetrics {
418 pub confidence_intervals: Option<Array2<(f64, f64)>>,
420 pub uncertainty_estimates: Option<Array2<f64>>,
422 pub imputation_variance: Option<f64>,
424}
425
426pub trait TypeSafeImputation<T, M> {
428 type Output;
429
430 fn impute(&self, data: &TypedArray<T, M, WithMissing>) -> SklResult<Self::Output>;
431}
432
433pub struct TypeSafeMeanImputer<D: MissingValueDetector<f64>> {
435 detector: D,
436}
437
438impl<D: MissingValueDetector<f64>> TypeSafeMeanImputer<D> {
439 pub fn new(detector: D) -> Self {
440 Self { detector }
441 }
442}
443
444impl<D: MissingValueDetector<f64>> TypeSafeImputation<f64, MCAR> for TypeSafeMeanImputer<D> {
445 type Output = CompleteArray<f64>;
446
447 fn impute(&self, data: &MCARArray<f64>) -> SklResult<Self::Output> {
448 let mut result = data.data().clone();
449 let mut imputed_positions = Vec::new();
450
451 let mut column_means = Array1::zeros(data.ncols());
453 for j in 0..data.ncols() {
454 let column = data.data().column(j);
455 let valid_values: Vec<f64> = column
456 .iter()
457 .filter(|&&x| !self.detector.is_missing(&x))
458 .copied()
459 .collect();
460
461 if !valid_values.is_empty() {
462 column_means[j] = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
463 }
464 }
465
466 for ((i, j), value) in data.data().indexed_iter() {
468 if self.detector.is_missing(value) {
469 result[[i, j]] = column_means[j];
470 imputed_positions.push((i, j));
471 }
472 }
473
474 Ok(CompleteArray::new_complete(result))
475 }
476}
477
478#[allow(non_snake_case)]
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use approx::assert_abs_diff_eq;
483
484 #[test]
485 fn test_typed_array_creation() {
486 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
487 let missing_mask =
488 Array2::from_shape_vec((3, 2), vec![false, false, true, false, false, false]).unwrap();
489
490 let typed_array =
491 TypedArray::<f64, UnknownMechanism, WithMissing>::new_with_missing(data, missing_mask);
492
493 assert_eq!(typed_array.shape(), (3, 2));
494 assert_eq!(typed_array.count_missing(), 1);
495 assert!(!typed_array.is_complete());
496 }
497
498 #[test]
499 fn test_fixed_size_array() {
500 let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
501 let fixed_array = FixedSizeArray::<f64, 2, 3>::new(data).unwrap();
502
503 assert!(fixed_array.validate_dimensions().is_ok());
504 assert_eq!(fixed_array.data().shape(), &[2, 3]);
505 }
506
507 #[test]
508 fn test_fixed_size_array_validation() {
509 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
510 let result = FixedSizeArray::<f64, 2, 3>::new(data);
511
512 assert!(result.is_err());
513 }
514
515 #[test]
516 fn test_nan_detector() {
517 let detector = NaNDetector;
518
519 assert!(detector.is_missing(&f64::NAN));
520 assert!(!detector.is_missing(&1.0));
521 assert!(!detector.is_missing(&0.0));
522 }
523
524 #[test]
525 fn test_sentinel_detector() {
526 let detector = SentinelDetector::new(-999.0);
527
528 assert!(detector.is_missing(&-999.0));
529 assert!(!detector.is_missing(&1.0));
530 assert!(!detector.is_missing(&0.0));
531 }
532
533 #[test]
534 fn test_type_safe_mean_imputation() {
535 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
536 let missing_mask =
537 Array2::from_shape_vec((3, 2), vec![false, false, true, false, false, false]).unwrap();
538
539 let mcar_array = TypedArray::<f64, MCAR, WithMissing>::new_with_missing(data, missing_mask);
540 let imputer = TypeSafeMeanImputer::new(NaNDetector);
541
542 let result = imputer.impute(&mcar_array).unwrap();
543
544 assert_abs_diff_eq!(result.data()[[1, 0]], 3.0, epsilon = 1e-10);
546
547 assert_abs_diff_eq!(result.data()[[0, 0]], 1.0, epsilon = 1e-10);
549 assert_abs_diff_eq!(result.data()[[0, 1]], 2.0, epsilon = 1e-10);
550 }
551
552 #[test]
553 fn test_missing_pattern_analysis() {
554 let data = Array2::from_shape_vec(
555 (4, 3),
556 vec![
557 1.0,
558 2.0,
559 3.0,
560 f64::NAN,
561 5.0,
562 6.0,
563 7.0,
564 f64::NAN,
565 9.0,
566 f64::NAN,
567 11.0,
568 f64::NAN,
569 ],
570 )
571 .unwrap();
572
573 let missing_mask = Array2::from_shape_vec(
574 (4, 3),
575 vec![
576 false, false, false, true, false, false, false, true, false, true, false, true,
577 ],
578 )
579 .unwrap();
580
581 let typed_array =
582 TypedArray::<f64, UnknownMechanism, WithMissing>::new_with_missing(data, missing_mask);
583 let patterns = typed_array.analyze_patterns();
584
585 assert_eq!(patterns.len(), 4); for pattern in patterns {
589 assert_abs_diff_eq!(pattern.frequency, 0.25, epsilon = 1e-10);
590 assert_eq!(pattern.count, 1);
591 }
592 }
593
594 #[test]
595 fn test_missing_rate_per_feature() {
596 let data = Array2::from_shape_vec(
597 (4, 3),
598 vec![
599 1.0,
600 2.0,
601 3.0,
602 f64::NAN,
603 5.0,
604 6.0,
605 7.0,
606 f64::NAN,
607 9.0,
608 f64::NAN,
609 11.0,
610 f64::NAN,
611 ],
612 )
613 .unwrap();
614
615 let missing_mask = Array2::from_shape_vec(
616 (4, 3),
617 vec![
618 false, false, false, true, false, false, false, true, false, true, false, true,
619 ],
620 )
621 .unwrap();
622
623 let typed_array =
624 TypedArray::<f64, UnknownMechanism, WithMissing>::new_with_missing(data, missing_mask);
625 let missing_rates = typed_array.missing_rate_per_feature();
626
627 assert_abs_diff_eq!(missing_rates[0], 0.5, epsilon = 1e-10);
629
630 assert_abs_diff_eq!(missing_rates[1], 0.25, epsilon = 1e-10);
632
633 assert_abs_diff_eq!(missing_rates[2], 0.25, epsilon = 1e-10);
635 }
636}