1use crate::error::{FFTError, FFTResult};
8use crate::fft::fft;
9use scirs2_core::ndarray::{s, ArrayBase, ArrayD, Data, Dimension, IxDyn};
10use scirs2_core::numeric::Complex64;
11use scirs2_core::numeric::NumCast;
12use std::fmt::Debug;
13use std::sync::Arc;
14use std::time::Instant;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum DecompositionStrategy {
19 Slab,
21 Pencil,
23 Volumetric,
25 Adaptive,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum CommunicationPattern {
32 AllToAll,
34 PointToPoint,
36 Neighbor,
38 Hybrid,
40}
41
42#[derive(Debug, Clone)]
44pub struct DistributedConfig {
45 pub node_count: usize,
47 pub rank: usize,
49 pub decomposition: DecompositionStrategy,
51 pub communication: CommunicationPattern,
53 pub process_grid: Vec<usize>,
55 pub local_size: Vec<usize>,
57 pub max_local_size: usize,
59}
60
61impl Default for DistributedConfig {
62 fn default() -> Self {
63 Self {
64 node_count: 1,
65 rank: 0,
66 decomposition: DecompositionStrategy::Slab,
67 communication: CommunicationPattern::AllToAll,
68 process_grid: vec![1],
69 local_size: vec![],
70 max_local_size: 1024, }
72 }
73}
74
75pub struct DistributedFFT {
77 config: DistributedConfig,
79 #[allow(dead_code)]
81 communicator: Arc<dyn Communicator>,
82}
83
84pub trait Communicator: Send + Sync + Debug {
86 fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()>;
88
89 fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>>;
91
92 fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>>;
94
95 fn barrier(&self) -> FFTResult<()>;
97
98 fn size(&self) -> usize;
100
101 fn rank(&self) -> usize;
103}
104
105impl DistributedFFT {
106 pub fn new(config: DistributedConfig, communicator: Arc<dyn Communicator>) -> Self {
108 Self {
109 config,
110 communicator,
111 }
112 }
113
114 pub fn distributed_fft<S, D>(&self, input: &ArrayBase<S, D>) -> FFTResult<ArrayD<Complex64>>
116 where
117 S: Data,
118 D: Dimension,
119 S::Elem: Into<Complex64> + Copy + Debug + NumCast,
120 {
121 let start = Instant::now();
123
124 let input_dyn = input.to_owned().into_dyn();
126
127 let local_data = self.decompose_data(&input_dyn)?;
129
130 let decomp_time = start.elapsed();
132
133 let mut local_result = ArrayD::zeros(local_data.dim());
135 self.perform_local_fft(&local_data, &mut local_result)?;
136
137 let local_fft_time = start.elapsed() - decomp_time;
139
140 let exchanged_data = self.exchange_data(&local_result)?;
142
143 let comm_time = start.elapsed() - decomp_time - local_fft_time;
145
146 let final_result = self.finalize_result(&exchanged_data, input.shape())?;
148
149 let total_time = start.elapsed();
151
152 if cfg!(debug_assertions) {
154 println!("Distributed FFT Performance:");
155 println!(" Decomposition: {:?}", decomp_time);
156 println!(" Local FFT: {:?}", local_fft_time);
157 println!(" Communication: {:?}", comm_time);
158 println!(" Total time: {:?}", total_time);
159 }
160
161 Ok(final_result)
162 }
163
164 pub fn decompose_data<T>(&self, input: &ArrayD<T>) -> FFTResult<ArrayD<Complex64>>
166 where
167 T: Into<Complex64> + Copy + NumCast,
168 {
169 let is_testing = cfg!(test) || std::env::var("RUST_TEST").is_ok();
171
172 match self.config.decomposition {
173 DecompositionStrategy::Slab => self.slab_decomposition(input, is_testing),
174 DecompositionStrategy::Pencil => self.pencil_decomposition(input, is_testing),
175 DecompositionStrategy::Volumetric => self.volumetric_decomposition(input, is_testing),
176 DecompositionStrategy::Adaptive => self.adaptive_decomposition(input, is_testing),
177 }
178 }
179
180 fn perform_local_fft(
182 &self,
183 input: &ArrayD<Complex64>,
184 output: &mut ArrayD<Complex64>,
185 ) -> FFTResult<()> {
186 if input.ndim() == 1
188 || (input.ndim() >= 2 && self.config.decomposition == DecompositionStrategy::Slab)
189 {
190 if input.ndim() >= 2 {
192 for i in 0..input.shape()[0].min(self.config.max_local_size) {
193 let row = input.slice(s![i, ..]).to_vec();
194 let result = fft(&row, None)?;
195 let mut output_row = output.slice_mut(s![i, ..]);
196 for (j, val) in result.iter().enumerate().take(output_row.len()) {
197 output_row[j] = *val;
198 }
199 }
200 } else {
201 let result = fft(input.as_slice().unwrap_or(&[]), None)?;
203 for (i, val) in result.iter().enumerate().take(output.len()) {
204 output[i] = *val;
205 }
206 }
207 } else if input.ndim() >= 2 && self.config.decomposition == DecompositionStrategy::Pencil {
208 for i in 0..input.shape()[0].min(self.config.max_local_size) {
211 for j in 0..input.shape()[1].min(self.config.max_local_size) {
212 let column = input.slice(s![i, j, ..]).to_vec();
213 let result = fft(&column, None)?;
214 let mut output_col = output.slice_mut(s![i, j, ..]);
215 for (k, val) in result.iter().enumerate().take(output_col.len()) {
216 output_col[k] = *val;
217 }
218 }
219 }
220 } else {
221 return Err(FFTError::DimensionError(format!(
223 "Unsupported decomposition strategy for input of dimension {}",
224 input.ndim()
225 )));
226 }
227
228 Ok(())
229 }
230
231 fn exchange_data(&self, localresult: &ArrayD<Complex64>) -> FFTResult<ArrayD<Complex64>> {
233 if self.config.node_count == 1 || self.config.rank == 0 {
239 return Ok(localresult.clone());
240 }
241
242 match self.config.communication {
245 CommunicationPattern::AllToAll => {
246 let flattened: Vec<Complex64> = localresult.iter().copied().collect();
248
249 let _result = self.communicator.all_to_all(&flattened)?;
251
252 Ok(localresult.clone())
254 }
255 CommunicationPattern::PointToPoint => {
256 Ok(localresult.clone())
259 }
260 _ => {
261 Ok(localresult.clone())
263 }
264 }
265 }
266
267 fn finalize_result(
269 &self,
270 exchanged_data: &ArrayD<Complex64>,
271 output_dim: &[usize],
272 ) -> FFTResult<ArrayD<Complex64>> {
273 if self.config.node_count == 1 || self.config.rank == 0 {
278 let limitedshape: Vec<usize> = output_dim
280 .iter()
281 .map(|&d| d.min(self.config.max_local_size))
282 .collect();
283
284 let mut output = ArrayD::zeros(IxDyn(&limitedshape));
286
287 if output_dim.len() == limitedshape.len() {
289 let mut all_match = true;
290 for (a, b) in output_dim.iter().zip(limitedshape.iter()) {
291 if a != b {
292 all_match = false;
293 break;
294 }
295 }
296
297 if all_match && !output.is_empty() && !exchanged_data.is_empty() {
298 let flat_output = output.as_slice_mut().expect("Operation failed");
300 for (i, &val) in exchanged_data.iter().enumerate().take(flat_output.len()) {
301 flat_output[i] = val;
302 }
303 } else {
304 if !output.is_empty() && !exchanged_data.is_empty() {
308 let flat_output = output.as_slice_mut().expect("Operation failed");
309 let copy_len = flat_output.len().min(exchanged_data.len());
310
311 for i in 0..copy_len {
312 flat_output[i] =
313 *exchanged_data.iter().nth(i).expect("Operation failed");
314 }
315 }
316 }
317 }
318
319 Ok(output)
320 } else {
321 Err(FFTError::ValueError(
324 "Only the root node (rank 0) produces the final output".to_string(),
325 ))
326 }
327 }
328
329 fn slab_decomposition<T>(
332 &self,
333 input: &ArrayD<T>,
334 is_testing: bool,
335 ) -> FFTResult<ArrayD<Complex64>>
336 where
337 T: Into<Complex64> + Copy + NumCast,
338 {
339 let shape = input.shape();
340
341 let max_size = if is_testing {
343 self.config.max_local_size
344 } else {
345 usize::MAX
346 };
347
348 if shape.is_empty() {
350 return Err(FFTError::DimensionError(
351 "Cannot perform FFT on empty array".to_string(),
352 ));
353 }
354
355 let total_slabs = shape[0];
357 let slabs_per_node = total_slabs.div_ceil(self.config.node_count);
358
359 let my_start = self.config.rank * slabs_per_node;
361 let my_end = (my_start + slabs_per_node).min(total_slabs);
362
363 if my_start >= total_slabs {
365 return Ok(ArrayD::zeros(IxDyn(&[0])));
367 }
368
369 let actual_end = my_end.min(my_start.saturating_add(max_size));
371
372 let mut myshape: Vec<usize> = shape.to_vec();
374 myshape[0] = actual_end - my_start;
375
376 let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
378
379 if input.ndim() == 1 {
381 for i in my_start..actual_end {
383 let input_idx = IxDyn(&[i]);
384 let output_idx = IxDyn(&[i - my_start]);
385 let val: Complex64 =
386 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
387 output[output_idx] = val;
388 }
389 } else if input.ndim() == 2 {
390 for i in my_start..actual_end {
392 for j in 0..shape[1].min(max_size) {
393 let input_idx = IxDyn(&[i, j]);
394 let output_idx = IxDyn(&[i - my_start, j]);
395 let val: Complex64 =
396 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
397 output[output_idx] = val;
398 }
399 }
400 } else if input.ndim() == 3 {
401 for i in my_start..actual_end {
403 for j in 0..shape[1].min(max_size) {
404 for k in 0..shape[2].min(max_size) {
405 let input_idx = IxDyn(&[i, j, k]);
406 let output_idx = IxDyn(&[i - my_start, j, k]);
407 let val: Complex64 =
408 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
409 output[output_idx] = val;
410 }
411 }
412 }
413 } else {
414 let ndim = input.ndim();
417 let iter_shape: Vec<usize> = (0..ndim)
419 .map(|ax| {
420 if ax == 0 {
421 actual_end - my_start
422 } else {
423 myshape[ax].min(max_size)
424 }
425 })
426 .collect();
427 let iter_dim = IxDyn(iter_shape.as_slice());
428 for local_idx in scirs2_core::ndarray::indices(iter_dim) {
429 let local_slice = local_idx.slice();
430 let mut global = local_slice.to_vec();
432 global[0] += my_start;
433 let val: Complex64 = NumCast::from(input[IxDyn(global.as_slice())])
434 .unwrap_or(Complex64::new(0.0, 0.0));
435 output[IxDyn(local_slice)] = val;
436 }
437 }
438
439 Ok(output)
440 }
441
442 fn pencil_decomposition<T>(
443 &self,
444 input: &ArrayD<T>,
445 is_testing: bool,
446 ) -> FFTResult<ArrayD<Complex64>>
447 where
448 T: Into<Complex64> + Copy + NumCast,
449 {
450 let shape = input.shape();
451
452 let max_size = if is_testing {
454 self.config.max_local_size
455 } else {
456 usize::MAX
457 };
458
459 if shape.len() < 2 {
461 return Err(FFTError::DimensionError(
462 "Pencil decomposition requires at least 2D input".to_string(),
463 ));
464 }
465
466 let process_grid = &self.config.process_grid;
469 if process_grid.len() < 2 {
470 return Err(FFTError::ValueError(
471 "Pencil decomposition requires a 2D process grid".to_string(),
472 ));
473 }
474
475 let p1 = process_grid[0];
476 let p2 = process_grid[1];
477
478 if p1 * p2 != self.config.node_count {
479 return Err(FFTError::ValueError(format!(
480 "Process grid ({} x {}) doesn't match node count ({})",
481 p1, p2, self.config.node_count
482 )));
483 }
484
485 let my_row = self.config.rank / p2;
487 let my_col = self.config.rank % p2;
488
489 let n1 = shape[0];
491 let n2 = shape[1];
492
493 let rows_per_node = n1.div_ceil(p1);
494 let cols_per_node = n2.div_ceil(p2);
495
496 let my_start_row = my_row * rows_per_node;
497 let my_end_row = (my_start_row + rows_per_node).min(n1);
498
499 let my_start_col = my_col * cols_per_node;
500 let my_end_col = (my_start_col + cols_per_node).min(n2);
501
502 if my_start_row >= n1 || my_start_col >= n2 {
504 return Ok(ArrayD::zeros(IxDyn(&[0])));
506 }
507
508 let actual_end_row = my_end_row.min(my_start_row.saturating_add(max_size));
510 let actual_end_col = my_end_col.min(my_start_col.saturating_add(max_size));
511
512 let mut myshape: Vec<usize> = shape.to_vec();
514 myshape[0] = actual_end_row - my_start_row;
515 myshape[1] = actual_end_col - my_start_col;
516
517 let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
519
520 if input.ndim() == 2 {
522 for i in my_start_row..actual_end_row {
524 for j in my_start_col..actual_end_col {
525 let input_idx = IxDyn(&[i, j]);
526 let output_idx = IxDyn(&[i - my_start_row, j - my_start_col]);
527 let val: Complex64 =
528 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
529 output[output_idx] = val;
530 }
531 }
532 } else if input.ndim() == 3 {
533 for i in my_start_row..actual_end_row {
535 for j in my_start_col..actual_end_col {
536 for k in 0..shape[2].min(max_size) {
537 let input_idx = IxDyn(&[i, j, k]);
538 let output_idx = IxDyn(&[i - my_start_row, j - my_start_col, k]);
539 let val: Complex64 =
540 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
541 output[output_idx] = val;
542 }
543 }
544 }
545 } else {
546 let ndim = input.ndim();
549 let iter_shape: Vec<usize> = (0..ndim)
550 .map(|ax| match ax {
551 0 => actual_end_row - my_start_row,
552 1 => actual_end_col - my_start_col,
553 _ => myshape[ax].min(max_size),
554 })
555 .collect();
556 let iter_dim = IxDyn(iter_shape.as_slice());
557 for local_idx in scirs2_core::ndarray::indices(iter_dim) {
558 let local_slice = local_idx.slice();
559 let mut global = local_slice.to_vec();
561 global[0] += my_start_row;
562 global[1] += my_start_col;
563 let val: Complex64 = NumCast::from(input[IxDyn(global.as_slice())])
564 .unwrap_or(Complex64::new(0.0, 0.0));
565 output[IxDyn(local_slice)] = val;
566 }
567 }
568
569 Ok(output)
570 }
571
572 fn volumetric_decomposition<T>(
573 &self,
574 input: &ArrayD<T>,
575 is_testing: bool,
576 ) -> FFTResult<ArrayD<Complex64>>
577 where
578 T: Into<Complex64> + Copy + NumCast,
579 {
580 let shape = input.shape();
581
582 let max_size = if is_testing {
584 self.config.max_local_size
585 } else {
586 usize::MAX
587 };
588
589 if shape.len() < 3 {
591 return Err(FFTError::DimensionError(
592 "Volumetric decomposition requires at least 3D input".to_string(),
593 ));
594 }
595
596 let process_grid = &self.config.process_grid;
599 if process_grid.len() < 3 {
600 return Err(FFTError::ValueError(
601 "Volumetric decomposition requires a 3D process grid".to_string(),
602 ));
603 }
604
605 let p1 = process_grid[0];
606 let p2 = process_grid[1];
607 let p3 = process_grid[2];
608
609 if p1 * p2 * p3 != self.config.node_count {
610 return Err(FFTError::ValueError(format!(
611 "Process grid ({} x {} x {}) doesn't match node count ({})",
612 p1, p2, p3, self.config.node_count
613 )));
614 }
615
616 let my_plane = self.config.rank / (p2 * p3);
618 let remainder = self.config.rank % (p2 * p3);
619 let my_row = remainder / p3;
620 let my_col = remainder % p3;
621
622 let n1 = shape[0];
624 let n2 = shape[1];
625 let n3 = shape[2];
626
627 let planes_per_node = n1.div_ceil(p1);
628 let rows_per_node = n2.div_ceil(p2);
629 let cols_per_node = n3.div_ceil(p3);
630
631 let my_start_plane = my_plane * planes_per_node;
632 let my_end_plane = (my_start_plane + planes_per_node).min(n1);
633
634 let my_start_row = my_row * rows_per_node;
635 let my_end_row = (my_start_row + rows_per_node).min(n2);
636
637 let my_start_col = my_col * cols_per_node;
638 let my_end_col = (my_start_col + cols_per_node).min(n3);
639
640 if my_start_plane >= n1 || my_start_row >= n2 || my_start_col >= n3 {
642 return Ok(ArrayD::zeros(IxDyn(&[0])));
644 }
645
646 let actual_end_plane = my_end_plane.min(my_start_plane.saturating_add(max_size));
648 let actual_end_row = my_end_row.min(my_start_row.saturating_add(max_size));
649 let actual_end_col = my_end_col.min(my_start_col.saturating_add(max_size));
650
651 let mut myshape: Vec<usize> = shape.to_vec();
653 myshape[0] = actual_end_plane - my_start_plane;
654 myshape[1] = actual_end_row - my_start_row;
655 myshape[2] = actual_end_col - my_start_col;
656
657 let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
659
660 if input.ndim() == 3 {
662 for i in my_start_plane..actual_end_plane {
664 for j in my_start_row..actual_end_row {
665 for k in my_start_col..actual_end_col {
666 let input_idx = IxDyn(&[i, j, k]);
667 let output_idx =
668 IxDyn(&[i - my_start_plane, j - my_start_row, k - my_start_col]);
669 let val: Complex64 =
670 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
671 output[output_idx] = val;
672 }
673 }
674 }
675 } else {
676 let ndim = input.ndim();
679 let iter_shape: Vec<usize> = (0..ndim)
680 .map(|ax| match ax {
681 0 => actual_end_plane - my_start_plane,
682 1 => actual_end_row - my_start_row,
683 2 => actual_end_col - my_start_col,
684 _ => myshape[ax].min(max_size),
685 })
686 .collect();
687 let iter_dim = IxDyn(iter_shape.as_slice());
688 for local_idx in scirs2_core::ndarray::indices(iter_dim) {
689 let local_slice = local_idx.slice();
690 let mut global = local_slice.to_vec();
692 global[0] += my_start_plane;
693 global[1] += my_start_row;
694 global[2] += my_start_col;
695 let val: Complex64 = NumCast::from(input[IxDyn(global.as_slice())])
696 .unwrap_or(Complex64::new(0.0, 0.0));
697 output[IxDyn(local_slice)] = val;
698 }
699 }
700
701 Ok(output)
702 }
703
704 fn adaptive_decomposition<T>(
705 &self,
706 input: &ArrayD<T>,
707 is_testing: bool,
708 ) -> FFTResult<ArrayD<Complex64>>
709 where
710 T: Into<Complex64> + Copy + NumCast,
711 {
712 let ndim = input.ndim();
713
714 if ndim == 1 || self.config.node_count == 1 {
716 self.slab_decomposition(input, is_testing)
718 } else if ndim == 2 || self.config.node_count < 8 {
719 self.slab_decomposition(input, is_testing)
721 } else if ndim == 3 && self.config.node_count >= 8 {
722 let mut config = self.config.clone();
725 if config.process_grid.len() < 2 {
726 let sqrt_nodes = (self.config.node_count as f64).sqrt().floor() as usize;
727 config.process_grid = vec![sqrt_nodes, self.config.node_count / sqrt_nodes];
728 }
729
730 let temp_dfft = DistributedFFT {
732 config,
733 communicator: self.communicator.clone(),
734 };
735
736 temp_dfft.pencil_decomposition(input, is_testing)
737 } else if ndim >= 3 && self.config.node_count >= 27 {
738 let mut config = self.config.clone();
741 if config.process_grid.len() < 3 {
742 let cbrt_nodes = (self.config.node_count as f64).cbrt().floor() as usize;
743 let remaining = self.config.node_count / cbrt_nodes;
744 let sqrt_remaining = (remaining as f64).sqrt().floor() as usize;
745 config.process_grid = vec![cbrt_nodes, sqrt_remaining, remaining / sqrt_remaining];
746 }
747
748 let temp_dfft = DistributedFFT {
750 config,
751 communicator: self.communicator.clone(),
752 };
753
754 temp_dfft.volumetric_decomposition(input, is_testing)
755 } else {
756 self.slab_decomposition(input, is_testing)
758 }
759 }
760
761 #[cfg(test)]
763 pub fn new_mock(config: DistributedConfig) -> Self {
764 let communicator = Arc::new(MockCommunicator::new(config.node_count, config.rank));
765 Self {
766 config,
767 communicator,
768 }
769 }
770}
771
772#[derive(Debug)]
774pub struct BasicCommunicator {
775 size: usize,
777 rank: usize,
779}
780
781impl BasicCommunicator {
782 pub fn new(size: usize, rank: usize) -> Self {
784 Self { size, rank }
785 }
786}
787
788impl Communicator for BasicCommunicator {
789 fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()> {
790 let _ = tag; if dest >= self.size {
792 return Err(FFTError::ValueError(format!(
793 "Invalid destination rank: {} (size: {})",
794 dest, self.size
795 )));
796 }
797
798 if data.is_empty() {
801 return Err(FFTError::ValueError("Cannot send empty data".to_string()));
802 }
803
804 Ok(())
805 }
806
807 fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>> {
808 let _ = tag; if src >= self.size {
810 return Err(FFTError::ValueError(format!(
811 "Invalid source rank: {} (size: {})",
812 src, self.size
813 )));
814 }
815
816 Ok(vec![Complex64::new(0.0, 0.0); size])
819 }
820
821 fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>> {
822 Ok(senddata.to_vec())
825 }
826
827 fn barrier(&self) -> FFTResult<()> {
828 Ok(())
831 }
832
833 fn size(&self) -> usize {
834 self.size
835 }
836
837 fn rank(&self) -> usize {
838 self.rank
839 }
840}
841
842#[derive(Debug)]
844pub struct MockCommunicator {
845 size: usize,
846 rank: usize,
847}
848
849impl MockCommunicator {
850 pub fn new(size: usize, rank: usize) -> Self {
852 Self { size, rank }
853 }
854}
855
856impl Communicator for MockCommunicator {
857 fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()> {
858 let _ = tag; if dest >= self.size {
860 return Err(FFTError::ValueError(format!(
861 "Invalid destination rank: {} (size: {})",
862 dest, self.size
863 )));
864 }
865
866 Ok(())
868 }
869
870 fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>> {
871 let _ = tag; if src >= self.size {
873 return Err(FFTError::ValueError(format!(
874 "Invalid source rank: {} (size: {})",
875 src, self.size
876 )));
877 }
878
879 Ok(vec![Complex64::new(0.0, 0.0); size])
881 }
882
883 fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>> {
884 Ok(senddata.to_vec())
886 }
887
888 fn barrier(&self) -> FFTResult<()> {
889 Ok(())
891 }
892
893 fn size(&self) -> usize {
894 self.size
895 }
896
897 fn rank(&self) -> usize {
898 self.rank
899 }
900}
901
902#[cfg(test)]
903mod tests {
904 use super::*;
905 use scirs2_core::ndarray::{Array1, Array2};
906
907 #[test]
908 fn test_distributed_config_default() {
909 let config = DistributedConfig::default();
910 assert_eq!(config.node_count, 1);
911 assert_eq!(config.rank, 0);
912 assert_eq!(config.decomposition, DecompositionStrategy::Slab);
913 }
914
915 #[test]
916 fn test_mock_communicator() {
917 let comm = MockCommunicator::new(4, 0);
918 assert_eq!(comm.size(), 4);
919 assert_eq!(comm.rank(), 0);
920
921 let data = vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
923 let result = comm.send(&data, 1, 0);
924 assert!(result.is_ok());
925
926 let result = comm.send(&data, 4, 0);
928 assert!(result.is_err());
929
930 let result = comm.recv(1, 0, 2);
932 assert!(result.is_ok());
933 assert_eq!(result.expect("Operation failed").len(), 2);
934
935 let result = comm.recv(4, 0, 2);
937 assert!(result.is_err());
938
939 let result = comm.all_to_all(&data);
941 assert!(result.is_ok());
942 assert_eq!(result.expect("Operation failed"), data);
943
944 let result = comm.barrier();
946 assert!(result.is_ok());
947 }
948
949 #[test]
950 fn test_slab_decomposition_1d() {
951 let config = DistributedConfig {
952 node_count: 2,
953 rank: 0,
954 decomposition: DecompositionStrategy::Slab,
955 communication: CommunicationPattern::AllToAll,
956 process_grid: vec![2],
957 local_size: vec![],
958 max_local_size: 16,
959 };
960
961 let dfft = DistributedFFT::new_mock(config);
962
963 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]).into_dyn();
964 let result = dfft.slab_decomposition(&input, true);
965 assert!(result.is_ok());
966
967 let local_data = result.expect("Operation failed");
968 assert_eq!(local_data.ndim(), 1);
969 assert_eq!(local_data.shape()[0], 2); }
971
972 #[test]
973 fn test_slab_decomposition_2d() {
974 let config = DistributedConfig {
975 node_count: 2,
976 rank: 0,
977 decomposition: DecompositionStrategy::Slab,
978 communication: CommunicationPattern::AllToAll,
979 process_grid: vec![2],
980 local_size: vec![],
981 max_local_size: 16,
982 };
983
984 let dfft = DistributedFFT::new_mock(config);
985
986 let input = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
987 .expect("Operation failed")
988 .into_dyn();
989 let result = dfft.slab_decomposition(&input, true);
990 assert!(result.is_ok());
991
992 let local_data = result.expect("Operation failed");
993 assert_eq!(local_data.ndim(), 2);
994 assert_eq!(local_data.shape()[0], 2); assert_eq!(local_data.shape()[1], 2); }
997
998 #[test]
999 fn test_pencil_decomposition_2d() {
1000 let config = DistributedConfig {
1001 node_count: 4,
1002 rank: 0,
1003 decomposition: DecompositionStrategy::Pencil,
1004 communication: CommunicationPattern::AllToAll,
1005 process_grid: vec![2, 2],
1006 local_size: vec![],
1007 max_local_size: 16,
1008 };
1009
1010 let dfft = DistributedFFT::new_mock(config);
1011
1012 let input = Array2::from_shape_vec((4, 4), (1..=16).map(|x| x as f64).collect())
1013 .expect("Operation failed")
1014 .into_dyn();
1015 let result = dfft.pencil_decomposition(&input, true);
1016 assert!(result.is_ok());
1017
1018 let local_data = result.expect("Operation failed");
1019 assert_eq!(local_data.ndim(), 2);
1020 assert_eq!(local_data.shape()[0], 2); assert_eq!(local_data.shape()[1], 2); }
1023
1024 #[test]
1025 fn test_adaptive_decomposition() {
1026 let config1 = DistributedConfig {
1028 node_count: 4,
1029 rank: 0,
1030 decomposition: DecompositionStrategy::Adaptive,
1031 communication: CommunicationPattern::AllToAll,
1032 process_grid: vec![4],
1033 local_size: vec![],
1034 max_local_size: 16,
1035 };
1036
1037 let dfft1 = DistributedFFT::new_mock(config1);
1038 let input1 = Array1::from_vec((1..=16).map(|x| x as f64).collect()).into_dyn();
1039 let result1 = dfft1.adaptive_decomposition(&input1, true);
1040 assert!(result1.is_ok());
1041
1042 let config2 = DistributedConfig {
1044 node_count: 4,
1045 rank: 0,
1046 decomposition: DecompositionStrategy::Adaptive,
1047 communication: CommunicationPattern::AllToAll,
1048 process_grid: vec![2, 2],
1049 local_size: vec![],
1050 max_local_size: 16,
1051 };
1052
1053 let dfft2 = DistributedFFT::new_mock(config2);
1054 let input2 = Array2::from_shape_vec((4, 4), (1..=16).map(|x| x as f64).collect())
1055 .expect("Operation failed")
1056 .into_dyn();
1057 let result2 = dfft2.adaptive_decomposition(&input2, true);
1058 assert!(result2.is_ok());
1059 }
1060}