sklears_impute/
type_safe.rs

1//! Type-safe missing data operations with phantom types for compile-time validation
2//!
3//! This module provides zero-cost abstractions for missing data handling using Rust's type system
4//! to prevent common errors at compile time.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::{Result as SklResult, SklearsError};
8use std::marker::PhantomData;
9
10/// Phantom type marker for MCAR (Missing Completely At Random) data
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct MCAR;
13
14/// Phantom type marker for MAR (Missing At Random) data
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct MAR;
17
18/// Phantom type marker for MNAR (Missing Not At Random) data
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct MNAR;
21
22/// Phantom type marker for unknown missing data mechanism
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct UnknownMechanism;
25
26/// Phantom type marker for complete data (no missing values)
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Complete;
29
30/// Phantom type marker for data with missing values
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct WithMissing;
33
34/// Type-safe wrapper for arrays with compile-time missing data mechanism tracking
35#[derive(Debug, Clone)]
36pub struct TypedArray<T, M, S> {
37    data: Array2<T>,
38    missing_mask: Option<Array2<bool>>,
39    _mechanism: PhantomData<M>,
40    _state: PhantomData<S>,
41}
42
43/// Type alias for complete data arrays
44pub type CompleteArray<T> = TypedArray<T, UnknownMechanism, Complete>;
45
46/// Type alias for MCAR data arrays
47pub type MCARArray<T> = TypedArray<T, MCAR, WithMissing>;
48
49/// Type alias for MAR data arrays
50pub type MARArray<T> = TypedArray<T, MAR, WithMissing>;
51
52/// Type alias for MNAR data arrays
53pub type MNARArray<T> = TypedArray<T, MNAR, WithMissing>;
54
55/// Missing data pattern information
56#[derive(Debug, Clone)]
57pub struct MissingPattern {
58    /// pattern
59    pub pattern: Vec<bool>,
60    /// count
61    pub count: usize,
62    /// frequency
63    pub frequency: f64,
64}
65
66/// Compile-time missing pattern validator
67pub trait MissingPatternValidator<M> {
68    fn validate_assumptions(&self) -> SklResult<()>;
69    fn recommended_imputers(&self) -> Vec<&'static str>;
70}
71
72impl<T: Clone + PartialEq> TypedArray<T, UnknownMechanism, Complete> {
73    /// Create a new complete array from ndarray
74    pub fn new_complete(data: Array2<T>) -> Self {
75        Self {
76            data,
77            missing_mask: None,
78            _mechanism: PhantomData,
79            _state: PhantomData,
80        }
81    }
82}
83
84impl<T: Clone + PartialEq> TypedArray<T, UnknownMechanism, WithMissing> {
85    /// Create a new array with missing values from ndarray
86    pub fn new_with_missing(data: Array2<T>, missing_mask: Array2<bool>) -> Self {
87        Self {
88            data,
89            missing_mask: Some(missing_mask),
90            _mechanism: PhantomData,
91            _state: PhantomData,
92        }
93    }
94}
95
96impl<T: Clone + PartialEq> TypedArray<T, MCAR, WithMissing> {
97    /// Create a new MCAR array with missing values from ndarray
98    pub fn new_with_missing(data: Array2<T>, missing_mask: Array2<bool>) -> Self {
99        Self {
100            data,
101            missing_mask: Some(missing_mask),
102            _mechanism: PhantomData,
103            _state: PhantomData,
104        }
105    }
106}
107
108impl<T, M, S> TypedArray<T, M, S> {
109    /// Get the underlying data array
110    pub fn data(&self) -> &Array2<T> {
111        &self.data
112    }
113
114    /// Get the missing mask if available
115    pub fn missing_mask(&self) -> Option<&Array2<bool>> {
116        self.missing_mask.as_ref()
117    }
118
119    /// Get the shape of the data
120    pub fn shape(&self) -> (usize, usize) {
121        self.data.dim()
122    }
123
124    /// Get the number of rows
125    pub fn nrows(&self) -> usize {
126        self.data.nrows()
127    }
128
129    /// Get the number of columns
130    pub fn ncols(&self) -> usize {
131        self.data.ncols()
132    }
133}
134
135impl<T: Clone + PartialEq> TypedArray<T, UnknownMechanism, WithMissing> {
136    /// Classify the missing data mechanism and return typed array
137    pub fn classify_mechanism(self) -> SklResult<ClassifiedArray<T>> {
138        let mechanism = self.infer_mechanism()?;
139        Ok(ClassifiedArray::new(
140            self.data,
141            self.missing_mask.unwrap(),
142            mechanism,
143        ))
144    }
145
146    /// Infer the missing data mechanism using statistical tests
147    fn infer_mechanism(&self) -> SklResult<MissingMechanism> {
148        // Simplified mechanism inference - in practice this would use statistical tests
149        let missing_mask = self.missing_mask.as_ref().unwrap();
150        let missing_rate =
151            missing_mask.iter().filter(|&&x| x).count() as f64 / missing_mask.len() as f64;
152
153        if missing_rate < 0.05 {
154            Ok(MissingMechanism::MCAR)
155        } else if missing_rate < 0.2 {
156            Ok(MissingMechanism::MAR)
157        } else {
158            Ok(MissingMechanism::MNAR)
159        }
160    }
161}
162
163/// Enum representing the classified missing data mechanism
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum MissingMechanism {
166    /// MCAR
167    MCAR,
168    /// MAR
169    MAR,
170    /// MNAR
171    MNAR,
172}
173
174/// Classified array with known missing data mechanism
175#[derive(Debug, Clone)]
176pub struct ClassifiedArray<T> {
177    data: Array2<T>,
178    missing_mask: Array2<bool>,
179    mechanism: MissingMechanism,
180}
181
182impl<T: Clone> ClassifiedArray<T> {
183    pub fn new(data: Array2<T>, missing_mask: Array2<bool>, mechanism: MissingMechanism) -> Self {
184        Self {
185            data,
186            missing_mask,
187            mechanism,
188        }
189    }
190
191    pub fn mechanism(&self) -> MissingMechanism {
192        self.mechanism
193    }
194
195    pub fn data(&self) -> &Array2<T> {
196        &self.data
197    }
198
199    pub fn missing_mask(&self) -> &Array2<bool> {
200        &self.missing_mask
201    }
202}
203
204impl<T> MissingPatternValidator<MCAR> for TypedArray<T, MCAR, WithMissing> {
205    fn validate_assumptions(&self) -> SklResult<()> {
206        // For MCAR, missing values should be randomly distributed
207        // This is a simplified check
208        Ok(())
209    }
210
211    fn recommended_imputers(&self) -> Vec<&'static str> {
212        vec!["SimpleImputer", "KNNImputer", "MatrixFactorization"]
213    }
214}
215
216impl<T> MissingPatternValidator<MAR> for TypedArray<T, MAR, WithMissing> {
217    fn validate_assumptions(&self) -> SklResult<()> {
218        // For MAR, missingness depends on observed values
219        // This would check dependencies between observed and missing values
220        Ok(())
221    }
222
223    fn recommended_imputers(&self) -> Vec<&'static str> {
224        vec![
225            "IterativeImputer",
226            "BayesianImputer",
227            "GaussianProcessImputer",
228        ]
229    }
230}
231
232impl<T> MissingPatternValidator<MNAR> for TypedArray<T, MNAR, WithMissing> {
233    fn validate_assumptions(&self) -> SklResult<()> {
234        // For MNAR, missingness depends on unobserved values
235        // This would need domain knowledge validation
236        Ok(())
237    }
238
239    fn recommended_imputers(&self) -> Vec<&'static str> {
240        vec!["PatternMixtureModel", "SelectionModel", "BayesianImputer"]
241    }
242}
243
244/// Type-safe missing data operations
245pub trait TypeSafeMissingOps<T, M, S> {
246    /// Check if data is complete (no missing values)
247    fn is_complete(&self) -> bool;
248
249    /// Count missing values
250    fn count_missing(&self) -> usize;
251
252    /// Get missing rate per feature
253    fn missing_rate_per_feature(&self) -> Array1<f64>;
254
255    /// Get missing pattern analysis
256    fn analyze_patterns(&self) -> Vec<MissingPattern>;
257}
258
259impl<T: Clone + PartialEq> TypeSafeMissingOps<T, UnknownMechanism, WithMissing>
260    for TypedArray<T, UnknownMechanism, WithMissing>
261{
262    fn is_complete(&self) -> bool {
263        self.missing_mask
264            .as_ref()
265            .map_or(true, |mask| !mask.iter().any(|&x| x))
266    }
267
268    fn count_missing(&self) -> usize {
269        self.missing_mask
270            .as_ref()
271            .map_or(0, |mask| mask.iter().filter(|&&x| x).count())
272    }
273
274    fn missing_rate_per_feature(&self) -> Array1<f64> {
275        if let Some(mask) = &self.missing_mask {
276            let n_rows = mask.nrows() as f64;
277            let mut rates = Array1::zeros(mask.ncols());
278
279            for j in 0..mask.ncols() {
280                let missing_count = mask.column(j).iter().filter(|&&x| x).count() as f64;
281                rates[j] = missing_count / n_rows;
282            }
283
284            rates
285        } else {
286            Array1::zeros(self.data.ncols())
287        }
288    }
289
290    fn analyze_patterns(&self) -> Vec<MissingPattern> {
291        if let Some(mask) = &self.missing_mask {
292            let mut pattern_counts = std::collections::HashMap::new();
293            let n_rows = mask.nrows();
294
295            for row in mask.rows() {
296                let pattern: Vec<bool> = row.to_vec();
297                *pattern_counts.entry(pattern).or_insert(0) += 1;
298            }
299
300            pattern_counts
301                .into_iter()
302                .map(|(pattern, count)| MissingPattern {
303                    pattern,
304                    count,
305                    frequency: count as f64 / n_rows as f64,
306                })
307                .collect()
308        } else {
309            vec![]
310        }
311    }
312}
313
314/// Compile-time size validation for fixed-size arrays
315pub trait FixedSizeValidation<const N: usize, const M: usize> {
316    fn validate_dimensions(&self) -> SklResult<()>;
317}
318
319/// Fixed-size typed array for compile-time dimension validation
320#[derive(Debug, Clone)]
321pub struct FixedSizeArray<T, const N: usize, const M: usize> {
322    data: Array2<T>,
323    _phantom: PhantomData<(T, [(); N], [(); M])>,
324}
325
326impl<T: Clone, const N: usize, const M: usize> FixedSizeArray<T, N, M> {
327    pub fn new(data: Array2<T>) -> SklResult<Self> {
328        if data.nrows() != N || data.ncols() != M {
329            return Err(SklearsError::InvalidInput(format!(
330                "Array dimensions {}x{} do not match required {}x{}",
331                data.nrows(),
332                data.ncols(),
333                N,
334                M
335            )));
336        }
337
338        Ok(Self {
339            data,
340            _phantom: PhantomData,
341        })
342    }
343
344    pub fn data(&self) -> &Array2<T> {
345        &self.data
346    }
347}
348
349impl<T, const N: usize, const M: usize> FixedSizeValidation<N, M> for FixedSizeArray<T, N, M> {
350    fn validate_dimensions(&self) -> SklResult<()> {
351        if self.data.nrows() != N || self.data.ncols() != M {
352            Err(SklearsError::InvalidInput(format!(
353                "Invalid dimensions: expected {}x{}, got {}x{}",
354                N,
355                M,
356                self.data.nrows(),
357                self.data.ncols()
358            )))
359        } else {
360            Ok(())
361        }
362    }
363}
364
365/// Zero-cost abstraction for missing value detection
366pub trait MissingValueDetector<T> {
367    fn is_missing(&self, value: &T) -> bool;
368}
369
370/// NaN-based missing value detector for floating point types
371pub struct NaNDetector;
372
373impl MissingValueDetector<f64> for NaNDetector {
374    fn is_missing(&self, value: &f64) -> bool {
375        value.is_nan()
376    }
377}
378
379impl MissingValueDetector<f32> for NaNDetector {
380    fn is_missing(&self, value: &f32) -> bool {
381        value.is_nan()
382    }
383}
384
385/// Sentinel value-based missing value detector
386pub struct SentinelDetector<T> {
387    sentinel: T,
388}
389
390impl<T: PartialEq> SentinelDetector<T> {
391    pub fn new(sentinel: T) -> Self {
392        Self { sentinel }
393    }
394}
395
396impl<T: PartialEq> MissingValueDetector<T> for SentinelDetector<T> {
397    fn is_missing(&self, value: &T) -> bool {
398        *value == self.sentinel
399    }
400}
401
402/// Type-safe imputation result with provenance tracking
403#[derive(Debug, Clone)]
404pub struct ImputationResult<T> {
405    /// data
406    pub data: Array2<T>,
407    /// imputed_positions
408    pub imputed_positions: Vec<(usize, usize)>,
409    /// imputation_method
410    pub imputation_method: String,
411    /// quality_metrics
412    pub quality_metrics: Option<ImputationQualityMetrics>,
413}
414
415/// Quality metrics for imputation results
416#[derive(Debug, Clone)]
417pub struct ImputationQualityMetrics {
418    /// confidence_intervals
419    pub confidence_intervals: Option<Array2<(f64, f64)>>,
420    /// uncertainty_estimates
421    pub uncertainty_estimates: Option<Array2<f64>>,
422    /// imputation_variance
423    pub imputation_variance: Option<f64>,
424}
425
426/// Trait for type-safe imputation operations
427pub trait TypeSafeImputation<T, M> {
428    type Output;
429
430    fn impute(&self, data: &TypedArray<T, M, WithMissing>) -> SklResult<Self::Output>;
431}
432
433/// Example implementation of type-safe mean imputation
434pub struct TypeSafeMeanImputer<D: MissingValueDetector<f64>> {
435    detector: D,
436}
437
438impl<D: MissingValueDetector<f64>> TypeSafeMeanImputer<D> {
439    pub fn new(detector: D) -> Self {
440        Self { detector }
441    }
442}
443
444impl<D: MissingValueDetector<f64>> TypeSafeImputation<f64, MCAR> for TypeSafeMeanImputer<D> {
445    type Output = CompleteArray<f64>;
446
447    fn impute(&self, data: &MCARArray<f64>) -> SklResult<Self::Output> {
448        let mut result = data.data().clone();
449        let mut imputed_positions = Vec::new();
450
451        // Calculate column means
452        let mut column_means = Array1::zeros(data.ncols());
453        for j in 0..data.ncols() {
454            let column = data.data().column(j);
455            let valid_values: Vec<f64> = column
456                .iter()
457                .filter(|&&x| !self.detector.is_missing(&x))
458                .copied()
459                .collect();
460
461            if !valid_values.is_empty() {
462                column_means[j] = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
463            }
464        }
465
466        // Impute missing values
467        for ((i, j), value) in data.data().indexed_iter() {
468            if self.detector.is_missing(value) {
469                result[[i, j]] = column_means[j];
470                imputed_positions.push((i, j));
471            }
472        }
473
474        Ok(CompleteArray::new_complete(result))
475    }
476}
477
478#[allow(non_snake_case)]
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use approx::assert_abs_diff_eq;
483
484    #[test]
485    fn test_typed_array_creation() {
486        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
487        let missing_mask =
488            Array2::from_shape_vec((3, 2), vec![false, false, true, false, false, false]).unwrap();
489
490        let typed_array =
491            TypedArray::<f64, UnknownMechanism, WithMissing>::new_with_missing(data, missing_mask);
492
493        assert_eq!(typed_array.shape(), (3, 2));
494        assert_eq!(typed_array.count_missing(), 1);
495        assert!(!typed_array.is_complete());
496    }
497
498    #[test]
499    fn test_fixed_size_array() {
500        let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
501        let fixed_array = FixedSizeArray::<f64, 2, 3>::new(data).unwrap();
502
503        assert!(fixed_array.validate_dimensions().is_ok());
504        assert_eq!(fixed_array.data().shape(), &[2, 3]);
505    }
506
507    #[test]
508    fn test_fixed_size_array_validation() {
509        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
510        let result = FixedSizeArray::<f64, 2, 3>::new(data);
511
512        assert!(result.is_err());
513    }
514
515    #[test]
516    fn test_nan_detector() {
517        let detector = NaNDetector;
518
519        assert!(detector.is_missing(&f64::NAN));
520        assert!(!detector.is_missing(&1.0));
521        assert!(!detector.is_missing(&0.0));
522    }
523
524    #[test]
525    fn test_sentinel_detector() {
526        let detector = SentinelDetector::new(-999.0);
527
528        assert!(detector.is_missing(&-999.0));
529        assert!(!detector.is_missing(&1.0));
530        assert!(!detector.is_missing(&0.0));
531    }
532
533    #[test]
534    fn test_type_safe_mean_imputation() {
535        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
536        let missing_mask =
537            Array2::from_shape_vec((3, 2), vec![false, false, true, false, false, false]).unwrap();
538
539        let mcar_array = TypedArray::<f64, MCAR, WithMissing>::new_with_missing(data, missing_mask);
540        let imputer = TypeSafeMeanImputer::new(NaNDetector);
541
542        let result = imputer.impute(&mcar_array).unwrap();
543
544        // Mean of column 0: (1.0 + 5.0) / 2 = 3.0
545        assert_abs_diff_eq!(result.data()[[1, 0]], 3.0, epsilon = 1e-10);
546
547        // Other values should be unchanged
548        assert_abs_diff_eq!(result.data()[[0, 0]], 1.0, epsilon = 1e-10);
549        assert_abs_diff_eq!(result.data()[[0, 1]], 2.0, epsilon = 1e-10);
550    }
551
552    #[test]
553    fn test_missing_pattern_analysis() {
554        let data = Array2::from_shape_vec(
555            (4, 3),
556            vec![
557                1.0,
558                2.0,
559                3.0,
560                f64::NAN,
561                5.0,
562                6.0,
563                7.0,
564                f64::NAN,
565                9.0,
566                f64::NAN,
567                11.0,
568                f64::NAN,
569            ],
570        )
571        .unwrap();
572
573        let missing_mask = Array2::from_shape_vec(
574            (4, 3),
575            vec![
576                false, false, false, true, false, false, false, true, false, true, false, true,
577            ],
578        )
579        .unwrap();
580
581        let typed_array =
582            TypedArray::<f64, UnknownMechanism, WithMissing>::new_with_missing(data, missing_mask);
583        let patterns = typed_array.analyze_patterns();
584
585        assert_eq!(patterns.len(), 4); // 4 unique patterns
586
587        // Each pattern should have frequency 0.25 (1/4)
588        for pattern in patterns {
589            assert_abs_diff_eq!(pattern.frequency, 0.25, epsilon = 1e-10);
590            assert_eq!(pattern.count, 1);
591        }
592    }
593
594    #[test]
595    fn test_missing_rate_per_feature() {
596        let data = Array2::from_shape_vec(
597            (4, 3),
598            vec![
599                1.0,
600                2.0,
601                3.0,
602                f64::NAN,
603                5.0,
604                6.0,
605                7.0,
606                f64::NAN,
607                9.0,
608                f64::NAN,
609                11.0,
610                f64::NAN,
611            ],
612        )
613        .unwrap();
614
615        let missing_mask = Array2::from_shape_vec(
616            (4, 3),
617            vec![
618                false, false, false, true, false, false, false, true, false, true, false, true,
619            ],
620        )
621        .unwrap();
622
623        let typed_array =
624            TypedArray::<f64, UnknownMechanism, WithMissing>::new_with_missing(data, missing_mask);
625        let missing_rates = typed_array.missing_rate_per_feature();
626
627        // Column 0: 2/4 = 0.5 missing rate
628        assert_abs_diff_eq!(missing_rates[0], 0.5, epsilon = 1e-10);
629
630        // Column 1: 1/4 = 0.25 missing rate
631        assert_abs_diff_eq!(missing_rates[1], 0.25, epsilon = 1e-10);
632
633        // Column 2: 1/4 = 0.25 missing rate
634        assert_abs_diff_eq!(missing_rates[2], 0.25, epsilon = 1e-10);
635    }
636}