1use std::fmt;
7
8pub type ImputationResult<T> = Result<T, ImputationError>;
10
11#[derive(Debug, Clone)]
13pub enum ImputationError {
14 InvalidParameter(String),
16 InsufficientData(String),
18 ConvergenceFailure(String),
20 MatrixError(String),
22 DimensionMismatch(String),
24 NumericalError(String),
26 ValidationError(String),
28 IOError(String),
30 MemoryError(String),
32 NotImplemented(String),
34 ProcessingError(String),
36 InvalidConfiguration(String),
38}
39
40impl fmt::Display for ImputationError {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 ImputationError::InvalidParameter(msg) => {
44 write!(f, "Invalid parameter: {}", msg)
45 }
46 ImputationError::InsufficientData(msg) => {
47 write!(f, "Insufficient data: {}", msg)
48 }
49 ImputationError::ConvergenceFailure(msg) => {
50 write!(f, "Convergence failure: {}", msg)
51 }
52 ImputationError::MatrixError(msg) => {
53 write!(f, "Matrix error: {}", msg)
54 }
55 ImputationError::DimensionMismatch(msg) => {
56 write!(f, "Dimension mismatch: {}", msg)
57 }
58 ImputationError::NumericalError(msg) => {
59 write!(f, "Numerical error: {}", msg)
60 }
61 ImputationError::ValidationError(msg) => {
62 write!(f, "Validation error: {}", msg)
63 }
64 ImputationError::IOError(msg) => {
65 write!(f, "I/O error: {}", msg)
66 }
67 ImputationError::MemoryError(msg) => {
68 write!(f, "Memory error: {}", msg)
69 }
70 ImputationError::NotImplemented(msg) => {
71 write!(f, "Not implemented: {}", msg)
72 }
73 ImputationError::ProcessingError(msg) => {
74 write!(f, "Processing error: {}", msg)
75 }
76 ImputationError::InvalidConfiguration(msg) => {
77 write!(f, "Invalid configuration: {}", msg)
78 }
79 }
80 }
81}
82
83impl std::error::Error for ImputationError {}
84
85impl From<std::io::Error> for ImputationError {
86 fn from(err: std::io::Error) -> Self {
87 ImputationError::IOError(err.to_string())
88 }
89}
90
91impl From<sklears_core::error::SklearsError> for ImputationError {
92 fn from(err: sklears_core::error::SklearsError) -> Self {
93 ImputationError::ProcessingError(err.to_string())
94 }
95}
96
97impl From<ImputationError> for sklears_core::error::SklearsError {
98 fn from(err: ImputationError) -> Self {
99 match err {
100 ImputationError::InvalidParameter(msg) => {
101 sklears_core::error::SklearsError::InvalidInput(msg)
102 }
103 ImputationError::InsufficientData(msg) => {
104 sklears_core::error::SklearsError::InvalidInput(msg)
105 }
106 ImputationError::ConvergenceFailure(msg) => {
107 sklears_core::error::SklearsError::FitError(msg)
108 }
109 ImputationError::MatrixError(msg) => {
110 sklears_core::error::SklearsError::InvalidInput(msg)
111 }
112 ImputationError::DimensionMismatch(msg) => {
113 sklears_core::error::SklearsError::InvalidInput(msg)
114 }
115 ImputationError::NumericalError(msg) => {
116 sklears_core::error::SklearsError::InvalidInput(msg)
117 }
118 ImputationError::ValidationError(msg) => {
119 sklears_core::error::SklearsError::InvalidInput(msg)
120 }
121 ImputationError::IOError(msg) => sklears_core::error::SklearsError::InvalidInput(msg),
122 ImputationError::MemoryError(msg) => {
123 sklears_core::error::SklearsError::InvalidInput(msg)
124 }
125 ImputationError::NotImplemented(msg) => {
126 sklears_core::error::SklearsError::InvalidInput(msg)
127 }
128 ImputationError::ProcessingError(msg) => {
129 sklears_core::error::SklearsError::InvalidInput(msg)
130 }
131 ImputationError::InvalidConfiguration(msg) => {
132 sklears_core::error::SklearsError::InvalidInput(msg)
133 }
134 }
135 }
136}
137
138pub trait Imputer {
140 fn fit_transform(
142 &self,
143 X: &scirs2_core::ndarray::ArrayView2<f64>,
144 ) -> ImputationResult<scirs2_core::ndarray::Array2<f64>>;
145}
146
147pub trait TrainableImputer {
149 type Trained;
151
152 fn fit(&self, X: &scirs2_core::ndarray::ArrayView2<f64>) -> ImputationResult<Self::Trained>;
154}
155
156pub trait TransformableImputer {
158 fn transform(
160 &self,
161 X: &scirs2_core::ndarray::ArrayView2<f64>,
162 ) -> ImputationResult<scirs2_core::ndarray::Array2<f64>>;
163}
164
165pub trait ImputerConfig {
167 fn validate(&self) -> ImputationResult<()>;
169
170 fn default_config() -> Self;
172}
173
174pub trait QualityAssessment {
176 fn assess_quality(
178 &self,
179 original: &scirs2_core::ndarray::ArrayView2<f64>,
180 imputed: &scirs2_core::ndarray::ArrayView2<f64>,
181 ) -> ImputationResult<f64>;
182}
183
184pub trait MissingPatternHandler {
186 fn analyze_patterns(
188 &self,
189 X: &scirs2_core::ndarray::ArrayView2<f64>,
190 ) -> ImputationResult<std::collections::HashMap<String, f64>>;
191
192 fn identify_mechanism(
194 &self,
195 X: &scirs2_core::ndarray::ArrayView2<f64>,
196 ) -> ImputationResult<String>;
197}
198
199pub trait StatisticalValidator {
201 fn validate_distribution(
203 &self,
204 original: &scirs2_core::ndarray::ArrayView2<f64>,
205 imputed: &scirs2_core::ndarray::ArrayView2<f64>,
206 ) -> ImputationResult<bool>;
207
208 fn test_bias(
210 &self,
211 original: &scirs2_core::ndarray::ArrayView2<f64>,
212 imputed: &scirs2_core::ndarray::ArrayView2<f64>,
213 ) -> ImputationResult<f64>;
214}
215
216#[derive(Debug, Clone)]
218pub struct ImputationMetadata {
219 pub method: String,
221 pub parameters: std::collections::HashMap<String, String>,
223 pub n_imputed: usize,
225 pub convergence_info: Option<ConvergenceInfo>,
227 pub quality_metrics: Option<std::collections::HashMap<String, f64>>,
229 pub processing_time_ms: Option<u64>,
231}
232
233#[derive(Debug, Clone)]
235pub struct ConvergenceInfo {
236 pub n_iterations: usize,
238 pub final_criterion: f64,
240 pub converged: bool,
242 pub history: Vec<f64>,
244}
245
246impl ImputationMetadata {
247 pub fn new(method: String) -> Self {
249 Self {
250 method,
251 parameters: std::collections::HashMap::new(),
252 n_imputed: 0,
253 convergence_info: None,
254 quality_metrics: None,
255 processing_time_ms: None,
256 }
257 }
258
259 pub fn with_parameter(mut self, key: String, value: String) -> Self {
261 self.parameters.insert(key, value);
262 self
263 }
264
265 pub fn with_n_imputed(mut self, n_imputed: usize) -> Self {
267 self.n_imputed = n_imputed;
268 self
269 }
270
271 pub fn with_convergence(mut self, convergence: ConvergenceInfo) -> Self {
273 self.convergence_info = Some(convergence);
274 self
275 }
276
277 pub fn with_quality_metrics(mut self, metrics: std::collections::HashMap<String, f64>) -> Self {
279 self.quality_metrics = Some(metrics);
280 self
281 }
282
283 pub fn with_processing_time(mut self, time_ms: u64) -> Self {
285 self.processing_time_ms = Some(time_ms);
286 self
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct ImputationOutputWithMetadata {
293 pub data: scirs2_core::ndarray::Array2<f64>,
295 pub metadata: ImputationMetadata,
297}
298
299impl ImputationOutputWithMetadata {
300 pub fn new(data: scirs2_core::ndarray::Array2<f64>, metadata: ImputationMetadata) -> Self {
302 Self { data, metadata }
303 }
304}
305
306pub mod utils {
308 use super::*;
309
310 pub fn count_missing(X: &scirs2_core::ndarray::ArrayView2<f64>) -> usize {
312 X.iter().filter(|&&x| x.is_nan()).count()
313 }
314
315 pub fn get_missing_positions(X: &scirs2_core::ndarray::ArrayView2<f64>) -> Vec<(usize, usize)> {
317 X.indexed_iter()
318 .filter_map(|((i, j), &val)| if val.is_nan() { Some((i, j)) } else { None })
319 .collect()
320 }
321
322 pub fn missing_rates_per_feature(X: &scirs2_core::ndarray::ArrayView2<f64>) -> Vec<f64> {
324 let (n_rows, n_cols) = X.dim();
325 let mut rates = Vec::with_capacity(n_cols);
326
327 for j in 0..n_cols {
328 let missing_count = X.column(j).iter().filter(|&&x| x.is_nan()).count();
329 rates.push(missing_count as f64 / n_rows as f64);
330 }
331
332 rates
333 }
334
335 pub fn missing_rates_per_sample(X: &scirs2_core::ndarray::ArrayView2<f64>) -> Vec<f64> {
337 let (n_rows, n_cols) = X.dim();
338 let mut rates = Vec::with_capacity(n_rows);
339
340 for i in 0..n_rows {
341 let missing_count = X.row(i).iter().filter(|&&x| x.is_nan()).count();
342 rates.push(missing_count as f64 / n_cols as f64);
343 }
344
345 rates
346 }
347
348 pub fn validate_input(X: &scirs2_core::ndarray::ArrayView2<f64>) -> ImputationResult<()> {
350 let (n_rows, n_cols) = X.dim();
351
352 if n_rows == 0 {
353 return Err(ImputationError::ValidationError(
354 "Input array has zero rows".to_string(),
355 ));
356 }
357
358 if n_cols == 0 {
359 return Err(ImputationError::ValidationError(
360 "Input array has zero columns".to_string(),
361 ));
362 }
363
364 let all_missing = X.iter().all(|&x| x.is_nan());
366 if all_missing {
367 return Err(ImputationError::InsufficientData(
368 "All values in the input array are missing".to_string(),
369 ));
370 }
371
372 Ok(())
373 }
374
375 pub fn check_dimensions_compatible(
377 X1: &scirs2_core::ndarray::ArrayView2<f64>,
378 X2: &scirs2_core::ndarray::ArrayView2<f64>,
379 ) -> ImputationResult<()> {
380 if X1.dim() != X2.dim() {
381 return Err(ImputationError::DimensionMismatch(format!(
382 "Array dimensions don't match: {:?} vs {:?}",
383 X1.dim(),
384 X2.dim()
385 )));
386 }
387 Ok(())
388 }
389}