sklears_impute/
parallel.rs

1//! Parallel imputation algorithms for high-performance missing data processing
2//!
3//! This module provides parallel implementations of imputation algorithms that can
4//! leverage multiple CPU cores for significant performance improvements on large datasets.
5
6// ✅ SciRS2 Policy compliant imports
7use crate::simd_ops::{SimdDistanceCalculator, SimdImputationOps, SimdStatistics};
8use rayon::prelude::*;
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Fit, Transform, Untrained},
13    types::Float,
14};
15use std::sync::{Arc, Mutex};
16
17/// Configuration for parallel processing
18#[derive(Debug, Clone)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct ParallelConfig {
21    /// max_threads
22    pub max_threads: Option<usize>,
23    /// chunk_size
24    pub chunk_size: usize,
25    /// load_balancing
26    pub load_balancing: bool,
27    /// memory_efficient
28    pub memory_efficient: bool,
29}
30
31impl Default for ParallelConfig {
32    fn default() -> Self {
33        Self {
34            max_threads: None, // Use all available cores
35            chunk_size: 1000,  // Process data in chunks of 1000 rows
36            load_balancing: true,
37            memory_efficient: false,
38        }
39    }
40}
41
42/// Parallel KNN Imputer with SIMD optimizations
43#[derive(Debug, Clone)]
44pub struct ParallelKNNImputer<S = Untrained> {
45    state: S,
46    n_neighbors: usize,
47    weights: String,
48    metric: String,
49    missing_values: f64,
50    parallel_config: ParallelConfig,
51}
52
53/// Trained state for parallel KNN imputer
54#[derive(Debug, Clone)]
55pub struct ParallelKNNImputerTrained {
56    X_train_: Array2<f64>,
57    n_features_in_: usize,
58    parallel_config: ParallelConfig,
59}
60
61impl Default for ParallelKNNImputer<Untrained> {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl ParallelKNNImputer<Untrained> {
68    pub fn new() -> Self {
69        Self {
70            state: Untrained,
71            n_neighbors: 5,
72            weights: "uniform".to_string(),
73            metric: "euclidean".to_string(),
74            missing_values: f64::NAN,
75            parallel_config: ParallelConfig::default(),
76        }
77    }
78
79    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
80        self.n_neighbors = n_neighbors;
81        self
82    }
83
84    pub fn weights(mut self, weights: String) -> Self {
85        self.weights = weights;
86        self
87    }
88
89    pub fn metric(mut self, metric: String) -> Self {
90        self.metric = metric;
91        self
92    }
93
94    pub fn parallel_config(mut self, config: ParallelConfig) -> Self {
95        self.parallel_config = config;
96        self
97    }
98
99    pub fn max_threads(mut self, max_threads: usize) -> Self {
100        self.parallel_config.max_threads = Some(max_threads);
101        self
102    }
103
104    fn is_missing(&self, value: f64) -> bool {
105        if self.missing_values.is_nan() {
106            value.is_nan()
107        } else {
108            (value - self.missing_values).abs() < f64::EPSILON
109        }
110    }
111}
112
113impl Estimator for ParallelKNNImputer<Untrained> {
114    type Config = ();
115    type Error = SklearsError;
116    type Float = Float;
117
118    fn config(&self) -> &Self::Config {
119        &()
120    }
121}
122
123impl Fit<ArrayView2<'_, Float>, ()> for ParallelKNNImputer<Untrained> {
124    type Fitted = ParallelKNNImputer<ParallelKNNImputerTrained>;
125
126    #[allow(non_snake_case)]
127    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
128        let X = X.mapv(|x| x);
129        let (_, n_features) = X.dim();
130
131        Ok(ParallelKNNImputer {
132            state: ParallelKNNImputerTrained {
133                X_train_: X.clone(),
134                n_features_in_: n_features,
135                parallel_config: self.parallel_config.clone(),
136            },
137            n_neighbors: self.n_neighbors,
138            weights: self.weights,
139            metric: self.metric,
140            missing_values: self.missing_values,
141            parallel_config: self.parallel_config,
142        })
143    }
144}
145
146impl Transform<ArrayView2<'_, Float>, Array2<Float>>
147    for ParallelKNNImputer<ParallelKNNImputerTrained>
148{
149    #[allow(non_snake_case)]
150    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
151        let X = X.mapv(|x| x);
152        let (_n_samples, n_features) = X.dim();
153
154        if n_features != self.state.n_features_in_ {
155            return Err(SklearsError::InvalidInput(format!(
156                "Number of features {} does not match training features {}",
157                n_features, self.state.n_features_in_
158            )));
159        }
160
161        // Set up parallel thread pool if specified (ignore if already initialized)
162        if let Some(max_threads) = self.state.parallel_config.max_threads {
163            let _ = rayon::ThreadPoolBuilder::new()
164                .num_threads(max_threads)
165                .build_global(); // Ignore error if already initialized
166        }
167
168        let mut X_imputed = X.clone();
169        let X_train = &self.state.X_train_;
170
171        // Parallel processing over samples and features
172        let missing_positions: Vec<(usize, usize)> = X_imputed
173            .indexed_iter()
174            .filter_map(|((i, j), &val)| {
175                if self.is_missing(val) {
176                    Some((i, j))
177                } else {
178                    None
179                }
180            })
181            .collect();
182
183        // Process missing values in parallel chunks
184        let chunk_size = self
185            .state
186            .parallel_config
187            .chunk_size
188            .min(missing_positions.len().max(1));
189
190        missing_positions
191            .par_chunks(chunk_size)
192            .map(|chunk| {
193                let mut local_imputed = X_imputed.clone();
194
195                for &(i, j) in chunk {
196                    let imputed_value = self.impute_single_value(&X_imputed, X_train, i, j)?;
197                    local_imputed[[i, j]] = imputed_value;
198                }
199
200                Ok::<Array2<f64>, SklearsError>(local_imputed)
201            })
202            .collect::<SklResult<Vec<_>>>()?
203            .into_iter()
204            .for_each(|chunk_result| {
205                // Merge results back (this could be optimized further)
206                for &(i, j) in missing_positions.iter() {
207                    if !self.is_missing(chunk_result[[i, j]]) {
208                        X_imputed[[i, j]] = chunk_result[[i, j]];
209                    }
210                }
211            });
212
213        Ok(X_imputed.mapv(|x| x as Float))
214    }
215}
216
217impl ParallelKNNImputer<ParallelKNNImputerTrained> {
218    fn is_missing(&self, value: f64) -> bool {
219        if self.missing_values.is_nan() {
220            value.is_nan()
221        } else {
222            (value - self.missing_values).abs() < f64::EPSILON
223        }
224    }
225
226    fn impute_single_value(
227        &self,
228        X: &Array2<f64>,
229        X_train: &Array2<f64>,
230        row_idx: usize,
231        col_idx: usize,
232    ) -> SklResult<f64> {
233        let query_row = X.row(row_idx);
234
235        // Calculate distances to all training samples in parallel
236        let distances: Vec<(f64, usize)> = X_train
237            .axis_iter(Axis(0))
238            .enumerate()
239            .par_bridge()
240            .map(|(train_idx, train_row)| {
241                let distance = match self.metric.as_str() {
242                    "euclidean" => SimdDistanceCalculator::euclidean_distance_simd(
243                        query_row.as_slice().unwrap(),
244                        train_row.as_slice().unwrap(),
245                    ),
246                    "manhattan" => SimdDistanceCalculator::manhattan_distance_simd(
247                        query_row.as_slice().unwrap(),
248                        train_row.as_slice().unwrap(),
249                    ),
250                    _ => SimdDistanceCalculator::euclidean_distance_simd(
251                        query_row.as_slice().unwrap(),
252                        train_row.as_slice().unwrap(),
253                    ),
254                };
255                (distance, train_idx)
256            })
257            .collect();
258
259        // Sort and find k nearest neighbors
260        let mut sorted_distances = distances;
261        sorted_distances.sort_by(|a, b| {
262            a.0.partial_cmp(&b.0).unwrap_or_else(|| {
263                // Handle NaN and infinity cases
264                if a.0.is_nan() && b.0.is_nan() {
265                    std::cmp::Ordering::Equal
266                } else if a.0.is_nan() {
267                    std::cmp::Ordering::Greater // NaN is considered larger
268                } else if b.0.is_nan() {
269                    std::cmp::Ordering::Less
270                } else {
271                    std::cmp::Ordering::Equal
272                }
273            })
274        });
275
276        // Collect valid neighbor values
277        let mut neighbor_values = Vec::new();
278        let mut weights = Vec::new();
279
280        for &(distance, train_idx) in sorted_distances.iter().take(self.n_neighbors * 3) {
281            if !self.is_missing(X_train[[train_idx, col_idx]]) {
282                neighbor_values.push(X_train[[train_idx, col_idx]]);
283
284                let weight = match self.weights.as_str() {
285                    "distance" => {
286                        if distance > 0.0 {
287                            1.0 / distance
288                        } else {
289                            1e6
290                        }
291                    }
292                    _ => 1.0,
293                };
294                weights.push(weight);
295
296                if neighbor_values.len() >= self.n_neighbors {
297                    break;
298                }
299            }
300        }
301
302        if neighbor_values.is_empty() {
303            // Fallback to column mean
304            let column = X_train.column(col_idx);
305            let valid_values: Vec<f64> = column
306                .iter()
307                .filter(|&&x| !self.is_missing(x))
308                .cloned()
309                .collect();
310
311            if !valid_values.is_empty() {
312                Ok(SimdStatistics::mean_simd(&valid_values))
313            } else {
314                Ok(0.0)
315            }
316        } else {
317            // Use SIMD-optimized weighted mean
318            Ok(SimdImputationOps::weighted_mean_simd(
319                &neighbor_values,
320                &weights,
321            ))
322        }
323    }
324}
325
326/// Parallel Iterative Imputer (MICE) with multi-threading
327#[derive(Debug, Clone)]
328pub struct ParallelIterativeImputer<S = Untrained> {
329    state: S,
330    max_iter: usize,
331    tol: f64,
332    n_nearest_features: Option<usize>,
333    random_state: Option<u64>,
334    parallel_config: ParallelConfig,
335}
336
337/// Trained state for parallel iterative imputer
338#[derive(Debug, Clone)]
339pub struct ParallelIterativeImputerTrained {
340    n_features_in_: usize,
341    missing_mask_: Array2<bool>,
342    parallel_config: ParallelConfig,
343    random_state: Option<u64>,
344}
345
346impl Default for ParallelIterativeImputer<Untrained> {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352impl ParallelIterativeImputer<Untrained> {
353    pub fn new() -> Self {
354        Self {
355            state: Untrained,
356            max_iter: 10,
357            tol: 1e-3,
358            n_nearest_features: None,
359            random_state: None,
360            parallel_config: ParallelConfig::default(),
361        }
362    }
363
364    pub fn max_iter(mut self, max_iter: usize) -> Self {
365        self.max_iter = max_iter;
366        self
367    }
368
369    pub fn tol(mut self, tol: f64) -> Self {
370        self.tol = tol;
371        self
372    }
373
374    pub fn parallel_config(mut self, config: ParallelConfig) -> Self {
375        self.parallel_config = config;
376        self
377    }
378
379    pub fn random_state(mut self, random_state: u64) -> Self {
380        self.random_state = Some(random_state);
381        self
382    }
383}
384
385impl Estimator for ParallelIterativeImputer<Untrained> {
386    type Config = ();
387    type Error = SklearsError;
388    type Float = Float;
389
390    fn config(&self) -> &Self::Config {
391        &()
392    }
393}
394
395impl Fit<ArrayView2<'_, Float>, ()> for ParallelIterativeImputer<Untrained> {
396    type Fitted = ParallelIterativeImputer<ParallelIterativeImputerTrained>;
397
398    #[allow(non_snake_case)]
399    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
400        let X = X.mapv(|x| x);
401        let (_, n_features) = X.dim();
402
403        // Create missing mask
404        let missing_mask = X.mapv(|x| x.is_nan());
405
406        Ok(ParallelIterativeImputer {
407            state: ParallelIterativeImputerTrained {
408                n_features_in_: n_features,
409                missing_mask_: missing_mask,
410                parallel_config: self.parallel_config.clone(),
411                random_state: self.random_state,
412            },
413            max_iter: self.max_iter,
414            tol: self.tol,
415            n_nearest_features: self.n_nearest_features,
416            random_state: self.random_state,
417            parallel_config: self.parallel_config,
418        })
419    }
420}
421
422impl Transform<ArrayView2<'_, Float>, Array2<Float>>
423    for ParallelIterativeImputer<ParallelIterativeImputerTrained>
424{
425    #[allow(non_snake_case)]
426    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
427        let X = X.mapv(|x| x);
428        let (_n_samples, n_features) = X.dim();
429
430        if n_features != self.state.n_features_in_ {
431            return Err(SklearsError::InvalidInput(format!(
432                "Number of features {} does not match training features {}",
433                n_features, self.state.n_features_in_
434            )));
435        }
436
437        // Initialize with simple mean imputation
438        let mut X_imputed = self.initial_imputation(&X)?;
439        let missing_mask = X.mapv(|x| x.is_nan());
440
441        // Iterative imputation with parallel feature processing
442        for _iteration in 0..self.max_iter {
443            let X_prev = X_imputed.clone();
444
445            // Process features in parallel
446            let feature_indices: Vec<usize> = (0..n_features).collect();
447            let imputed_features: SklResult<Vec<Array1<f64>>> = feature_indices
448                .par_iter()
449                .map(|&feature_idx| self.impute_feature(&X_imputed, &missing_mask, feature_idx))
450                .collect();
451
452            let imputed_features = imputed_features?;
453
454            // Update imputed values for each feature
455            for (feature_idx, feature_values) in imputed_features.into_iter().enumerate() {
456                for (sample_idx, &value) in feature_values.iter().enumerate() {
457                    if missing_mask[[sample_idx, feature_idx]] {
458                        X_imputed[[sample_idx, feature_idx]] = value;
459                    }
460                }
461            }
462
463            // Check convergence using SIMD-optimized calculations
464            let diff = self.calculate_convergence_difference(&X_prev, &X_imputed, &missing_mask);
465            if diff < self.tol {
466                break;
467            }
468        }
469
470        Ok(X_imputed.mapv(|x| x as Float))
471    }
472}
473
474impl ParallelIterativeImputer<ParallelIterativeImputerTrained> {
475    fn initial_imputation(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
476        let mut X_imputed = X.clone();
477
478        // Parallel mean imputation for each column
479        let column_means: Vec<f64> = (0..X.ncols())
480            .into_par_iter()
481            .map(|col_idx| {
482                let column = X.column(col_idx);
483                let valid_values: Vec<f64> =
484                    column.iter().filter(|&&x| !x.is_nan()).cloned().collect();
485
486                if !valid_values.is_empty() {
487                    SimdStatistics::mean_simd(&valid_values)
488                } else {
489                    0.0
490                }
491            })
492            .collect();
493
494        // Fill missing values with means
495        for ((_i, j), value) in X_imputed.indexed_iter_mut() {
496            if value.is_nan() {
497                *value = column_means[j];
498            }
499        }
500
501        Ok(X_imputed)
502    }
503
504    fn impute_feature(
505        &self,
506        X: &Array2<f64>,
507        missing_mask: &Array2<bool>,
508        feature_idx: usize,
509    ) -> SklResult<Array1<f64>> {
510        let n_samples = X.nrows();
511        let mut imputed_feature = Array1::zeros(n_samples);
512
513        // Identify samples with missing values for this feature
514        let missing_samples: Vec<usize> = (0..n_samples)
515            .filter(|&i| missing_mask[[i, feature_idx]])
516            .collect();
517
518        if missing_samples.is_empty() {
519            return Ok(X.column(feature_idx).to_owned());
520        }
521
522        // Use other features as predictors
523        let predictor_features: Vec<usize> = (0..X.ncols()).filter(|&i| i != feature_idx).collect();
524
525        // Simple linear regression for each missing sample
526        missing_samples
527            .par_iter()
528            .map(|&sample_idx| {
529                self.predict_missing_value(X, sample_idx, feature_idx, &predictor_features)
530            })
531            .collect::<SklResult<Vec<_>>>()?
532            .into_iter()
533            .zip(missing_samples.iter())
534            .for_each(|(predicted_value, &sample_idx)| {
535                imputed_feature[sample_idx] = predicted_value;
536            });
537
538        // Copy non-missing values
539        for i in 0..n_samples {
540            if !missing_mask[[i, feature_idx]] {
541                imputed_feature[i] = X[[i, feature_idx]];
542            }
543        }
544
545        Ok(imputed_feature)
546    }
547
548    fn predict_missing_value(
549        &self,
550        X: &Array2<f64>,
551        sample_idx: usize,
552        target_feature: usize,
553        predictor_features: &[usize],
554    ) -> SklResult<f64> {
555        // Find complete cases for regression
556        let complete_samples: Vec<usize> = (0..X.nrows())
557            .filter(|&i| {
558                !X[[i, target_feature]].is_nan()
559                    && predictor_features.iter().all(|&j| !X[[i, j]].is_nan())
560            })
561            .collect();
562
563        if complete_samples.len() < 2 {
564            // Fallback to column mean
565            let column = X.column(target_feature);
566            let valid_values: Vec<f64> = column.iter().filter(|&&x| !x.is_nan()).cloned().collect();
567
568            return Ok(if !valid_values.is_empty() {
569                SimdStatistics::mean_simd(&valid_values)
570            } else {
571                0.0
572            });
573        }
574
575        // Simple average of k nearest neighbors based on predictor features
576        let query_features: Vec<f64> = predictor_features
577            .iter()
578            .map(|&j| X[[sample_idx, j]])
579            .collect();
580
581        let mut distances: Vec<(f64, usize)> = complete_samples
582            .par_iter()
583            .map(|&complete_idx| {
584                let sample_features: Vec<f64> = predictor_features
585                    .iter()
586                    .map(|&j| X[[complete_idx, j]])
587                    .collect();
588
589                let distance = SimdDistanceCalculator::euclidean_distance_simd(
590                    &query_features,
591                    &sample_features,
592                );
593
594                (distance, complete_idx)
595            })
596            .collect();
597
598        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
599
600        // Use top k neighbors
601        let k = 5.min(distances.len());
602        let neighbor_values: Vec<f64> = distances
603            .iter()
604            .take(k)
605            .map(|(_, idx)| X[[*idx, target_feature]])
606            .collect();
607
608        Ok(SimdStatistics::mean_simd(&neighbor_values))
609    }
610
611    fn calculate_convergence_difference(
612        &self,
613        X_prev: &Array2<f64>,
614        X_current: &Array2<f64>,
615        missing_mask: &Array2<bool>,
616    ) -> f64 {
617        let differences: Vec<f64> = X_prev
618            .indexed_iter()
619            .par_bridge()
620            .filter_map(|((i, j), &prev_val)| {
621                if missing_mask[[i, j]] {
622                    Some((prev_val - X_current[[i, j]]).abs())
623                } else {
624                    None
625                }
626            })
627            .collect();
628
629        if differences.is_empty() {
630            0.0
631        } else {
632            SimdStatistics::mean_simd(&differences)
633        }
634    }
635}
636
637/// Parallel memory-efficient imputation for large datasets
638pub struct MemoryEfficientImputer;
639
640impl MemoryEfficientImputer {
641    /// Process large datasets in chunks to minimize memory usage
642    pub fn impute_chunked<F>(
643        data: &Array2<f64>,
644        chunk_size: usize,
645        impute_fn: F,
646    ) -> SklResult<Array2<f64>>
647    where
648        F: Fn(&Array2<f64>) -> SklResult<Array2<f64>> + Sync + Send,
649    {
650        let (n_rows, n_cols) = data.dim();
651        let mut result = Array2::zeros((n_rows, n_cols));
652
653        // Process in row chunks
654        let chunks: Vec<_> = (0..n_rows).step_by(chunk_size).collect();
655
656        chunks
657            .par_iter()
658            .map(|&start_row| {
659                let end_row = (start_row + chunk_size).min(n_rows);
660                let chunk = data.slice(s![start_row..end_row, ..]).to_owned();
661
662                let imputed_chunk = impute_fn(&chunk)?;
663                Ok((start_row, imputed_chunk))
664            })
665            .collect::<SklResult<Vec<_>>>()?
666            .into_iter()
667            .for_each(|(start_row, imputed_chunk)| {
668                let end_row = start_row + imputed_chunk.nrows();
669                result
670                    .slice_mut(s![start_row..end_row, ..])
671                    .assign(&imputed_chunk);
672            });
673
674        Ok(result)
675    }
676
677    /// Process streaming data with online imputation
678    pub fn stream_impute<F>(
679        stream: impl Iterator<Item = SklResult<Array1<f64>>> + Send,
680        impute_fn: F,
681    ) -> impl Iterator<Item = SklResult<Array1<f64>>> + Send
682    where
683        F: Fn(&Array1<f64>) -> SklResult<Array1<f64>> + Sync + Send + Clone + 'static,
684    {
685        stream.map(move |result| {
686            let row = result?;
687            impute_fn(&row)
688        })
689    }
690}
691
692/// Streaming imputation for very large datasets that don't fit in memory
693#[derive(Debug, Clone)]
694pub struct StreamingImputer {
695    window_size: usize,
696    buffer_size: usize,
697    strategy: String,
698    missing_values: f64,
699}
700
701impl Default for StreamingImputer {
702    fn default() -> Self {
703        Self::new()
704    }
705}
706
707impl StreamingImputer {
708    pub fn new() -> Self {
709        Self {
710            window_size: 1000,
711            buffer_size: 10000,
712            strategy: "mean".to_string(),
713            missing_values: f64::NAN,
714        }
715    }
716
717    pub fn window_size(mut self, window_size: usize) -> Self {
718        self.window_size = window_size;
719        self
720    }
721
722    pub fn buffer_size(mut self, buffer_size: usize) -> Self {
723        self.buffer_size = buffer_size;
724        self
725    }
726
727    pub fn strategy(mut self, strategy: String) -> Self {
728        self.strategy = strategy;
729        self
730    }
731
732    /// Process streaming data with online statistics update
733    pub fn fit_transform_stream<I>(&self, data_stream: I) -> SklResult<Vec<Array1<f64>>>
734    where
735        I: Iterator<Item = SklResult<Array1<f64>>>,
736    {
737        let mut results = Vec::new();
738        let mut statistics = OnlineStatistics::new();
739        let mut buffer = Vec::new();
740
741        for row_result in data_stream {
742            let row = row_result?;
743            buffer.push(row.clone());
744
745            // Update online statistics
746            statistics.update(&row);
747
748            // Process buffer when full
749            if buffer.len() >= self.buffer_size {
750                let processed_buffer = self.process_buffer(&buffer, &statistics)?;
751                results.extend(processed_buffer);
752                buffer.clear();
753            }
754        }
755
756        // Process remaining buffer
757        if !buffer.is_empty() {
758            let processed_buffer = self.process_buffer(&buffer, &statistics)?;
759            results.extend(processed_buffer);
760        }
761
762        Ok(results)
763    }
764
765    fn process_buffer(
766        &self,
767        buffer: &[Array1<f64>],
768        statistics: &OnlineStatistics,
769    ) -> SklResult<Vec<Array1<f64>>> {
770        buffer
771            .par_iter()
772            .map(|row| self.impute_row(row, statistics))
773            .collect()
774    }
775
776    fn impute_row(
777        &self,
778        row: &Array1<f64>,
779        statistics: &OnlineStatistics,
780    ) -> SklResult<Array1<f64>> {
781        let mut imputed_row = row.clone();
782
783        for (i, &value) in row.iter().enumerate() {
784            if self.is_missing(value) {
785                let imputed_value = match self.strategy.as_str() {
786                    "mean" => statistics.get_mean(i),
787                    "median" => statistics.get_median(i),
788                    "mode" => statistics.get_mode(i),
789                    _ => statistics.get_mean(i),
790                };
791                imputed_row[i] = imputed_value;
792            }
793        }
794
795        Ok(imputed_row)
796    }
797
798    fn is_missing(&self, value: f64) -> bool {
799        if self.missing_values.is_nan() {
800            value.is_nan()
801        } else {
802            (value - self.missing_values).abs() < f64::EPSILON
803        }
804    }
805}
806
807/// Online statistics for streaming imputation
808#[derive(Debug, Clone)]
809pub struct OnlineStatistics {
810    n_samples: usize,
811    means: Vec<f64>,
812    variances: Vec<f64>,
813    mins: Vec<f64>,
814    maxs: Vec<f64>,
815    value_counts: Vec<std::collections::HashMap<i64, usize>>,
816    n_features: usize,
817}
818
819impl Default for OnlineStatistics {
820    fn default() -> Self {
821        Self::new()
822    }
823}
824
825impl OnlineStatistics {
826    pub fn new() -> Self {
827        Self {
828            n_samples: 0,
829            means: Vec::new(),
830            variances: Vec::new(),
831            mins: Vec::new(),
832            maxs: Vec::new(),
833            value_counts: Vec::new(),
834            n_features: 0,
835        }
836    }
837
838    pub fn update(&mut self, row: &Array1<f64>) {
839        if self.n_features == 0 {
840            self.n_features = row.len();
841            self.means = vec![0.0; self.n_features];
842            self.variances = vec![0.0; self.n_features];
843            self.mins = vec![f64::INFINITY; self.n_features];
844            self.maxs = vec![f64::NEG_INFINITY; self.n_features];
845            self.value_counts = vec![std::collections::HashMap::new(); self.n_features];
846        }
847
848        self.n_samples += 1;
849
850        for (i, &value) in row.iter().enumerate() {
851            if !value.is_nan() {
852                // Update mean using Welford's online algorithm
853                let delta = value - self.means[i];
854                self.means[i] += delta / self.n_samples as f64;
855                let delta2 = value - self.means[i];
856                self.variances[i] += delta * delta2;
857
858                // Update min/max
859                self.mins[i] = self.mins[i].min(value);
860                self.maxs[i] = self.maxs[i].max(value);
861
862                // Update value counts for mode calculation
863                let rounded_value = (value * 1000.0).round() as i64;
864                *self.value_counts[i].entry(rounded_value).or_insert(0) += 1;
865            }
866        }
867    }
868
869    pub fn get_mean(&self, feature_idx: usize) -> f64 {
870        if feature_idx < self.means.len() {
871            self.means[feature_idx]
872        } else {
873            0.0
874        }
875    }
876
877    pub fn get_median(&self, feature_idx: usize) -> f64 {
878        if feature_idx < self.means.len() {
879            self.means[feature_idx]
880        } else {
881            0.0
882        }
883    }
884
885    pub fn get_mode(&self, feature_idx: usize) -> f64 {
886        if feature_idx < self.value_counts.len() {
887            self.value_counts[feature_idx]
888                .iter()
889                .max_by_key(|(_, &count)| count)
890                .map(|(&value, _)| value as f64 / 1000.0)
891                .unwrap_or(0.0)
892        } else {
893            0.0
894        }
895    }
896
897    pub fn get_variance(&self, feature_idx: usize) -> f64 {
898        if feature_idx < self.variances.len() && self.n_samples > 1 {
899            self.variances[feature_idx] / (self.n_samples - 1) as f64
900        } else {
901            0.0
902        }
903    }
904}
905
906/// Adaptive streaming imputation that learns from incoming data
907#[derive(Debug, Clone)]
908pub struct AdaptiveStreamingImputer {
909    learning_rate: f64,
910    forgetting_factor: f64,
911    min_samples_for_adaptation: usize,
912    statistics: OnlineStatistics,
913}
914
915impl Default for AdaptiveStreamingImputer {
916    fn default() -> Self {
917        Self::new()
918    }
919}
920
921impl AdaptiveStreamingImputer {
922    pub fn new() -> Self {
923        Self {
924            learning_rate: 0.01,
925            forgetting_factor: 0.99,
926            min_samples_for_adaptation: 100,
927            statistics: OnlineStatistics::new(),
928        }
929    }
930
931    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
932        self.learning_rate = learning_rate;
933        self
934    }
935
936    pub fn forgetting_factor(mut self, forgetting_factor: f64) -> Self {
937        self.forgetting_factor = forgetting_factor;
938        self
939    }
940
941    /// Process a single row with adaptive learning
942    pub fn fit_transform_single(&mut self, row: &Array1<f64>) -> SklResult<Array1<f64>> {
943        let mut imputed_row = row.clone();
944
945        // First, impute missing values using current statistics
946        for (i, &value) in row.iter().enumerate() {
947            if value.is_nan() {
948                let imputed_value = if self.statistics.n_samples >= self.min_samples_for_adaptation
949                {
950                    self.statistics.get_mean(i)
951                } else {
952                    0.0
953                };
954                imputed_row[i] = imputed_value;
955            }
956        }
957
958        // Update statistics with the imputed row
959        self.statistics.update(&imputed_row);
960
961        // Apply forgetting factor to adapt to data drift
962        if self.statistics.n_samples > self.min_samples_for_adaptation {
963            for i in 0..self.statistics.n_features {
964                self.statistics.means[i] *= self.forgetting_factor;
965                self.statistics.variances[i] *= self.forgetting_factor;
966            }
967        }
968
969        Ok(imputed_row)
970    }
971}
972
973#[allow(non_snake_case)]
974#[cfg(test)]
975mod tests {
976    use super::*;
977    use approx::assert_abs_diff_eq;
978
979    #[test]
980    fn test_parallel_knn_imputer() {
981        let data = Array2::from_shape_vec(
982            (4, 3),
983            vec![
984                1.0,
985                2.0,
986                3.0,
987                f64::NAN,
988                5.0,
989                6.0,
990                7.0,
991                8.0,
992                9.0,
993                10.0,
994                11.0,
995                12.0,
996            ],
997        )
998        .unwrap();
999
1000        let config = ParallelConfig {
1001            max_threads: Some(2),
1002            ..Default::default()
1003        };
1004
1005        let imputer = ParallelKNNImputer::new()
1006            .n_neighbors(2)
1007            .parallel_config(config);
1008
1009        let fitted = imputer.fit(&data.view(), &()).unwrap();
1010        let result = fitted.transform(&data.view()).unwrap();
1011
1012        // Should have no missing values
1013        assert!(!result.iter().any(|&x| (x).is_nan()));
1014
1015        // Non-missing values should be preserved
1016        assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1017        assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
1018        assert_abs_diff_eq!(result[[0, 2]], 3.0, epsilon = 1e-10);
1019    }
1020
1021    #[test]
1022    fn test_parallel_iterative_imputer() {
1023        let data = Array2::from_shape_vec(
1024            (5, 3),
1025            vec![
1026                1.0,
1027                2.0,
1028                3.0,
1029                f64::NAN,
1030                5.0,
1031                6.0,
1032                7.0,
1033                f64::NAN,
1034                9.0,
1035                10.0,
1036                11.0,
1037                12.0,
1038                13.0,
1039                14.0,
1040                f64::NAN,
1041            ],
1042        )
1043        .unwrap();
1044
1045        let config = ParallelConfig {
1046            max_threads: Some(2),
1047            ..Default::default()
1048        };
1049
1050        let imputer = ParallelIterativeImputer::new()
1051            .max_iter(5)
1052            .tol(1e-3)
1053            .parallel_config(config);
1054
1055        let fitted = imputer.fit(&data.view(), &()).unwrap();
1056        let result = fitted.transform(&data.view()).unwrap();
1057
1058        // Should have no missing values
1059        assert!(!result.iter().any(|&x| (x).is_nan()));
1060
1061        // Non-missing values should be preserved
1062        assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1063        assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
1064        assert_abs_diff_eq!(result[[0, 2]], 3.0, epsilon = 1e-10);
1065    }
1066
1067    #[test]
1068    fn test_memory_efficient_chunked_processing() {
1069        let data = Array2::from_shape_vec(
1070            (10, 3),
1071            vec![
1072                1.0,
1073                2.0,
1074                3.0,
1075                f64::NAN,
1076                5.0,
1077                6.0,
1078                7.0,
1079                8.0,
1080                9.0,
1081                10.0,
1082                f64::NAN,
1083                12.0,
1084                13.0,
1085                14.0,
1086                15.0,
1087                16.0,
1088                17.0,
1089                f64::NAN,
1090                19.0,
1091                20.0,
1092                21.0,
1093                22.0,
1094                f64::NAN,
1095                24.0,
1096                25.0,
1097                26.0,
1098                27.0,
1099                28.0,
1100                29.0,
1101                30.0,
1102            ],
1103        )
1104        .unwrap();
1105
1106        let simple_impute_fn = |chunk: &Array2<f64>| -> SklResult<Array2<f64>> {
1107            let mut result = chunk.clone();
1108
1109            // Simple mean imputation
1110            for j in 0..chunk.ncols() {
1111                let column = chunk.column(j);
1112                let valid_values: Vec<f64> =
1113                    column.iter().filter(|&&x| !x.is_nan()).cloned().collect();
1114
1115                if !valid_values.is_empty() {
1116                    let mean = SimdStatistics::mean_simd(&valid_values);
1117
1118                    for i in 0..chunk.nrows() {
1119                        if chunk[[i, j]].is_nan() {
1120                            result[[i, j]] = mean;
1121                        }
1122                    }
1123                }
1124            }
1125
1126            Ok(result)
1127        };
1128
1129        let result = MemoryEfficientImputer::impute_chunked(&data, 3, simple_impute_fn).unwrap();
1130
1131        // Should have no missing values
1132        assert!(!result.iter().any(|&x| x.is_nan()));
1133
1134        // Should preserve non-missing values
1135        assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1136        assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
1137        assert_abs_diff_eq!(result[[0, 2]], 3.0, epsilon = 1e-10);
1138    }
1139
1140    #[test]
1141    fn test_streaming_imputer() {
1142        // Create a mock data stream
1143        let data_rows = vec![
1144            Array1::from_vec(vec![1.0, 2.0, 3.0]),
1145            Array1::from_vec(vec![f64::NAN, 5.0, 6.0]),
1146            Array1::from_vec(vec![7.0, f64::NAN, 9.0]),
1147            Array1::from_vec(vec![10.0, 11.0, 12.0]),
1148            Array1::from_vec(vec![13.0, 14.0, f64::NAN]),
1149        ];
1150
1151        let data_stream = data_rows.into_iter().map(Ok);
1152
1153        let imputer = StreamingImputer::new()
1154            .buffer_size(3)
1155            .strategy("mean".to_string());
1156
1157        let results = imputer.fit_transform_stream(data_stream).unwrap();
1158
1159        // Should have same number of rows
1160        assert_eq!(results.len(), 5);
1161
1162        // Should have no missing values
1163        for result_row in &results {
1164            assert!(!result_row.iter().any(|&x| x.is_nan()));
1165        }
1166
1167        // Non-missing values should be preserved
1168        assert_abs_diff_eq!(results[0][0], 1.0, epsilon = 1e-10);
1169        assert_abs_diff_eq!(results[0][1], 2.0, epsilon = 1e-10);
1170        assert_abs_diff_eq!(results[0][2], 3.0, epsilon = 1e-10);
1171    }
1172
1173    #[test]
1174    fn test_adaptive_streaming_imputer() {
1175        let mut imputer = AdaptiveStreamingImputer::new()
1176            .learning_rate(0.1)
1177            .forgetting_factor(0.95);
1178
1179        // Process several rows to build up statistics
1180        let rows = vec![
1181            Array1::from_vec(vec![1.0, 2.0, 3.0]),
1182            Array1::from_vec(vec![4.0, 5.0, 6.0]),
1183            Array1::from_vec(vec![7.0, 8.0, 9.0]),
1184            Array1::from_vec(vec![f64::NAN, 11.0, 12.0]),
1185        ];
1186
1187        let mut results = Vec::new();
1188
1189        for row in &rows {
1190            let result = imputer.fit_transform_single(row).unwrap();
1191            results.push(result);
1192        }
1193
1194        // Should have no missing values
1195        for result_row in &results {
1196            assert!(!result_row.iter().any(|&x| x.is_nan()));
1197        }
1198
1199        // Non-missing values should be preserved
1200        assert_abs_diff_eq!(results[0][0], 1.0, epsilon = 1e-10);
1201        assert_abs_diff_eq!(results[1][1], 5.0, epsilon = 1e-10);
1202        assert_abs_diff_eq!(results[2][2], 9.0, epsilon = 1e-10);
1203    }
1204
1205    #[test]
1206    fn test_online_statistics() {
1207        let mut stats = OnlineStatistics::new();
1208
1209        // Add some data points
1210        let rows = vec![
1211            Array1::from_vec(vec![1.0, 2.0, 3.0]),
1212            Array1::from_vec(vec![4.0, 5.0, 6.0]),
1213            Array1::from_vec(vec![7.0, 8.0, 9.0]),
1214        ];
1215
1216        for row in &rows {
1217            stats.update(row);
1218        }
1219
1220        // Check means (should be 4.0, 5.0, 6.0)
1221        assert_abs_diff_eq!(stats.get_mean(0), 4.0, epsilon = 1e-10);
1222        assert_abs_diff_eq!(stats.get_mean(1), 5.0, epsilon = 1e-10);
1223        assert_abs_diff_eq!(stats.get_mean(2), 6.0, epsilon = 1e-10);
1224
1225        // Check sample count
1226        assert_eq!(stats.n_samples, 3);
1227    }
1228
1229    #[test]
1230    fn test_memory_efficient_stream_processing() {
1231        let data_stream = (0..100).map(|i| {
1232            let row = Array1::from_vec(vec![
1233                i as f64,
1234                (i * 2) as f64,
1235                if i % 10 == 0 {
1236                    f64::NAN
1237                } else {
1238                    (i * 3) as f64
1239                },
1240            ]);
1241            Ok(row)
1242        });
1243
1244        let impute_fn = |row: &Array1<f64>| -> SklResult<Array1<f64>> {
1245            let mut result = row.clone();
1246            for (i, value) in result.iter_mut().enumerate() {
1247                if value.is_nan() {
1248                    *value = i as f64; // Simple fallback
1249                }
1250            }
1251            Ok(result)
1252        };
1253
1254        let processed_stream: Vec<_> =
1255            MemoryEfficientImputer::stream_impute(data_stream, impute_fn)
1256                .collect::<SklResult<Vec<_>>>()
1257                .unwrap();
1258
1259        assert_eq!(processed_stream.len(), 100);
1260
1261        // Should have no missing values
1262        for row in &processed_stream {
1263            assert!(!row.iter().any(|&x| x.is_nan()));
1264        }
1265    }
1266}
1267
1268// Additional memory efficiency improvements
1269
1270/// Sparse matrix representation for missing data patterns
1271#[derive(Debug, Clone)]
1272pub struct SparseMatrix {
1273    /// Row indices for non-missing values
1274    pub row_indices: Vec<usize>,
1275    /// Column indices for non-missing values  
1276    pub col_indices: Vec<usize>,
1277    /// Non-missing values
1278    pub values: Vec<f64>,
1279    /// Matrix dimensions
1280    pub shape: (usize, usize),
1281    /// Sparsity ratio (fraction of missing values)
1282    pub sparsity: f64,
1283}
1284
1285impl SparseMatrix {
1286    /// Create a sparse matrix from a dense array
1287    pub fn from_dense(array: &Array2<f64>, missing_value: f64) -> Self {
1288        let (n_rows, n_cols) = array.dim();
1289        let mut row_indices = Vec::new();
1290        let mut col_indices = Vec::new();
1291        let mut values = Vec::new();
1292        let mut missing_count = 0;
1293
1294        for i in 0..n_rows {
1295            for j in 0..n_cols {
1296                let value = array[[i, j]];
1297                let is_missing = if missing_value.is_nan() {
1298                    value.is_nan()
1299                } else {
1300                    (value - missing_value).abs() < f64::EPSILON
1301                };
1302
1303                if !is_missing {
1304                    row_indices.push(i);
1305                    col_indices.push(j);
1306                    values.push(value);
1307                } else {
1308                    missing_count += 1;
1309                }
1310            }
1311        }
1312
1313        let total_elements = n_rows * n_cols;
1314        let sparsity = missing_count as f64 / total_elements as f64;
1315
1316        Self {
1317            row_indices,
1318            col_indices,
1319            values,
1320            shape: (n_rows, n_cols),
1321            sparsity,
1322        }
1323    }
1324
1325    /// Convert back to dense array with missing values filled
1326    pub fn to_dense(&self, missing_value: f64) -> Array2<f64> {
1327        let mut array = Array2::from_elem(self.shape, missing_value);
1328
1329        for ((i, j), &value) in self
1330            .row_indices
1331            .iter()
1332            .zip(self.col_indices.iter())
1333            .zip(self.values.iter())
1334        {
1335            array[[*i, *j]] = value;
1336        }
1337
1338        array
1339    }
1340
1341    /// Get non-missing value at specific coordinates
1342    pub fn get(&self, row: usize, col: usize) -> Option<f64> {
1343        for ((r, c), &value) in self
1344            .row_indices
1345            .iter()
1346            .zip(self.col_indices.iter())
1347            .zip(self.values.iter())
1348        {
1349            if *r == row && *c == col {
1350                return Some(value);
1351            }
1352        }
1353        None
1354    }
1355
1356    /// Check if the matrix is sparse enough to benefit from sparse representation
1357    pub fn is_beneficial(&self) -> bool {
1358        self.sparsity > 0.5 // More than 50% missing values
1359    }
1360
1361    /// Calculate memory savings compared to dense representation
1362    pub fn memory_savings(&self) -> f64 {
1363        let dense_size = self.shape.0 * self.shape.1 * std::mem::size_of::<f64>();
1364        let sparse_size = (self.values.len() * std::mem::size_of::<f64>())
1365            + (self.row_indices.len() * std::mem::size_of::<usize>())
1366            + (self.col_indices.len() * std::mem::size_of::<usize>());
1367
1368        1.0 - (sparse_size as f64 / dense_size as f64)
1369    }
1370}
1371
1372/// Memory-mapped data operations for large datasets
1373pub struct MemoryMappedData {
1374    /// File path for memory-mapped data
1375    file_path: std::path::PathBuf,
1376    /// Data dimensions
1377    shape: (usize, usize),
1378    /// Data type size in bytes
1379    dtype_size: usize,
1380}
1381
1382impl MemoryMappedData {
1383    /// Create a new memory-mapped data structure
1384    pub fn new(file_path: std::path::PathBuf, shape: (usize, usize)) -> Self {
1385        Self {
1386            file_path,
1387            shape,
1388            dtype_size: std::mem::size_of::<f64>(),
1389        }
1390    }
1391
1392    /// Write array data to memory-mapped file
1393    pub fn write_array(&self, array: &Array2<f64>) -> SklResult<()> {
1394        if array.dim() != self.shape {
1395            return Err(SklearsError::InvalidInput(
1396                "Array dimensions don't match memory-mapped data shape".to_string(),
1397            ));
1398        }
1399
1400        use std::fs::File;
1401        use std::io::Write;
1402
1403        let mut file = File::create(&self.file_path)
1404            .map_err(|e| SklearsError::InvalidInput(format!("Failed to create file: {}", e)))?;
1405
1406        // Write raw bytes
1407        let data_slice = array
1408            .as_slice()
1409            .ok_or_else(|| SklearsError::InvalidInput("Array is not contiguous".to_string()))?;
1410
1411        let bytes = unsafe {
1412            std::slice::from_raw_parts(
1413                data_slice.as_ptr() as *const u8,
1414                data_slice.len() * self.dtype_size,
1415            )
1416        };
1417
1418        file.write_all(bytes)
1419            .map_err(|e| SklearsError::InvalidInput(format!("Failed to write data: {}", e)))?;
1420
1421        Ok(())
1422    }
1423
1424    /// Read array data from memory-mapped file  
1425    pub fn read_array(&self) -> SklResult<Array2<f64>> {
1426        use std::fs::File;
1427        use std::io::Read;
1428
1429        let mut file = File::open(&self.file_path)
1430            .map_err(|e| SklearsError::InvalidInput(format!("Failed to open file: {}", e)))?;
1431
1432        let expected_size = self.shape.0 * self.shape.1 * self.dtype_size;
1433        let mut buffer = vec![0u8; expected_size];
1434
1435        file.read_exact(&mut buffer)
1436            .map_err(|e| SklearsError::InvalidInput(format!("Failed to read data: {}", e)))?;
1437
1438        // Convert bytes back to f64 array
1439        let data_slice = unsafe {
1440            std::slice::from_raw_parts(buffer.as_ptr() as *const f64, self.shape.0 * self.shape.1)
1441        };
1442
1443        Array2::from_shape_vec(self.shape, data_slice.to_vec())
1444            .map_err(|e| SklearsError::InvalidInput(format!("Failed to reshape array: {}", e)))
1445    }
1446
1447    /// Get estimated memory usage
1448    pub fn memory_usage(&self) -> usize {
1449        self.shape.0 * self.shape.1 * self.dtype_size
1450    }
1451
1452    /// Check if file exists
1453    pub fn exists(&self) -> bool {
1454        self.file_path.exists()
1455    }
1456}
1457
1458/// Reference-counted shared data for efficient memory usage
1459#[derive(Debug, Clone)]
1460pub struct SharedDataRef<T> {
1461    data: Arc<T>,
1462    refs: Arc<Mutex<usize>>,
1463}
1464
1465impl<T> SharedDataRef<T> {
1466    /// Create a new shared data reference
1467    pub fn new(data: T) -> Self {
1468        Self {
1469            data: Arc::new(data),
1470            refs: Arc::new(Mutex::new(1)),
1471        }
1472    }
1473
1474    /// Get a reference to the data
1475    pub fn get(&self) -> &T {
1476        &self.data
1477    }
1478
1479    /// Get the current reference count
1480    pub fn ref_count(&self) -> usize {
1481        *self.refs.lock().unwrap()
1482    }
1483
1484    /// Check if this is the only reference
1485    pub fn is_unique(&self) -> bool {
1486        Arc::strong_count(&self.data) == 1
1487    }
1488}
1489
1490impl<T: Clone> SharedDataRef<T> {
1491    /// Make the data mutable by cloning if necessary
1492    pub fn make_mut(&mut self) -> &mut T {
1493        if Arc::strong_count(&self.data) > 1 {
1494            self.data = Arc::new((*self.data).clone());
1495        }
1496        Arc::get_mut(&mut self.data).unwrap()
1497    }
1498}
1499
1500/// Memory-efficient imputation strategies
1501pub struct MemoryOptimizedImputer {
1502    strategy: MemoryStrategy,
1503    chunk_size: usize,
1504    use_sparse: bool,
1505    use_mmap: bool,
1506    temp_dir: std::path::PathBuf,
1507}
1508
1509#[derive(Debug, Clone)]
1510pub enum MemoryStrategy {
1511    /// Process data in chunks to limit memory usage
1512    Chunked,
1513    /// Use sparse representations for high-sparsity data
1514    Sparse,
1515    /// Use memory-mapped files for very large datasets
1516    MemoryMapped,
1517    /// Combine multiple strategies
1518    Hybrid,
1519}
1520
1521impl MemoryOptimizedImputer {
1522    /// Create a new memory-optimized imputer
1523    pub fn new() -> Self {
1524        Self {
1525            strategy: MemoryStrategy::Hybrid,
1526            chunk_size: 1000,
1527            use_sparse: true,
1528            use_mmap: false,
1529            temp_dir: std::env::temp_dir(),
1530        }
1531    }
1532
1533    /// Set the memory optimization strategy
1534    pub fn strategy(mut self, strategy: MemoryStrategy) -> Self {
1535        self.strategy = strategy;
1536        self
1537    }
1538
1539    /// Set chunk size for chunked processing
1540    pub fn chunk_size(mut self, chunk_size: usize) -> Self {
1541        self.chunk_size = chunk_size;
1542        self
1543    }
1544
1545    /// Enable/disable sparse matrix optimization
1546    pub fn use_sparse(mut self, use_sparse: bool) -> Self {
1547        self.use_sparse = use_sparse;
1548        self
1549    }
1550
1551    /// Enable/disable memory mapping
1552    pub fn use_memory_mapping(mut self, use_mmap: bool) -> Self {
1553        self.use_mmap = use_mmap;
1554        self
1555    }
1556
1557    /// Set temporary directory for memory-mapped files
1558    pub fn temp_dir(mut self, temp_dir: std::path::PathBuf) -> Self {
1559        self.temp_dir = temp_dir;
1560        self
1561    }
1562
1563    /// Estimate memory requirements for a dataset
1564    pub fn estimate_memory_usage(&self, shape: (usize, usize)) -> usize {
1565        let base_size = shape.0 * shape.1 * std::mem::size_of::<f64>();
1566
1567        match self.strategy {
1568            MemoryStrategy::Chunked => {
1569                // Only load chunk_size rows at a time
1570                self.chunk_size * shape.1 * std::mem::size_of::<f64>()
1571            }
1572            MemoryStrategy::Sparse => {
1573                // Assume 50% sparsity for estimation
1574                base_size / 2
1575            }
1576            MemoryStrategy::MemoryMapped => {
1577                // Minimal memory usage, just metadata
1578                1024
1579            }
1580            MemoryStrategy::Hybrid => {
1581                // Use the most efficient strategy for the given size
1582                if base_size > 1_000_000_000 {
1583                    // 1GB
1584                    1024 // Memory mapped
1585                } else if base_size > 100_000_000 {
1586                    // 100MB
1587                    self.chunk_size * shape.1 * std::mem::size_of::<f64>() // Chunked
1588                } else {
1589                    base_size / 2 // Sparse if beneficial
1590                }
1591            }
1592        }
1593    }
1594
1595    /// Check if the dataset would benefit from sparse representation
1596    pub fn should_use_sparse(&self, array: &Array2<f64>) -> bool {
1597        if !self.use_sparse {
1598            return false;
1599        }
1600
1601        // Quick sparsity check on a sample
1602        let sample_size = 1000.min(array.len());
1603        let mut missing_count = 0;
1604
1605        for &value in array.iter().take(sample_size) {
1606            if value.is_nan() {
1607                missing_count += 1;
1608            }
1609        }
1610
1611        let sparsity = missing_count as f64 / sample_size as f64;
1612        sparsity > 0.5 // Use sparse if more than 50% missing
1613    }
1614
1615    /// Process large dataset with memory optimization
1616    pub fn process_large_dataset<F>(
1617        &self,
1618        array: &Array2<f64>,
1619        mut processor: F,
1620    ) -> SklResult<Array2<f64>>
1621    where
1622        F: FnMut(&ArrayView2<f64>) -> SklResult<Array2<f64>>,
1623    {
1624        let (n_rows, n_cols) = array.dim();
1625
1626        match self.strategy {
1627            MemoryStrategy::Chunked | MemoryStrategy::Hybrid => {
1628                let mut result = Array2::zeros((n_rows, n_cols));
1629
1630                // Process in chunks
1631                for chunk_start in (0..n_rows).step_by(self.chunk_size) {
1632                    let chunk_end = (chunk_start + self.chunk_size).min(n_rows);
1633                    let chunk = array.slice(s![chunk_start..chunk_end, ..]);
1634                    let processed_chunk = processor(&chunk)?;
1635
1636                    result
1637                        .slice_mut(s![chunk_start..chunk_end, ..])
1638                        .assign(&processed_chunk);
1639                }
1640
1641                Ok(result)
1642            }
1643            MemoryStrategy::Sparse => {
1644                if self.should_use_sparse(array) {
1645                    let sparse = SparseMatrix::from_dense(array, f64::NAN);
1646                    let dense = sparse.to_dense(f64::NAN);
1647                    processor(&dense.view())
1648                } else {
1649                    processor(&array.view())
1650                }
1651            }
1652            MemoryStrategy::MemoryMapped => {
1653                if self.use_mmap {
1654                    let temp_file = self.temp_dir.join("temp_data.bin");
1655                    let mmap = MemoryMappedData::new(temp_file, (n_rows, n_cols));
1656                    mmap.write_array(array)?;
1657                    let loaded_array = mmap.read_array()?;
1658                    processor(&loaded_array.view())
1659                } else {
1660                    processor(&array.view())
1661                }
1662            }
1663        }
1664    }
1665}
1666
1667impl Default for MemoryOptimizedImputer {
1668    fn default() -> Self {
1669        Self::new()
1670    }
1671}