sklears_impute/
core.rs

1//! Core types and traits for imputation operations
2//!
3//! This module provides fundamental types, error handling, and traits
4//! that are used throughout the imputation framework.
5
6use std::fmt;
7
8/// Result type for imputation operations
9pub type ImputationResult<T> = Result<T, ImputationError>;
10
11/// Error types for imputation operations
12#[derive(Debug, Clone)]
13pub enum ImputationError {
14    /// Invalid parameter provided
15    InvalidParameter(String),
16    /// Insufficient data to perform imputation
17    InsufficientData(String),
18    /// Convergence failure in iterative methods
19    ConvergenceFailure(String),
20    /// Matrix operation error (e.g., singular matrix)
21    MatrixError(String),
22    /// Dimension mismatch between arrays
23    DimensionMismatch(String),
24    /// Numerical computation error
25    NumericalError(String),
26    /// Data validation error
27    ValidationError(String),
28    /// I/O operation error
29    IOError(String),
30    /// Memory allocation error
31    MemoryError(String),
32    /// Feature not implemented
33    NotImplemented(String),
34    /// General processing error
35    ProcessingError(String),
36    /// Invalid configuration error
37    InvalidConfiguration(String),
38}
39
40impl fmt::Display for ImputationError {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            ImputationError::InvalidParameter(msg) => {
44                write!(f, "Invalid parameter: {}", msg)
45            }
46            ImputationError::InsufficientData(msg) => {
47                write!(f, "Insufficient data: {}", msg)
48            }
49            ImputationError::ConvergenceFailure(msg) => {
50                write!(f, "Convergence failure: {}", msg)
51            }
52            ImputationError::MatrixError(msg) => {
53                write!(f, "Matrix error: {}", msg)
54            }
55            ImputationError::DimensionMismatch(msg) => {
56                write!(f, "Dimension mismatch: {}", msg)
57            }
58            ImputationError::NumericalError(msg) => {
59                write!(f, "Numerical error: {}", msg)
60            }
61            ImputationError::ValidationError(msg) => {
62                write!(f, "Validation error: {}", msg)
63            }
64            ImputationError::IOError(msg) => {
65                write!(f, "I/O error: {}", msg)
66            }
67            ImputationError::MemoryError(msg) => {
68                write!(f, "Memory error: {}", msg)
69            }
70            ImputationError::NotImplemented(msg) => {
71                write!(f, "Not implemented: {}", msg)
72            }
73            ImputationError::ProcessingError(msg) => {
74                write!(f, "Processing error: {}", msg)
75            }
76            ImputationError::InvalidConfiguration(msg) => {
77                write!(f, "Invalid configuration: {}", msg)
78            }
79        }
80    }
81}
82
83impl std::error::Error for ImputationError {}
84
85impl From<std::io::Error> for ImputationError {
86    fn from(err: std::io::Error) -> Self {
87        ImputationError::IOError(err.to_string())
88    }
89}
90
91impl From<sklears_core::error::SklearsError> for ImputationError {
92    fn from(err: sklears_core::error::SklearsError) -> Self {
93        ImputationError::ProcessingError(err.to_string())
94    }
95}
96
97impl From<ImputationError> for sklears_core::error::SklearsError {
98    fn from(err: ImputationError) -> Self {
99        match err {
100            ImputationError::InvalidParameter(msg) => {
101                sklears_core::error::SklearsError::InvalidInput(msg)
102            }
103            ImputationError::InsufficientData(msg) => {
104                sklears_core::error::SklearsError::InvalidInput(msg)
105            }
106            ImputationError::ConvergenceFailure(msg) => {
107                sklears_core::error::SklearsError::FitError(msg)
108            }
109            ImputationError::MatrixError(msg) => {
110                sklears_core::error::SklearsError::InvalidInput(msg)
111            }
112            ImputationError::DimensionMismatch(msg) => {
113                sklears_core::error::SklearsError::InvalidInput(msg)
114            }
115            ImputationError::NumericalError(msg) => {
116                sklears_core::error::SklearsError::InvalidInput(msg)
117            }
118            ImputationError::ValidationError(msg) => {
119                sklears_core::error::SklearsError::InvalidInput(msg)
120            }
121            ImputationError::IOError(msg) => sklears_core::error::SklearsError::InvalidInput(msg),
122            ImputationError::MemoryError(msg) => {
123                sklears_core::error::SklearsError::InvalidInput(msg)
124            }
125            ImputationError::NotImplemented(msg) => {
126                sklears_core::error::SklearsError::InvalidInput(msg)
127            }
128            ImputationError::ProcessingError(msg) => {
129                sklears_core::error::SklearsError::InvalidInput(msg)
130            }
131            ImputationError::InvalidConfiguration(msg) => {
132                sklears_core::error::SklearsError::InvalidInput(msg)
133            }
134        }
135    }
136}
137
138/// Core trait for imputation methods
139pub trait Imputer {
140    /// Fit the imputer to the data and return the imputed data
141    fn fit_transform(
142        &self,
143        X: &scirs2_core::ndarray::ArrayView2<f64>,
144    ) -> ImputationResult<scirs2_core::ndarray::Array2<f64>>;
145}
146
147/// Trait for imputers that can be trained separately
148pub trait TrainableImputer {
149    /// The trained state type
150    type Trained;
151
152    /// Fit the imputer to training data
153    fn fit(&self, X: &scirs2_core::ndarray::ArrayView2<f64>) -> ImputationResult<Self::Trained>;
154}
155
156/// Trait for trained imputers that can transform data
157pub trait TransformableImputer {
158    /// Transform data using the trained imputer
159    fn transform(
160        &self,
161        X: &scirs2_core::ndarray::ArrayView2<f64>,
162    ) -> ImputationResult<scirs2_core::ndarray::Array2<f64>>;
163}
164
165/// Configuration trait for imputation methods
166pub trait ImputerConfig {
167    /// Validate the configuration
168    fn validate(&self) -> ImputationResult<()>;
169
170    /// Get default configuration
171    fn default_config() -> Self;
172}
173
174/// Trait for imputation quality assessment
175pub trait QualityAssessment {
176    /// Assess the quality of imputation
177    fn assess_quality(
178        &self,
179        original: &scirs2_core::ndarray::ArrayView2<f64>,
180        imputed: &scirs2_core::ndarray::ArrayView2<f64>,
181    ) -> ImputationResult<f64>;
182}
183
184/// Trait for handling missing value patterns
185pub trait MissingPatternHandler {
186    /// Analyze missing value patterns
187    fn analyze_patterns(
188        &self,
189        X: &scirs2_core::ndarray::ArrayView2<f64>,
190    ) -> ImputationResult<std::collections::HashMap<String, f64>>;
191
192    /// Identify missing value mechanism (MCAR, MAR, MNAR)
193    fn identify_mechanism(
194        &self,
195        X: &scirs2_core::ndarray::ArrayView2<f64>,
196    ) -> ImputationResult<String>;
197}
198
199/// Trait for statistical validation of imputations
200pub trait StatisticalValidator {
201    /// Validate distributional properties
202    fn validate_distribution(
203        &self,
204        original: &scirs2_core::ndarray::ArrayView2<f64>,
205        imputed: &scirs2_core::ndarray::ArrayView2<f64>,
206    ) -> ImputationResult<bool>;
207
208    /// Test for bias in imputation
209    fn test_bias(
210        &self,
211        original: &scirs2_core::ndarray::ArrayView2<f64>,
212        imputed: &scirs2_core::ndarray::ArrayView2<f64>,
213    ) -> ImputationResult<f64>;
214}
215
216/// Metadata about the imputation process
217#[derive(Debug, Clone)]
218pub struct ImputationMetadata {
219    /// Method used for imputation
220    pub method: String,
221    /// Parameters used
222    pub parameters: std::collections::HashMap<String, String>,
223    /// Number of values imputed
224    pub n_imputed: usize,
225    /// Convergence information (if applicable)
226    pub convergence_info: Option<ConvergenceInfo>,
227    /// Quality metrics
228    pub quality_metrics: Option<std::collections::HashMap<String, f64>>,
229    /// Processing time in milliseconds
230    pub processing_time_ms: Option<u64>,
231}
232
233/// Information about convergence for iterative methods
234#[derive(Debug, Clone)]
235pub struct ConvergenceInfo {
236    /// Number of iterations performed
237    pub n_iterations: usize,
238    /// Final convergence criterion value
239    pub final_criterion: f64,
240    /// Whether convergence was achieved
241    pub converged: bool,
242    /// Convergence history
243    pub history: Vec<f64>,
244}
245
246impl ImputationMetadata {
247    /// Create new metadata
248    pub fn new(method: String) -> Self {
249        Self {
250            method,
251            parameters: std::collections::HashMap::new(),
252            n_imputed: 0,
253            convergence_info: None,
254            quality_metrics: None,
255            processing_time_ms: None,
256        }
257    }
258
259    /// Add parameter information
260    pub fn with_parameter(mut self, key: String, value: String) -> Self {
261        self.parameters.insert(key, value);
262        self
263    }
264
265    /// Set number of imputed values
266    pub fn with_n_imputed(mut self, n_imputed: usize) -> Self {
267        self.n_imputed = n_imputed;
268        self
269    }
270
271    /// Set convergence information
272    pub fn with_convergence(mut self, convergence: ConvergenceInfo) -> Self {
273        self.convergence_info = Some(convergence);
274        self
275    }
276
277    /// Set quality metrics
278    pub fn with_quality_metrics(mut self, metrics: std::collections::HashMap<String, f64>) -> Self {
279        self.quality_metrics = Some(metrics);
280        self
281    }
282
283    /// Set processing time
284    pub fn with_processing_time(mut self, time_ms: u64) -> Self {
285        self.processing_time_ms = Some(time_ms);
286        self
287    }
288}
289
290/// Result of an imputation operation with metadata
291#[derive(Debug, Clone)]
292pub struct ImputationOutputWithMetadata {
293    /// The imputed data
294    pub data: scirs2_core::ndarray::Array2<f64>,
295    /// Metadata about the imputation process
296    pub metadata: ImputationMetadata,
297}
298
299impl ImputationOutputWithMetadata {
300    /// Create new output with metadata
301    pub fn new(data: scirs2_core::ndarray::Array2<f64>, metadata: ImputationMetadata) -> Self {
302        Self { data, metadata }
303    }
304}
305
306/// Utility functions for common operations
307pub mod utils {
308    use super::*;
309
310    /// Count missing values in an array
311    pub fn count_missing(X: &scirs2_core::ndarray::ArrayView2<f64>) -> usize {
312        X.iter().filter(|&&x| x.is_nan()).count()
313    }
314
315    /// Get missing value positions
316    pub fn get_missing_positions(X: &scirs2_core::ndarray::ArrayView2<f64>) -> Vec<(usize, usize)> {
317        X.indexed_iter()
318            .filter_map(|((i, j), &val)| if val.is_nan() { Some((i, j)) } else { None })
319            .collect()
320    }
321
322    /// Compute missing value rate per feature
323    pub fn missing_rates_per_feature(X: &scirs2_core::ndarray::ArrayView2<f64>) -> Vec<f64> {
324        let (n_rows, n_cols) = X.dim();
325        let mut rates = Vec::with_capacity(n_cols);
326
327        for j in 0..n_cols {
328            let missing_count = X.column(j).iter().filter(|&&x| x.is_nan()).count();
329            rates.push(missing_count as f64 / n_rows as f64);
330        }
331
332        rates
333    }
334
335    /// Compute missing value rate per sample
336    pub fn missing_rates_per_sample(X: &scirs2_core::ndarray::ArrayView2<f64>) -> Vec<f64> {
337        let (n_rows, n_cols) = X.dim();
338        let mut rates = Vec::with_capacity(n_rows);
339
340        for i in 0..n_rows {
341            let missing_count = X.row(i).iter().filter(|&&x| x.is_nan()).count();
342            rates.push(missing_count as f64 / n_cols as f64);
343        }
344
345        rates
346    }
347
348    /// Validate input data for imputation
349    pub fn validate_input(X: &scirs2_core::ndarray::ArrayView2<f64>) -> ImputationResult<()> {
350        let (n_rows, n_cols) = X.dim();
351
352        if n_rows == 0 {
353            return Err(ImputationError::ValidationError(
354                "Input array has zero rows".to_string(),
355            ));
356        }
357
358        if n_cols == 0 {
359            return Err(ImputationError::ValidationError(
360                "Input array has zero columns".to_string(),
361            ));
362        }
363
364        // Check if all values are missing
365        let all_missing = X.iter().all(|&x| x.is_nan());
366        if all_missing {
367            return Err(ImputationError::InsufficientData(
368                "All values in the input array are missing".to_string(),
369            ));
370        }
371
372        Ok(())
373    }
374
375    /// Check if arrays have compatible dimensions
376    pub fn check_dimensions_compatible(
377        X1: &scirs2_core::ndarray::ArrayView2<f64>,
378        X2: &scirs2_core::ndarray::ArrayView2<f64>,
379    ) -> ImputationResult<()> {
380        if X1.dim() != X2.dim() {
381            return Err(ImputationError::DimensionMismatch(format!(
382                "Array dimensions don't match: {:?} vs {:?}",
383                X1.dim(),
384                X2.dim()
385            )));
386        }
387        Ok(())
388    }
389}