sklears_impute/
distributed.rs

1//! Distributed imputation algorithms for large-scale missing data processing
2//!
3//! This module provides distributed implementations that can process datasets
4//! across multiple machines or cores, enabling imputation of very large datasets
5//! that don't fit on a single machine.
6
7// ✅ SciRS2 Policy compliant imports
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9// use scirs2_core::parallel::{LoadBalancer}; // Note: ParallelExecutor, ChunkStrategy not available
10// use scirs2_core::memory_efficient::{ChunkedArray, AdaptiveChunking}; // Note: memory_efficient feature-gated
11// use scirs2_core::simd::{SimdOps}; // Note: SimdArray not available
12
13use crate::core::{ImputationError, ImputationMetadata, Imputer};
14use crate::simple::SimpleImputer;
15use rayon::prelude::*;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18    error::{Result as SklResult, SklearsError},
19    traits::{Estimator, Fit, Transform, Untrained},
20    types::Float,
21};
22use std::collections::HashMap;
23use std::sync::{Arc, Mutex, RwLock};
24use std::thread;
25use std::time::{Duration, Instant};
26
27/// Configuration for distributed imputation
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct DistributedConfig {
30    /// Number of worker nodes/processes
31    pub num_workers: usize,
32    /// Chunk size for data partitioning
33    pub chunk_size: usize,
34    /// Communication strategy between workers
35    pub communication_strategy: CommunicationStrategy,
36    /// Load balancing enabled
37    pub load_balancing: bool,
38    /// Fault tolerance enabled
39    pub fault_tolerance: bool,
40    /// Maximum retry attempts for failed operations
41    pub max_retries: usize,
42    /// Timeout for worker operations (in seconds)
43    pub worker_timeout: Duration,
44}
45
46impl Default for DistributedConfig {
47    fn default() -> Self {
48        Self {
49            num_workers: num_cpus::get(),
50            chunk_size: 10000,
51            communication_strategy: CommunicationStrategy::SharedMemory,
52            load_balancing: true,
53            fault_tolerance: true,
54            max_retries: 3,
55            worker_timeout: Duration::from_secs(300),
56        }
57    }
58}
59
60/// Communication strategies for distributed processing
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum CommunicationStrategy {
63    /// Use shared memory for communication (single machine)
64    SharedMemory,
65    /// Use message passing for communication (multi-machine)
66    MessagePassing,
67    /// Use parameter server architecture
68    ParameterServer,
69    /// Use all-reduce communication pattern
70    AllReduce,
71}
72
73/// Distributed data partition
74#[derive(Debug, Clone)]
75pub struct DataPartition {
76    /// Partition identifier
77    pub id: usize,
78    /// Start row index
79    pub start_row: usize,
80    /// End row index
81    pub end_row: usize,
82    /// Data chunk
83    pub data: Array2<f64>,
84    /// Missing value mask
85    pub missing_mask: Array2<bool>,
86    /// Partition metadata
87    pub metadata: PartitionMetadata,
88}
89
90/// Metadata for data partitions
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct PartitionMetadata {
93    /// partition_id
94    pub partition_id: usize,
95    /// worker_id
96    pub worker_id: usize,
97    /// num_samples
98    pub num_samples: usize,
99    /// num_features
100    pub num_features: usize,
101    /// missing_ratio
102    pub missing_ratio: f64,
103    /// processing_time
104    pub processing_time: Duration,
105    /// memory_usage
106    pub memory_usage: usize,
107}
108
109/// Worker node for distributed processing
110pub struct DistributedWorker {
111    /// id
112    pub id: usize,
113    /// config
114    pub config: DistributedConfig,
115    /// partitions
116    pub partitions: Vec<DataPartition>,
117    /// local_imputer
118    pub local_imputer: Box<dyn Imputer + Send + Sync>,
119    /// statistics
120    pub statistics: WorkerStatistics,
121}
122
123/// Statistics tracked by each worker
124#[derive(Debug, Default, Clone)]
125pub struct WorkerStatistics {
126    /// samples_processed
127    pub samples_processed: usize,
128    /// features_imputed
129    pub features_imputed: usize,
130    /// processing_time
131    pub processing_time: Duration,
132    /// memory_peak
133    pub memory_peak: usize,
134    /// errors_count
135    pub errors_count: usize,
136    /// retries_count
137    pub retries_count: usize,
138}
139
140/// Distributed KNN Imputer
141pub struct DistributedKNNImputer<S = Untrained> {
142    state: S,
143    n_neighbors: usize,
144    weights: String,
145    missing_values: f64,
146    config: DistributedConfig,
147    workers: Vec<DistributedWorker>,
148    coordinator: Option<ImputationCoordinator>,
149}
150
151/// Trained state for distributed KNN imputer
152pub struct DistributedKNNImputerTrained {
153    reference_data: Arc<RwLock<Array2<f64>>>,
154    n_features_in_: usize,
155    config: DistributedConfig,
156    workers: Vec<Arc<Mutex<DistributedWorker>>>,
157    coordinator: ImputationCoordinator,
158}
159
160/// Coordinator for managing distributed imputation
161#[derive(Debug)]
162pub struct ImputationCoordinator {
163    /// config
164    pub config: DistributedConfig,
165    /// workers
166    pub workers: HashMap<usize, WorkerHandle>,
167    /// data_partitioner
168    pub data_partitioner: DataPartitioner,
169    /// result_aggregator
170    pub result_aggregator: ResultAggregator,
171    /// fault_handler
172    pub fault_handler: FaultHandler,
173}
174
175/// Handle for a worker process/thread
176#[derive(Debug)]
177pub struct WorkerHandle {
178    /// id
179    pub id: usize,
180    /// thread_handle
181    pub thread_handle: Option<thread::JoinHandle<Result<WorkerResult, ImputationError>>>,
182    /// status
183    pub status: WorkerStatus,
184    /// last_heartbeat
185    pub last_heartbeat: Instant,
186}
187
188/// Status of a worker
189#[derive(Debug, Clone, PartialEq)]
190pub enum WorkerStatus {
191    /// Idle
192    Idle,
193    /// Processing
194    Processing,
195    /// Completed
196    Completed,
197    /// Failed
198    Failed,
199    /// Timeout
200    Timeout,
201}
202
203/// Result from a worker
204#[derive(Debug, Clone)]
205pub struct WorkerResult {
206    /// worker_id
207    pub worker_id: usize,
208    /// partition_id
209    pub partition_id: usize,
210    /// imputed_data
211    pub imputed_data: Array2<f64>,
212    /// statistics
213    pub statistics: WorkerStatistics,
214    /// metadata
215    pub metadata: ImputationMetadata,
216}
217
218/// Data partitioning strategies
219#[derive(Debug)]
220pub struct DataPartitioner {
221    strategy: PartitioningStrategy,
222}
223
224/// Partitioning strategies
225#[derive(Debug, Clone)]
226pub enum PartitioningStrategy {
227    /// Horizontal partitioning (row-wise)
228    Horizontal,
229    /// Vertical partitioning (column-wise)
230    Vertical,
231    /// Random partitioning
232    Random,
233    /// Stratified partitioning based on missing patterns
234    Stratified,
235    /// Hash-based partitioning
236    Hash,
237}
238
239/// Result aggregation strategies
240#[derive(Debug)]
241pub struct ResultAggregator {
242    strategy: AggregationStrategy,
243}
244
245/// Aggregation strategies for combining results
246#[derive(Debug, Clone)]
247pub enum AggregationStrategy {
248    /// Simple concatenation
249    Concatenate,
250    /// Weighted averaging
251    WeightedAverage,
252    /// Consensus-based aggregation
253    Consensus,
254    /// Model averaging
255    ModelAveraging,
256}
257
258/// Fault handling for distributed processing
259#[derive(Debug)]
260pub struct FaultHandler {
261    /// max_retries
262    pub max_retries: usize,
263    /// retry_delay
264    pub retry_delay: Duration,
265    /// checkpointing_enabled
266    pub checkpointing_enabled: bool,
267    /// checkpoint_interval
268    pub checkpoint_interval: Duration,
269}
270
271impl DistributedKNNImputer<Untrained> {
272    /// Create a new distributed KNN imputer
273    pub fn new() -> Self {
274        Self {
275            state: Untrained,
276            n_neighbors: 5,
277            weights: "uniform".to_string(),
278            missing_values: f64::NAN,
279            config: DistributedConfig::default(),
280            workers: Vec::new(),
281            coordinator: None,
282        }
283    }
284
285    /// Set the number of neighbors
286    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
287        self.n_neighbors = n_neighbors;
288        self
289    }
290
291    /// Set the weight function
292    pub fn weights(mut self, weights: String) -> Self {
293        self.weights = weights;
294        self
295    }
296
297    /// Set the distributed configuration
298    pub fn distributed_config(mut self, config: DistributedConfig) -> Self {
299        self.config = config;
300        self
301    }
302
303    /// Set the number of workers
304    pub fn num_workers(mut self, num_workers: usize) -> Self {
305        self.config.num_workers = num_workers;
306        self
307    }
308
309    /// Set the chunk size
310    pub fn chunk_size(mut self, chunk_size: usize) -> Self {
311        self.config.chunk_size = chunk_size;
312        self
313    }
314
315    /// Set the communication strategy
316    pub fn communication_strategy(mut self, strategy: CommunicationStrategy) -> Self {
317        self.config.communication_strategy = strategy;
318        self
319    }
320
321    /// Enable fault tolerance
322    pub fn fault_tolerance(mut self, enabled: bool) -> Self {
323        self.config.fault_tolerance = enabled;
324        self
325    }
326
327    fn is_missing(&self, value: f64) -> bool {
328        if self.missing_values.is_nan() {
329            value.is_nan()
330        } else {
331            (value - self.missing_values).abs() < f64::EPSILON
332        }
333    }
334}
335
336impl Default for DistributedKNNImputer<Untrained> {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342impl Estimator for DistributedKNNImputer<Untrained> {
343    type Config = DistributedConfig;
344    type Error = SklearsError;
345    type Float = Float;
346
347    fn config(&self) -> &Self::Config {
348        &self.config
349    }
350}
351
352impl Fit<ArrayView2<'_, Float>, ()> for DistributedKNNImputer<Untrained> {
353    type Fitted = DistributedKNNImputer<DistributedKNNImputerTrained>;
354
355    #[allow(non_snake_case)]
356    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
357        let X = X.mapv(|x| x);
358        let (n_samples, n_features) = X.dim();
359
360        if n_samples < self.config.num_workers {
361            return Err(SklearsError::InvalidInput(
362                "Dataset too small for distributed processing. Use regular KNN imputer."
363                    .to_string(),
364            ));
365        }
366
367        // Create data partitioner
368        let data_partitioner = DataPartitioner {
369            strategy: PartitioningStrategy::Horizontal,
370        };
371
372        // Create result aggregator
373        let result_aggregator = ResultAggregator {
374            strategy: AggregationStrategy::Concatenate,
375        };
376
377        // Create fault handler
378        let fault_handler = FaultHandler {
379            max_retries: self.config.max_retries,
380            retry_delay: Duration::from_secs(1),
381            checkpointing_enabled: false,
382            checkpoint_interval: Duration::from_secs(60),
383        };
384
385        // Create coordinator
386        let coordinator = ImputationCoordinator {
387            config: self.config.clone(),
388            workers: HashMap::new(),
389            data_partitioner,
390            result_aggregator,
391            fault_handler,
392        };
393
394        // Initialize workers
395        let mut workers = Vec::new();
396        for worker_id in 0..self.config.num_workers {
397            let worker = DistributedWorker {
398                id: worker_id,
399                config: self.config.clone(),
400                partitions: Vec::new(),
401                local_imputer: Box::new(SimpleImputer::default()),
402                statistics: WorkerStatistics::default(),
403            };
404            workers.push(Arc::new(Mutex::new(worker)));
405        }
406
407        Ok(DistributedKNNImputer {
408            state: DistributedKNNImputerTrained {
409                reference_data: Arc::new(RwLock::new(X.clone())),
410                n_features_in_: n_features,
411                config: self.config,
412                workers,
413                coordinator,
414            },
415            n_neighbors: self.n_neighbors,
416            weights: self.weights,
417            missing_values: self.missing_values,
418            config: Default::default(),
419            workers: Vec::new(),
420            coordinator: None,
421        })
422    }
423}
424
425impl Transform<ArrayView2<'_, Float>, Array2<Float>>
426    for DistributedKNNImputer<DistributedKNNImputerTrained>
427{
428    #[allow(non_snake_case)]
429    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
430        let X = X.mapv(|x| x);
431        let (_n_samples, n_features) = X.dim();
432
433        if n_features != self.state.n_features_in_ {
434            return Err(SklearsError::InvalidInput(format!(
435                "Number of features {} does not match training features {}",
436                n_features, self.state.n_features_in_
437            )));
438        }
439
440        // Partition data across workers
441        let partitions = self.partition_data(&X)?;
442
443        // Process partitions in parallel using workers
444        let results = self.process_partitions_distributed(partitions)?;
445
446        // Aggregate results
447        let X_imputed = self.aggregate_results(results)?;
448
449        Ok(X_imputed.mapv(|x| x as Float))
450    }
451}
452
453impl DistributedKNNImputer<DistributedKNNImputerTrained> {
454    /// Partition data for distributed processing
455    fn partition_data(&self, X: &Array2<f64>) -> Result<Vec<DataPartition>, ImputationError> {
456        let (n_samples, _n_features) = X.dim();
457        let chunk_size = self
458            .state
459            .config
460            .chunk_size
461            .min(n_samples / self.state.config.num_workers);
462        let mut partitions = Vec::new();
463
464        for (partition_id, chunk) in X.axis_chunks_iter(Axis(0), chunk_size).enumerate() {
465            let start_row = partition_id * chunk_size;
466            let end_row = (start_row + chunk.nrows()).min(n_samples);
467
468            // Create missing value mask
469            let mut missing_mask = Array2::<bool>::from_elem(chunk.dim(), false);
470            for ((i, j), &value) in chunk.indexed_iter() {
471                missing_mask[[i, j]] = self.is_missing(value);
472            }
473
474            // Calculate missing ratio
475            let total_elements = chunk.len();
476            let missing_count = missing_mask.iter().filter(|&&x| x).count();
477            let missing_ratio = missing_count as f64 / total_elements as f64;
478
479            let metadata = PartitionMetadata {
480                partition_id,
481                worker_id: partition_id % self.state.config.num_workers,
482                num_samples: chunk.nrows(),
483                num_features: chunk.ncols(),
484                missing_ratio,
485                processing_time: Duration::default(),
486                memory_usage: chunk.len() * std::mem::size_of::<f64>(),
487            };
488
489            partitions.push(DataPartition {
490                id: partition_id,
491                start_row,
492                end_row,
493                data: chunk.to_owned(),
494                missing_mask,
495                metadata,
496            });
497        }
498
499        Ok(partitions)
500    }
501
502    /// Process partitions using distributed workers
503    fn process_partitions_distributed(
504        &self,
505        partitions: Vec<DataPartition>,
506    ) -> Result<Vec<WorkerResult>, ImputationError> {
507        let reference_data = self.state.reference_data.clone();
508        let n_neighbors = self.n_neighbors;
509        let weights = self.weights.clone();
510        let _missing_values = self.missing_values;
511
512        // Process partitions in parallel
513        let results: Result<Vec<_>, _> = partitions
514            .into_par_iter()
515            .map(|partition| -> Result<WorkerResult, ImputationError> {
516                let start_time = Instant::now();
517                let worker_id = partition.metadata.worker_id;
518
519                // Access reference data
520                let ref_data = reference_data.read().map_err(|_| {
521                    ImputationError::ProcessingError("Failed to access reference data".to_string())
522                })?;
523
524                // Perform KNN imputation on this partition
525                let mut imputed_data = partition.data.clone();
526
527                for i in 0..imputed_data.nrows() {
528                    for j in 0..imputed_data.ncols() {
529                        if partition.missing_mask[[i, j]] {
530                            // Find k nearest neighbors from reference data
531                            let query_row = imputed_data.row(i);
532                            let query_row_2d = query_row.insert_axis(Axis(0));
533                            let neighbors =
534                                self.find_knn_neighbors(&ref_data, query_row_2d, n_neighbors, j)?;
535
536                            // Compute weighted average
537                            let imputed_value =
538                                self.compute_weighted_average(&neighbors, &weights)?;
539                            imputed_data[[i, j]] = imputed_value;
540                        }
541                    }
542                }
543
544                let processing_time = start_time.elapsed();
545
546                let statistics = WorkerStatistics {
547                    samples_processed: partition.metadata.num_samples,
548                    features_imputed: partition.missing_mask.iter().filter(|&&x| x).count(),
549                    processing_time,
550                    memory_peak: partition.metadata.memory_usage,
551                    errors_count: 0,
552                    retries_count: 0,
553                };
554
555                let metadata = ImputationMetadata {
556                    method: "DistributedKNN".to_string(),
557                    parameters: {
558                        let mut params = std::collections::HashMap::new();
559                        params.insert("n_neighbors".to_string(), n_neighbors.to_string());
560                        params.insert("weights".to_string(), weights.clone());
561                        params
562                    },
563                    processing_time_ms: Some(processing_time.as_millis() as u64),
564                    n_imputed: partition.missing_mask.iter().filter(|&&x| x).count(),
565                    convergence_info: None,
566                    quality_metrics: None,
567                };
568
569                Ok(WorkerResult {
570                    worker_id,
571                    partition_id: partition.id,
572                    imputed_data,
573                    statistics,
574                    metadata,
575                })
576            })
577            .collect();
578
579        results.map_err(|_| {
580            ImputationError::ProcessingError("Distributed processing failed".to_string())
581        })
582    }
583
584    /// Find k nearest neighbors for imputation
585    fn find_knn_neighbors(
586        &self,
587        reference_data: &Array2<f64>,
588        query_row: ArrayView2<f64>,
589        k: usize,
590        target_feature: usize,
591    ) -> Result<Vec<(f64, f64)>, ImputationError> {
592        let mut neighbors = Vec::new();
593
594        for ref_row_idx in 0..reference_data.nrows() {
595            let ref_row = reference_data.row(ref_row_idx);
596
597            // Skip if reference row has missing value for target feature
598            if self.is_missing(ref_row[target_feature]) {
599                continue;
600            }
601
602            // Calculate distance (ignoring missing values)
603            let distance = self.calculate_nan_euclidean_distance(query_row.row(0), ref_row);
604
605            if distance.is_finite() {
606                neighbors.push((distance, ref_row[target_feature]));
607            }
608        }
609
610        // Sort by distance and take k nearest
611        neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
612        neighbors.truncate(k);
613
614        Ok(neighbors)
615    }
616
617    /// Calculate Euclidean distance ignoring NaN values
618    fn calculate_nan_euclidean_distance(
619        &self,
620        row1: ArrayView1<f64>,
621        row2: ArrayView1<f64>,
622    ) -> f64 {
623        let mut sum_sq = 0.0;
624        let mut valid_count = 0;
625
626        for (&x1, &x2) in row1.iter().zip(row2.iter()) {
627            if !self.is_missing(x1) && !self.is_missing(x2) {
628                sum_sq += (x1 - x2).powi(2);
629                valid_count += 1;
630            }
631        }
632
633        if valid_count > 0 {
634            (sum_sq / valid_count as f64).sqrt()
635        } else {
636            f64::INFINITY
637        }
638    }
639
640    /// Compute weighted average of neighbor values
641    fn compute_weighted_average(
642        &self,
643        neighbors: &[(f64, f64)],
644        weights_type: &str,
645    ) -> Result<f64, ImputationError> {
646        if neighbors.is_empty() {
647            return Err(ImputationError::ProcessingError(
648                "No valid neighbors found".to_string(),
649            ));
650        }
651
652        match weights_type {
653            "uniform" => {
654                let sum: f64 = neighbors.iter().map(|(_, value)| value).sum();
655                Ok(sum / neighbors.len() as f64)
656            }
657            "distance" => {
658                let mut weighted_sum = 0.0;
659                let mut weight_sum = 0.0;
660
661                for &(distance, value) in neighbors {
662                    let weight = if distance > 0.0 { 1.0 / distance } else { 1e6 };
663                    weighted_sum += weight * value;
664                    weight_sum += weight;
665                }
666
667                if weight_sum > 0.0 {
668                    Ok(weighted_sum / weight_sum)
669                } else {
670                    Ok(neighbors[0].1) // Fallback to first neighbor
671                }
672            }
673            _ => Err(ImputationError::InvalidConfiguration(format!(
674                "Unknown weights type: {}",
675                weights_type
676            ))),
677        }
678    }
679
680    /// Aggregate results from distributed workers
681    fn aggregate_results(
682        &self,
683        results: Vec<WorkerResult>,
684    ) -> Result<Array2<f64>, ImputationError> {
685        if results.is_empty() {
686            return Err(ImputationError::ProcessingError(
687                "No results to aggregate".to_string(),
688            ));
689        }
690
691        // Sort results by partition ID to maintain order
692        let mut sorted_results = results;
693        sorted_results.sort_by_key(|r| r.partition_id);
694
695        // Concatenate imputed data
696        let first_result = &sorted_results[0];
697        let n_features = first_result.imputed_data.ncols();
698        let total_rows: usize = sorted_results.iter().map(|r| r.imputed_data.nrows()).sum();
699
700        let mut aggregated_data = Array2::<f64>::zeros((total_rows, n_features));
701        let mut current_row = 0;
702
703        for result in sorted_results {
704            let chunk_rows = result.imputed_data.nrows();
705            aggregated_data
706                .slice_mut(s![current_row..current_row + chunk_rows, ..])
707                .assign(&result.imputed_data);
708            current_row += chunk_rows;
709        }
710
711        Ok(aggregated_data)
712    }
713
714    fn is_missing(&self, value: f64) -> bool {
715        if self.missing_values.is_nan() {
716            value.is_nan()
717        } else {
718            (value - self.missing_values).abs() < f64::EPSILON
719        }
720    }
721}
722
723/// Distributed Simple Imputer for basic imputation strategies
724#[derive(Debug)]
725pub struct DistributedSimpleImputer<S = Untrained> {
726    state: S,
727    strategy: String,
728    missing_values: f64,
729    config: DistributedConfig,
730}
731
732/// Trained state for distributed simple imputer
733#[derive(Debug)]
734pub struct DistributedSimpleImputerTrained {
735    statistics_: Array1<f64>,
736    n_features_in_: usize,
737    config: DistributedConfig,
738}
739
740impl DistributedSimpleImputer<Untrained> {
741    pub fn new() -> Self {
742        Self {
743            state: Untrained,
744            strategy: "mean".to_string(),
745            missing_values: f64::NAN,
746            config: DistributedConfig::default(),
747        }
748    }
749
750    pub fn strategy(mut self, strategy: String) -> Self {
751        self.strategy = strategy;
752        self
753    }
754
755    pub fn distributed_config(mut self, config: DistributedConfig) -> Self {
756        self.config = config;
757        self
758    }
759
760    fn is_missing(&self, value: f64) -> bool {
761        if self.missing_values.is_nan() {
762            value.is_nan()
763        } else {
764            (value - self.missing_values).abs() < f64::EPSILON
765        }
766    }
767}
768
769impl Default for DistributedSimpleImputer<Untrained> {
770    fn default() -> Self {
771        Self::new()
772    }
773}
774
775impl Estimator for DistributedSimpleImputer<Untrained> {
776    type Config = DistributedConfig;
777    type Error = SklearsError;
778    type Float = Float;
779
780    fn config(&self) -> &Self::Config {
781        &self.config
782    }
783}
784
785impl Fit<ArrayView2<'_, Float>, ()> for DistributedSimpleImputer<Untrained> {
786    type Fitted = DistributedSimpleImputer<DistributedSimpleImputerTrained>;
787
788    #[allow(non_snake_case)]
789    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
790        let X = X.mapv(|x| x);
791        let (_, n_features) = X.dim();
792
793        // Compute statistics in parallel across features
794        let statistics: Vec<f64> = (0..n_features)
795            .into_par_iter()
796            .map(|j| {
797                let column = X.column(j);
798                let valid_values: Vec<f64> = column
799                    .iter()
800                    .filter(|&&x| !self.is_missing(x))
801                    .cloned()
802                    .collect();
803
804                if valid_values.is_empty() {
805                    0.0
806                } else {
807                    match self.strategy.as_str() {
808                        "mean" => valid_values.iter().sum::<f64>() / valid_values.len() as f64,
809                        "median" => {
810                            let mut sorted_values = valid_values.clone();
811                            sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
812                            let mid = sorted_values.len() / 2;
813                            if sorted_values.len() % 2 == 0 {
814                                (sorted_values[mid - 1] + sorted_values[mid]) / 2.0
815                            } else {
816                                sorted_values[mid]
817                            }
818                        }
819                        "most_frequent" => {
820                            let mut frequency_map = HashMap::new();
821                            for &value in &valid_values {
822                                *frequency_map.entry(value as i64).or_insert(0) += 1;
823                            }
824                            frequency_map
825                                .into_iter()
826                                .max_by_key(|(_, count)| *count)
827                                .map(|(value, _)| value as f64)
828                                .unwrap_or(0.0)
829                        }
830                        _ => valid_values.iter().sum::<f64>() / valid_values.len() as f64,
831                    }
832                }
833            })
834            .collect();
835
836        Ok(DistributedSimpleImputer {
837            state: DistributedSimpleImputerTrained {
838                statistics_: Array1::from_vec(statistics),
839                n_features_in_: n_features,
840                config: self.config,
841            },
842            strategy: self.strategy,
843            missing_values: self.missing_values,
844            config: Default::default(),
845        })
846    }
847}
848
849impl Transform<ArrayView2<'_, Float>, Array2<Float>>
850    for DistributedSimpleImputer<DistributedSimpleImputerTrained>
851{
852    #[allow(non_snake_case)]
853    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
854        let X = X.mapv(|x| x);
855        let (n_samples, n_features) = X.dim();
856
857        if n_features != self.state.n_features_in_ {
858            return Err(SklearsError::InvalidInput(format!(
859                "Number of features {} does not match training features {}",
860                n_features, self.state.n_features_in_
861            )));
862        }
863
864        // Parallel imputation across rows
865        let imputed_rows: Vec<Array1<f64>> = (0..n_samples)
866            .into_par_iter()
867            .map(|i| {
868                let mut row = X.row(i).to_owned();
869                for j in 0..n_features {
870                    if self.is_missing(row[j]) {
871                        row[j] = self.state.statistics_[j];
872                    }
873                }
874                row
875            })
876            .collect();
877
878        // Reconstruct the array
879        let mut X_imputed = Array2::zeros((n_samples, n_features));
880        for (i, row) in imputed_rows.into_iter().enumerate() {
881            X_imputed.row_mut(i).assign(&row);
882        }
883
884        Ok(X_imputed.mapv(|x| x as Float))
885    }
886}
887
888impl DistributedSimpleImputer<DistributedSimpleImputerTrained> {
889    fn is_missing(&self, value: f64) -> bool {
890        if self.missing_values.is_nan() {
891            value.is_nan()
892        } else {
893            (value - self.missing_values).abs() < f64::EPSILON
894        }
895    }
896}
897
898#[allow(non_snake_case)]
899#[cfg(test)]
900mod tests {
901    use super::*;
902    use approx::assert_abs_diff_eq;
903    use scirs2_core::ndarray::array;
904
905    #[test]
906    #[allow(non_snake_case)]
907    fn test_distributed_simple_imputer() {
908        let X = array![[1.0, 2.0, 3.0], [4.0, f64::NAN, 6.0], [7.0, 8.0, 9.0]];
909
910        let imputer = DistributedSimpleImputer::new()
911            .strategy("mean".to_string())
912            .distributed_config(DistributedConfig {
913                num_workers: 2,
914                ..Default::default()
915            });
916
917        let fitted = imputer.fit(&X.view(), &()).unwrap();
918        let X_imputed = fitted.transform(&X.view()).unwrap();
919
920        // Check that NaN was replaced with mean of column (2.0 + 8.0) / 2 = 5.0
921        assert_abs_diff_eq!(X_imputed[[1, 1]], 5.0, epsilon = 1e-10);
922        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
923        assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
924    }
925
926    #[test]
927    #[allow(non_snake_case)]
928    fn test_distributed_knn_imputer() {
929        let X = array![
930            [1.0, 2.0, 3.0],
931            [4.0, f64::NAN, 6.0],
932            [7.0, 8.0, 9.0],
933            [10.0, 11.0, 12.0]
934        ];
935
936        let imputer = DistributedKNNImputer::new()
937            .n_neighbors(2)
938            .weights("uniform".to_string())
939            .distributed_config(DistributedConfig {
940                num_workers: 2,
941                chunk_size: 2,
942                ..Default::default()
943            });
944
945        let fitted = imputer.fit(&X.view(), &()).unwrap();
946        let X_imputed = fitted.transform(&X.view()).unwrap();
947
948        // Verify that missing value was imputed
949        assert!(!X_imputed[[1, 1]].is_nan());
950        assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
951        assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
952    }
953
954    #[test]
955    #[allow(non_snake_case)]
956    fn test_data_partitioning() {
957        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
958
959        let imputer = DistributedKNNImputer::new().distributed_config(DistributedConfig {
960            num_workers: 2,
961            chunk_size: 2,
962            ..Default::default()
963        });
964
965        let fitted = imputer.fit(&X.view(), &()).unwrap();
966        let partitions = fitted.partition_data(&X.mapv(|x| x)).unwrap();
967
968        assert_eq!(partitions.len(), 2);
969        assert_eq!(partitions[0].data.nrows(), 2);
970        assert_eq!(partitions[1].data.nrows(), 2);
971        assert_eq!(partitions[0].start_row, 0);
972        assert_eq!(partitions[0].end_row, 2);
973        assert_eq!(partitions[1].start_row, 2);
974        assert_eq!(partitions[1].end_row, 4);
975    }
976}