1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9use crate::core::{ImputationError, ImputationMetadata, Imputer};
14use crate::simple::SimpleImputer;
15use rayon::prelude::*;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18 error::{Result as SklResult, SklearsError},
19 traits::{Estimator, Fit, Transform, Untrained},
20 types::Float,
21};
22use std::collections::HashMap;
23use std::sync::{Arc, Mutex, RwLock};
24use std::thread;
25use std::time::{Duration, Instant};
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct DistributedConfig {
30 pub num_workers: usize,
32 pub chunk_size: usize,
34 pub communication_strategy: CommunicationStrategy,
36 pub load_balancing: bool,
38 pub fault_tolerance: bool,
40 pub max_retries: usize,
42 pub worker_timeout: Duration,
44}
45
46impl Default for DistributedConfig {
47 fn default() -> Self {
48 Self {
49 num_workers: num_cpus::get(),
50 chunk_size: 10000,
51 communication_strategy: CommunicationStrategy::SharedMemory,
52 load_balancing: true,
53 fault_tolerance: true,
54 max_retries: 3,
55 worker_timeout: Duration::from_secs(300),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum CommunicationStrategy {
63 SharedMemory,
65 MessagePassing,
67 ParameterServer,
69 AllReduce,
71}
72
73#[derive(Debug, Clone)]
75pub struct DataPartition {
76 pub id: usize,
78 pub start_row: usize,
80 pub end_row: usize,
82 pub data: Array2<f64>,
84 pub missing_mask: Array2<bool>,
86 pub metadata: PartitionMetadata,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct PartitionMetadata {
93 pub partition_id: usize,
95 pub worker_id: usize,
97 pub num_samples: usize,
99 pub num_features: usize,
101 pub missing_ratio: f64,
103 pub processing_time: Duration,
105 pub memory_usage: usize,
107}
108
109pub struct DistributedWorker {
111 pub id: usize,
113 pub config: DistributedConfig,
115 pub partitions: Vec<DataPartition>,
117 pub local_imputer: Box<dyn Imputer + Send + Sync>,
119 pub statistics: WorkerStatistics,
121}
122
123#[derive(Debug, Default, Clone)]
125pub struct WorkerStatistics {
126 pub samples_processed: usize,
128 pub features_imputed: usize,
130 pub processing_time: Duration,
132 pub memory_peak: usize,
134 pub errors_count: usize,
136 pub retries_count: usize,
138}
139
140pub struct DistributedKNNImputer<S = Untrained> {
142 state: S,
143 n_neighbors: usize,
144 weights: String,
145 missing_values: f64,
146 config: DistributedConfig,
147 workers: Vec<DistributedWorker>,
148 coordinator: Option<ImputationCoordinator>,
149}
150
151pub struct DistributedKNNImputerTrained {
153 reference_data: Arc<RwLock<Array2<f64>>>,
154 n_features_in_: usize,
155 config: DistributedConfig,
156 workers: Vec<Arc<Mutex<DistributedWorker>>>,
157 coordinator: ImputationCoordinator,
158}
159
160#[derive(Debug)]
162pub struct ImputationCoordinator {
163 pub config: DistributedConfig,
165 pub workers: HashMap<usize, WorkerHandle>,
167 pub data_partitioner: DataPartitioner,
169 pub result_aggregator: ResultAggregator,
171 pub fault_handler: FaultHandler,
173}
174
175#[derive(Debug)]
177pub struct WorkerHandle {
178 pub id: usize,
180 pub thread_handle: Option<thread::JoinHandle<Result<WorkerResult, ImputationError>>>,
182 pub status: WorkerStatus,
184 pub last_heartbeat: Instant,
186}
187
188#[derive(Debug, Clone, PartialEq)]
190pub enum WorkerStatus {
191 Idle,
193 Processing,
195 Completed,
197 Failed,
199 Timeout,
201}
202
203#[derive(Debug, Clone)]
205pub struct WorkerResult {
206 pub worker_id: usize,
208 pub partition_id: usize,
210 pub imputed_data: Array2<f64>,
212 pub statistics: WorkerStatistics,
214 pub metadata: ImputationMetadata,
216}
217
218#[derive(Debug)]
220pub struct DataPartitioner {
221 strategy: PartitioningStrategy,
222}
223
224#[derive(Debug, Clone)]
226pub enum PartitioningStrategy {
227 Horizontal,
229 Vertical,
231 Random,
233 Stratified,
235 Hash,
237}
238
239#[derive(Debug)]
241pub struct ResultAggregator {
242 strategy: AggregationStrategy,
243}
244
245#[derive(Debug, Clone)]
247pub enum AggregationStrategy {
248 Concatenate,
250 WeightedAverage,
252 Consensus,
254 ModelAveraging,
256}
257
258#[derive(Debug)]
260pub struct FaultHandler {
261 pub max_retries: usize,
263 pub retry_delay: Duration,
265 pub checkpointing_enabled: bool,
267 pub checkpoint_interval: Duration,
269}
270
271impl DistributedKNNImputer<Untrained> {
272 pub fn new() -> Self {
274 Self {
275 state: Untrained,
276 n_neighbors: 5,
277 weights: "uniform".to_string(),
278 missing_values: f64::NAN,
279 config: DistributedConfig::default(),
280 workers: Vec::new(),
281 coordinator: None,
282 }
283 }
284
285 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
287 self.n_neighbors = n_neighbors;
288 self
289 }
290
291 pub fn weights(mut self, weights: String) -> Self {
293 self.weights = weights;
294 self
295 }
296
297 pub fn distributed_config(mut self, config: DistributedConfig) -> Self {
299 self.config = config;
300 self
301 }
302
303 pub fn num_workers(mut self, num_workers: usize) -> Self {
305 self.config.num_workers = num_workers;
306 self
307 }
308
309 pub fn chunk_size(mut self, chunk_size: usize) -> Self {
311 self.config.chunk_size = chunk_size;
312 self
313 }
314
315 pub fn communication_strategy(mut self, strategy: CommunicationStrategy) -> Self {
317 self.config.communication_strategy = strategy;
318 self
319 }
320
321 pub fn fault_tolerance(mut self, enabled: bool) -> Self {
323 self.config.fault_tolerance = enabled;
324 self
325 }
326
327 fn is_missing(&self, value: f64) -> bool {
328 if self.missing_values.is_nan() {
329 value.is_nan()
330 } else {
331 (value - self.missing_values).abs() < f64::EPSILON
332 }
333 }
334}
335
336impl Default for DistributedKNNImputer<Untrained> {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342impl Estimator for DistributedKNNImputer<Untrained> {
343 type Config = DistributedConfig;
344 type Error = SklearsError;
345 type Float = Float;
346
347 fn config(&self) -> &Self::Config {
348 &self.config
349 }
350}
351
352impl Fit<ArrayView2<'_, Float>, ()> for DistributedKNNImputer<Untrained> {
353 type Fitted = DistributedKNNImputer<DistributedKNNImputerTrained>;
354
355 #[allow(non_snake_case)]
356 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
357 let X = X.mapv(|x| x);
358 let (n_samples, n_features) = X.dim();
359
360 if n_samples < self.config.num_workers {
361 return Err(SklearsError::InvalidInput(
362 "Dataset too small for distributed processing. Use regular KNN imputer."
363 .to_string(),
364 ));
365 }
366
367 let data_partitioner = DataPartitioner {
369 strategy: PartitioningStrategy::Horizontal,
370 };
371
372 let result_aggregator = ResultAggregator {
374 strategy: AggregationStrategy::Concatenate,
375 };
376
377 let fault_handler = FaultHandler {
379 max_retries: self.config.max_retries,
380 retry_delay: Duration::from_secs(1),
381 checkpointing_enabled: false,
382 checkpoint_interval: Duration::from_secs(60),
383 };
384
385 let coordinator = ImputationCoordinator {
387 config: self.config.clone(),
388 workers: HashMap::new(),
389 data_partitioner,
390 result_aggregator,
391 fault_handler,
392 };
393
394 let mut workers = Vec::new();
396 for worker_id in 0..self.config.num_workers {
397 let worker = DistributedWorker {
398 id: worker_id,
399 config: self.config.clone(),
400 partitions: Vec::new(),
401 local_imputer: Box::new(SimpleImputer::default()),
402 statistics: WorkerStatistics::default(),
403 };
404 workers.push(Arc::new(Mutex::new(worker)));
405 }
406
407 Ok(DistributedKNNImputer {
408 state: DistributedKNNImputerTrained {
409 reference_data: Arc::new(RwLock::new(X.clone())),
410 n_features_in_: n_features,
411 config: self.config,
412 workers,
413 coordinator,
414 },
415 n_neighbors: self.n_neighbors,
416 weights: self.weights,
417 missing_values: self.missing_values,
418 config: Default::default(),
419 workers: Vec::new(),
420 coordinator: None,
421 })
422 }
423}
424
425impl Transform<ArrayView2<'_, Float>, Array2<Float>>
426 for DistributedKNNImputer<DistributedKNNImputerTrained>
427{
428 #[allow(non_snake_case)]
429 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
430 let X = X.mapv(|x| x);
431 let (_n_samples, n_features) = X.dim();
432
433 if n_features != self.state.n_features_in_ {
434 return Err(SklearsError::InvalidInput(format!(
435 "Number of features {} does not match training features {}",
436 n_features, self.state.n_features_in_
437 )));
438 }
439
440 let partitions = self.partition_data(&X)?;
442
443 let results = self.process_partitions_distributed(partitions)?;
445
446 let X_imputed = self.aggregate_results(results)?;
448
449 Ok(X_imputed.mapv(|x| x as Float))
450 }
451}
452
453impl DistributedKNNImputer<DistributedKNNImputerTrained> {
454 fn partition_data(&self, X: &Array2<f64>) -> Result<Vec<DataPartition>, ImputationError> {
456 let (n_samples, _n_features) = X.dim();
457 let chunk_size = self
458 .state
459 .config
460 .chunk_size
461 .min(n_samples / self.state.config.num_workers);
462 let mut partitions = Vec::new();
463
464 for (partition_id, chunk) in X.axis_chunks_iter(Axis(0), chunk_size).enumerate() {
465 let start_row = partition_id * chunk_size;
466 let end_row = (start_row + chunk.nrows()).min(n_samples);
467
468 let mut missing_mask = Array2::<bool>::from_elem(chunk.dim(), false);
470 for ((i, j), &value) in chunk.indexed_iter() {
471 missing_mask[[i, j]] = self.is_missing(value);
472 }
473
474 let total_elements = chunk.len();
476 let missing_count = missing_mask.iter().filter(|&&x| x).count();
477 let missing_ratio = missing_count as f64 / total_elements as f64;
478
479 let metadata = PartitionMetadata {
480 partition_id,
481 worker_id: partition_id % self.state.config.num_workers,
482 num_samples: chunk.nrows(),
483 num_features: chunk.ncols(),
484 missing_ratio,
485 processing_time: Duration::default(),
486 memory_usage: chunk.len() * std::mem::size_of::<f64>(),
487 };
488
489 partitions.push(DataPartition {
490 id: partition_id,
491 start_row,
492 end_row,
493 data: chunk.to_owned(),
494 missing_mask,
495 metadata,
496 });
497 }
498
499 Ok(partitions)
500 }
501
502 fn process_partitions_distributed(
504 &self,
505 partitions: Vec<DataPartition>,
506 ) -> Result<Vec<WorkerResult>, ImputationError> {
507 let reference_data = self.state.reference_data.clone();
508 let n_neighbors = self.n_neighbors;
509 let weights = self.weights.clone();
510 let _missing_values = self.missing_values;
511
512 let results: Result<Vec<_>, _> = partitions
514 .into_par_iter()
515 .map(|partition| -> Result<WorkerResult, ImputationError> {
516 let start_time = Instant::now();
517 let worker_id = partition.metadata.worker_id;
518
519 let ref_data = reference_data.read().map_err(|_| {
521 ImputationError::ProcessingError("Failed to access reference data".to_string())
522 })?;
523
524 let mut imputed_data = partition.data.clone();
526
527 for i in 0..imputed_data.nrows() {
528 for j in 0..imputed_data.ncols() {
529 if partition.missing_mask[[i, j]] {
530 let query_row = imputed_data.row(i);
532 let query_row_2d = query_row.insert_axis(Axis(0));
533 let neighbors =
534 self.find_knn_neighbors(&ref_data, query_row_2d, n_neighbors, j)?;
535
536 let imputed_value =
538 self.compute_weighted_average(&neighbors, &weights)?;
539 imputed_data[[i, j]] = imputed_value;
540 }
541 }
542 }
543
544 let processing_time = start_time.elapsed();
545
546 let statistics = WorkerStatistics {
547 samples_processed: partition.metadata.num_samples,
548 features_imputed: partition.missing_mask.iter().filter(|&&x| x).count(),
549 processing_time,
550 memory_peak: partition.metadata.memory_usage,
551 errors_count: 0,
552 retries_count: 0,
553 };
554
555 let metadata = ImputationMetadata {
556 method: "DistributedKNN".to_string(),
557 parameters: {
558 let mut params = std::collections::HashMap::new();
559 params.insert("n_neighbors".to_string(), n_neighbors.to_string());
560 params.insert("weights".to_string(), weights.clone());
561 params
562 },
563 processing_time_ms: Some(processing_time.as_millis() as u64),
564 n_imputed: partition.missing_mask.iter().filter(|&&x| x).count(),
565 convergence_info: None,
566 quality_metrics: None,
567 };
568
569 Ok(WorkerResult {
570 worker_id,
571 partition_id: partition.id,
572 imputed_data,
573 statistics,
574 metadata,
575 })
576 })
577 .collect();
578
579 results.map_err(|_| {
580 ImputationError::ProcessingError("Distributed processing failed".to_string())
581 })
582 }
583
584 fn find_knn_neighbors(
586 &self,
587 reference_data: &Array2<f64>,
588 query_row: ArrayView2<f64>,
589 k: usize,
590 target_feature: usize,
591 ) -> Result<Vec<(f64, f64)>, ImputationError> {
592 let mut neighbors = Vec::new();
593
594 for ref_row_idx in 0..reference_data.nrows() {
595 let ref_row = reference_data.row(ref_row_idx);
596
597 if self.is_missing(ref_row[target_feature]) {
599 continue;
600 }
601
602 let distance = self.calculate_nan_euclidean_distance(query_row.row(0), ref_row);
604
605 if distance.is_finite() {
606 neighbors.push((distance, ref_row[target_feature]));
607 }
608 }
609
610 neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
612 neighbors.truncate(k);
613
614 Ok(neighbors)
615 }
616
617 fn calculate_nan_euclidean_distance(
619 &self,
620 row1: ArrayView1<f64>,
621 row2: ArrayView1<f64>,
622 ) -> f64 {
623 let mut sum_sq = 0.0;
624 let mut valid_count = 0;
625
626 for (&x1, &x2) in row1.iter().zip(row2.iter()) {
627 if !self.is_missing(x1) && !self.is_missing(x2) {
628 sum_sq += (x1 - x2).powi(2);
629 valid_count += 1;
630 }
631 }
632
633 if valid_count > 0 {
634 (sum_sq / valid_count as f64).sqrt()
635 } else {
636 f64::INFINITY
637 }
638 }
639
640 fn compute_weighted_average(
642 &self,
643 neighbors: &[(f64, f64)],
644 weights_type: &str,
645 ) -> Result<f64, ImputationError> {
646 if neighbors.is_empty() {
647 return Err(ImputationError::ProcessingError(
648 "No valid neighbors found".to_string(),
649 ));
650 }
651
652 match weights_type {
653 "uniform" => {
654 let sum: f64 = neighbors.iter().map(|(_, value)| value).sum();
655 Ok(sum / neighbors.len() as f64)
656 }
657 "distance" => {
658 let mut weighted_sum = 0.0;
659 let mut weight_sum = 0.0;
660
661 for &(distance, value) in neighbors {
662 let weight = if distance > 0.0 { 1.0 / distance } else { 1e6 };
663 weighted_sum += weight * value;
664 weight_sum += weight;
665 }
666
667 if weight_sum > 0.0 {
668 Ok(weighted_sum / weight_sum)
669 } else {
670 Ok(neighbors[0].1) }
672 }
673 _ => Err(ImputationError::InvalidConfiguration(format!(
674 "Unknown weights type: {}",
675 weights_type
676 ))),
677 }
678 }
679
680 fn aggregate_results(
682 &self,
683 results: Vec<WorkerResult>,
684 ) -> Result<Array2<f64>, ImputationError> {
685 if results.is_empty() {
686 return Err(ImputationError::ProcessingError(
687 "No results to aggregate".to_string(),
688 ));
689 }
690
691 let mut sorted_results = results;
693 sorted_results.sort_by_key(|r| r.partition_id);
694
695 let first_result = &sorted_results[0];
697 let n_features = first_result.imputed_data.ncols();
698 let total_rows: usize = sorted_results.iter().map(|r| r.imputed_data.nrows()).sum();
699
700 let mut aggregated_data = Array2::<f64>::zeros((total_rows, n_features));
701 let mut current_row = 0;
702
703 for result in sorted_results {
704 let chunk_rows = result.imputed_data.nrows();
705 aggregated_data
706 .slice_mut(s![current_row..current_row + chunk_rows, ..])
707 .assign(&result.imputed_data);
708 current_row += chunk_rows;
709 }
710
711 Ok(aggregated_data)
712 }
713
714 fn is_missing(&self, value: f64) -> bool {
715 if self.missing_values.is_nan() {
716 value.is_nan()
717 } else {
718 (value - self.missing_values).abs() < f64::EPSILON
719 }
720 }
721}
722
723#[derive(Debug)]
725pub struct DistributedSimpleImputer<S = Untrained> {
726 state: S,
727 strategy: String,
728 missing_values: f64,
729 config: DistributedConfig,
730}
731
732#[derive(Debug)]
734pub struct DistributedSimpleImputerTrained {
735 statistics_: Array1<f64>,
736 n_features_in_: usize,
737 config: DistributedConfig,
738}
739
740impl DistributedSimpleImputer<Untrained> {
741 pub fn new() -> Self {
742 Self {
743 state: Untrained,
744 strategy: "mean".to_string(),
745 missing_values: f64::NAN,
746 config: DistributedConfig::default(),
747 }
748 }
749
750 pub fn strategy(mut self, strategy: String) -> Self {
751 self.strategy = strategy;
752 self
753 }
754
755 pub fn distributed_config(mut self, config: DistributedConfig) -> Self {
756 self.config = config;
757 self
758 }
759
760 fn is_missing(&self, value: f64) -> bool {
761 if self.missing_values.is_nan() {
762 value.is_nan()
763 } else {
764 (value - self.missing_values).abs() < f64::EPSILON
765 }
766 }
767}
768
769impl Default for DistributedSimpleImputer<Untrained> {
770 fn default() -> Self {
771 Self::new()
772 }
773}
774
775impl Estimator for DistributedSimpleImputer<Untrained> {
776 type Config = DistributedConfig;
777 type Error = SklearsError;
778 type Float = Float;
779
780 fn config(&self) -> &Self::Config {
781 &self.config
782 }
783}
784
785impl Fit<ArrayView2<'_, Float>, ()> for DistributedSimpleImputer<Untrained> {
786 type Fitted = DistributedSimpleImputer<DistributedSimpleImputerTrained>;
787
788 #[allow(non_snake_case)]
789 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
790 let X = X.mapv(|x| x);
791 let (_, n_features) = X.dim();
792
793 let statistics: Vec<f64> = (0..n_features)
795 .into_par_iter()
796 .map(|j| {
797 let column = X.column(j);
798 let valid_values: Vec<f64> = column
799 .iter()
800 .filter(|&&x| !self.is_missing(x))
801 .cloned()
802 .collect();
803
804 if valid_values.is_empty() {
805 0.0
806 } else {
807 match self.strategy.as_str() {
808 "mean" => valid_values.iter().sum::<f64>() / valid_values.len() as f64,
809 "median" => {
810 let mut sorted_values = valid_values.clone();
811 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
812 let mid = sorted_values.len() / 2;
813 if sorted_values.len() % 2 == 0 {
814 (sorted_values[mid - 1] + sorted_values[mid]) / 2.0
815 } else {
816 sorted_values[mid]
817 }
818 }
819 "most_frequent" => {
820 let mut frequency_map = HashMap::new();
821 for &value in &valid_values {
822 *frequency_map.entry(value as i64).or_insert(0) += 1;
823 }
824 frequency_map
825 .into_iter()
826 .max_by_key(|(_, count)| *count)
827 .map(|(value, _)| value as f64)
828 .unwrap_or(0.0)
829 }
830 _ => valid_values.iter().sum::<f64>() / valid_values.len() as f64,
831 }
832 }
833 })
834 .collect();
835
836 Ok(DistributedSimpleImputer {
837 state: DistributedSimpleImputerTrained {
838 statistics_: Array1::from_vec(statistics),
839 n_features_in_: n_features,
840 config: self.config,
841 },
842 strategy: self.strategy,
843 missing_values: self.missing_values,
844 config: Default::default(),
845 })
846 }
847}
848
849impl Transform<ArrayView2<'_, Float>, Array2<Float>>
850 for DistributedSimpleImputer<DistributedSimpleImputerTrained>
851{
852 #[allow(non_snake_case)]
853 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
854 let X = X.mapv(|x| x);
855 let (n_samples, n_features) = X.dim();
856
857 if n_features != self.state.n_features_in_ {
858 return Err(SklearsError::InvalidInput(format!(
859 "Number of features {} does not match training features {}",
860 n_features, self.state.n_features_in_
861 )));
862 }
863
864 let imputed_rows: Vec<Array1<f64>> = (0..n_samples)
866 .into_par_iter()
867 .map(|i| {
868 let mut row = X.row(i).to_owned();
869 for j in 0..n_features {
870 if self.is_missing(row[j]) {
871 row[j] = self.state.statistics_[j];
872 }
873 }
874 row
875 })
876 .collect();
877
878 let mut X_imputed = Array2::zeros((n_samples, n_features));
880 for (i, row) in imputed_rows.into_iter().enumerate() {
881 X_imputed.row_mut(i).assign(&row);
882 }
883
884 Ok(X_imputed.mapv(|x| x as Float))
885 }
886}
887
888impl DistributedSimpleImputer<DistributedSimpleImputerTrained> {
889 fn is_missing(&self, value: f64) -> bool {
890 if self.missing_values.is_nan() {
891 value.is_nan()
892 } else {
893 (value - self.missing_values).abs() < f64::EPSILON
894 }
895 }
896}
897
898#[allow(non_snake_case)]
899#[cfg(test)]
900mod tests {
901 use super::*;
902 use approx::assert_abs_diff_eq;
903 use scirs2_core::ndarray::array;
904
905 #[test]
906 #[allow(non_snake_case)]
907 fn test_distributed_simple_imputer() {
908 let X = array![[1.0, 2.0, 3.0], [4.0, f64::NAN, 6.0], [7.0, 8.0, 9.0]];
909
910 let imputer = DistributedSimpleImputer::new()
911 .strategy("mean".to_string())
912 .distributed_config(DistributedConfig {
913 num_workers: 2,
914 ..Default::default()
915 });
916
917 let fitted = imputer.fit(&X.view(), &()).unwrap();
918 let X_imputed = fitted.transform(&X.view()).unwrap();
919
920 assert_abs_diff_eq!(X_imputed[[1, 1]], 5.0, epsilon = 1e-10);
922 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
923 assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
924 }
925
926 #[test]
927 #[allow(non_snake_case)]
928 fn test_distributed_knn_imputer() {
929 let X = array![
930 [1.0, 2.0, 3.0],
931 [4.0, f64::NAN, 6.0],
932 [7.0, 8.0, 9.0],
933 [10.0, 11.0, 12.0]
934 ];
935
936 let imputer = DistributedKNNImputer::new()
937 .n_neighbors(2)
938 .weights("uniform".to_string())
939 .distributed_config(DistributedConfig {
940 num_workers: 2,
941 chunk_size: 2,
942 ..Default::default()
943 });
944
945 let fitted = imputer.fit(&X.view(), &()).unwrap();
946 let X_imputed = fitted.transform(&X.view()).unwrap();
947
948 assert!(!X_imputed[[1, 1]].is_nan());
950 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
951 assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
952 }
953
954 #[test]
955 #[allow(non_snake_case)]
956 fn test_data_partitioning() {
957 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
958
959 let imputer = DistributedKNNImputer::new().distributed_config(DistributedConfig {
960 num_workers: 2,
961 chunk_size: 2,
962 ..Default::default()
963 });
964
965 let fitted = imputer.fit(&X.view(), &()).unwrap();
966 let partitions = fitted.partition_data(&X.mapv(|x| x)).unwrap();
967
968 assert_eq!(partitions.len(), 2);
969 assert_eq!(partitions[0].data.nrows(), 2);
970 assert_eq!(partitions[1].data.nrows(), 2);
971 assert_eq!(partitions[0].start_row, 0);
972 assert_eq!(partitions[0].end_row, 2);
973 assert_eq!(partitions[1].start_row, 2);
974 assert_eq!(partitions[1].end_row, 4);
975 }
976}