1use crate::error::{FFTError, FFTResult};
8use crate::fft::fft;
9use scirs2_core::ndarray::{s, ArrayBase, ArrayD, Data, Dimension, IxDyn};
10use scirs2_core::numeric::Complex64;
11use scirs2_core::numeric::NumCast;
12use std::fmt::Debug;
13use std::sync::Arc;
14use std::time::Instant;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum DecompositionStrategy {
19 Slab,
21 Pencil,
23 Volumetric,
25 Adaptive,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum CommunicationPattern {
32 AllToAll,
34 PointToPoint,
36 Neighbor,
38 Hybrid,
40}
41
42#[derive(Debug, Clone)]
44pub struct DistributedConfig {
45 pub node_count: usize,
47 pub rank: usize,
49 pub decomposition: DecompositionStrategy,
51 pub communication: CommunicationPattern,
53 pub process_grid: Vec<usize>,
55 pub local_size: Vec<usize>,
57 pub max_local_size: usize,
59}
60
61impl Default for DistributedConfig {
62 fn default() -> Self {
63 Self {
64 node_count: 1,
65 rank: 0,
66 decomposition: DecompositionStrategy::Slab,
67 communication: CommunicationPattern::AllToAll,
68 process_grid: vec![1],
69 local_size: vec![],
70 max_local_size: 1024, }
72 }
73}
74
75pub struct DistributedFFT {
77 config: DistributedConfig,
79 #[allow(dead_code)]
81 communicator: Arc<dyn Communicator>,
82}
83
84pub trait Communicator: Send + Sync + Debug {
86 fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()>;
88
89 fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>>;
91
92 fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>>;
94
95 fn barrier(&self) -> FFTResult<()>;
97
98 fn size(&self) -> usize;
100
101 fn rank(&self) -> usize;
103}
104
105impl DistributedFFT {
106 pub fn new(config: DistributedConfig, communicator: Arc<dyn Communicator>) -> Self {
108 Self {
109 config,
110 communicator,
111 }
112 }
113
114 pub fn distributed_fft<S, D>(&self, input: &ArrayBase<S, D>) -> FFTResult<ArrayD<Complex64>>
116 where
117 S: Data,
118 D: Dimension,
119 S::Elem: Into<Complex64> + Copy + Debug + NumCast,
120 {
121 let start = Instant::now();
123
124 let input_dyn = input.to_owned().into_dyn();
126
127 let local_data = self.decompose_data(&input_dyn)?;
129
130 let decomp_time = start.elapsed();
132
133 let mut local_result = ArrayD::zeros(local_data.dim());
135 self.perform_local_fft(&local_data, &mut local_result)?;
136
137 let local_fft_time = start.elapsed() - decomp_time;
139
140 let exchanged_data = self.exchange_data(&local_result)?;
142
143 let comm_time = start.elapsed() - decomp_time - local_fft_time;
145
146 let final_result = self.finalize_result(&exchanged_data, input.shape())?;
148
149 let total_time = start.elapsed();
151
152 if cfg!(debug_assertions) {
154 println!("Distributed FFT Performance:");
155 println!(" Decomposition: {:?}", decomp_time);
156 println!(" Local FFT: {:?}", local_fft_time);
157 println!(" Communication: {:?}", comm_time);
158 println!(" Total time: {:?}", total_time);
159 }
160
161 Ok(final_result)
162 }
163
164 pub fn decompose_data<T>(&self, input: &ArrayD<T>) -> FFTResult<ArrayD<Complex64>>
166 where
167 T: Into<Complex64> + Copy + NumCast,
168 {
169 let is_testing = cfg!(test) || std::env::var("RUST_TEST").is_ok();
171
172 match self.config.decomposition {
173 DecompositionStrategy::Slab => self.slab_decomposition(input, is_testing),
174 DecompositionStrategy::Pencil => self.pencil_decomposition(input, is_testing),
175 DecompositionStrategy::Volumetric => self.volumetric_decomposition(input, is_testing),
176 DecompositionStrategy::Adaptive => self.adaptive_decomposition(input, is_testing),
177 }
178 }
179
180 fn perform_local_fft(
182 &self,
183 input: &ArrayD<Complex64>,
184 output: &mut ArrayD<Complex64>,
185 ) -> FFTResult<()> {
186 if input.ndim() == 1
188 || (input.ndim() >= 2 && self.config.decomposition == DecompositionStrategy::Slab)
189 {
190 if input.ndim() >= 2 {
192 for i in 0..input.shape()[0].min(self.config.max_local_size) {
193 let row = input.slice(s![i, ..]).to_vec();
194 let result = fft(&row, None)?;
195 let mut output_row = output.slice_mut(s![i, ..]);
196 for (j, val) in result.iter().enumerate().take(output_row.len()) {
197 output_row[j] = *val;
198 }
199 }
200 } else {
201 let result = fft(input.as_slice().unwrap_or(&[]), None)?;
203 for (i, val) in result.iter().enumerate().take(output.len()) {
204 output[i] = *val;
205 }
206 }
207 } else if input.ndim() >= 2 && self.config.decomposition == DecompositionStrategy::Pencil {
208 for i in 0..input.shape()[0].min(self.config.max_local_size) {
211 for j in 0..input.shape()[1].min(self.config.max_local_size) {
212 let column = input.slice(s![i, j, ..]).to_vec();
213 let result = fft(&column, None)?;
214 let mut output_col = output.slice_mut(s![i, j, ..]);
215 for (k, val) in result.iter().enumerate().take(output_col.len()) {
216 output_col[k] = *val;
217 }
218 }
219 }
220 } else {
221 return Err(FFTError::DimensionError(format!(
223 "Unsupported decomposition strategy for input of dimension {}",
224 input.ndim()
225 )));
226 }
227
228 Ok(())
229 }
230
231 fn exchange_data(&self, localresult: &ArrayD<Complex64>) -> FFTResult<ArrayD<Complex64>> {
233 if self.config.node_count == 1 || self.config.rank == 0 {
239 return Ok(localresult.clone());
240 }
241
242 match self.config.communication {
245 CommunicationPattern::AllToAll => {
246 let flattened: Vec<Complex64> = localresult.iter().copied().collect();
248
249 let _result = self.communicator.all_to_all(&flattened)?;
251
252 Ok(localresult.clone())
254 }
255 CommunicationPattern::PointToPoint => {
256 Ok(localresult.clone())
259 }
260 _ => {
261 Ok(localresult.clone())
263 }
264 }
265 }
266
267 fn finalize_result(
269 &self,
270 exchanged_data: &ArrayD<Complex64>,
271 output_dim: &[usize],
272 ) -> FFTResult<ArrayD<Complex64>> {
273 if self.config.node_count == 1 || self.config.rank == 0 {
278 let limitedshape: Vec<usize> = output_dim
280 .iter()
281 .map(|&d| d.min(self.config.max_local_size))
282 .collect();
283
284 let mut output = ArrayD::zeros(IxDyn(&limitedshape));
286
287 if output_dim.len() == limitedshape.len() {
289 let mut all_match = true;
290 for (a, b) in output_dim.iter().zip(limitedshape.iter()) {
291 if a != b {
292 all_match = false;
293 break;
294 }
295 }
296
297 if all_match && !output.is_empty() && !exchanged_data.is_empty() {
298 let flat_output = output.as_slice_mut().expect("Operation failed");
300 for (i, &val) in exchanged_data.iter().enumerate().take(flat_output.len()) {
301 flat_output[i] = val;
302 }
303 } else {
304 if !output.is_empty() && !exchanged_data.is_empty() {
308 let flat_output = output.as_slice_mut().expect("Operation failed");
309 let copy_len = flat_output.len().min(exchanged_data.len());
310
311 for i in 0..copy_len {
312 flat_output[i] =
313 *exchanged_data.iter().nth(i).expect("Operation failed");
314 }
315 }
316 }
317 }
318
319 Ok(output)
320 } else {
321 Err(FFTError::ValueError(
324 "Only the root node (rank 0) produces the final output".to_string(),
325 ))
326 }
327 }
328
329 fn slab_decomposition<T>(
332 &self,
333 input: &ArrayD<T>,
334 is_testing: bool,
335 ) -> FFTResult<ArrayD<Complex64>>
336 where
337 T: Into<Complex64> + Copy + NumCast,
338 {
339 let shape = input.shape();
340
341 let max_size = if is_testing {
343 self.config.max_local_size
344 } else {
345 usize::MAX
346 };
347
348 if shape.is_empty() {
350 return Err(FFTError::DimensionError(
351 "Cannot perform FFT on empty array".to_string(),
352 ));
353 }
354
355 let total_slabs = shape[0];
357 let slabs_per_node = total_slabs.div_ceil(self.config.node_count);
358
359 let my_start = self.config.rank * slabs_per_node;
361 let my_end = (my_start + slabs_per_node).min(total_slabs);
362
363 if my_start >= total_slabs {
365 return Ok(ArrayD::zeros(IxDyn(&[0])));
367 }
368
369 let actual_end = my_end.min(my_start + max_size);
371
372 let mut myshape: Vec<usize> = shape.to_vec();
374 myshape[0] = actual_end - my_start;
375
376 let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
378
379 if input.ndim() == 1 {
381 for i in my_start..actual_end {
383 let input_idx = IxDyn(&[i]);
384 let output_idx = IxDyn(&[i - my_start]);
385 let val: Complex64 =
386 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
387 output[output_idx] = val;
388 }
389 } else if input.ndim() == 2 {
390 for i in my_start..actual_end {
392 for j in 0..shape[1].min(max_size) {
393 let input_idx = IxDyn(&[i, j]);
394 let output_idx = IxDyn(&[i - my_start, j]);
395 let val: Complex64 =
396 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
397 output[output_idx] = val;
398 }
399 }
400 } else if input.ndim() == 3 {
401 for i in my_start..actual_end {
403 for j in 0..shape[1].min(max_size) {
404 for k in 0..shape[2].min(max_size) {
405 let input_idx = IxDyn(&[i, j, k]);
406 let output_idx = IxDyn(&[i - my_start, j, k]);
407 let val: Complex64 =
408 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
409 output[output_idx] = val;
410 }
411 }
412 }
413 } else {
414 return Err(FFTError::DimensionError(
417 "Dimensions higher than 3 not yet implemented for slab decomposition".to_string(),
418 ));
419 }
420
421 Ok(output)
422 }
423
424 fn pencil_decomposition<T>(
425 &self,
426 input: &ArrayD<T>,
427 is_testing: bool,
428 ) -> FFTResult<ArrayD<Complex64>>
429 where
430 T: Into<Complex64> + Copy + NumCast,
431 {
432 let shape = input.shape();
433
434 let max_size = if is_testing {
436 self.config.max_local_size
437 } else {
438 usize::MAX
439 };
440
441 if shape.len() < 2 {
443 return Err(FFTError::DimensionError(
444 "Pencil decomposition requires at least 2D input".to_string(),
445 ));
446 }
447
448 let process_grid = &self.config.process_grid;
451 if process_grid.len() < 2 {
452 return Err(FFTError::ValueError(
453 "Pencil decomposition requires a 2D process grid".to_string(),
454 ));
455 }
456
457 let p1 = process_grid[0];
458 let p2 = process_grid[1];
459
460 if p1 * p2 != self.config.node_count {
461 return Err(FFTError::ValueError(format!(
462 "Process grid ({} x {}) doesn't match node count ({})",
463 p1, p2, self.config.node_count
464 )));
465 }
466
467 let my_row = self.config.rank / p2;
469 let my_col = self.config.rank % p2;
470
471 let n1 = shape[0];
473 let n2 = shape[1];
474
475 let rows_per_node = n1.div_ceil(p1);
476 let cols_per_node = n2.div_ceil(p2);
477
478 let my_start_row = my_row * rows_per_node;
479 let my_end_row = (my_start_row + rows_per_node).min(n1);
480
481 let my_start_col = my_col * cols_per_node;
482 let my_end_col = (my_start_col + cols_per_node).min(n2);
483
484 if my_start_row >= n1 || my_start_col >= n2 {
486 return Ok(ArrayD::zeros(IxDyn(&[0])));
488 }
489
490 let actual_end_row = my_end_row.min(my_start_row + max_size);
492 let actual_end_col = my_end_col.min(my_start_col + max_size);
493
494 let mut myshape: Vec<usize> = shape.to_vec();
496 myshape[0] = actual_end_row - my_start_row;
497 myshape[1] = actual_end_col - my_start_col;
498
499 let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
501
502 if input.ndim() == 2 {
504 for i in my_start_row..actual_end_row {
506 for j in my_start_col..actual_end_col {
507 let input_idx = IxDyn(&[i, j]);
508 let output_idx = IxDyn(&[i - my_start_row, j - my_start_col]);
509 let val: Complex64 =
510 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
511 output[output_idx] = val;
512 }
513 }
514 } else if input.ndim() == 3 {
515 for i in my_start_row..actual_end_row {
517 for j in my_start_col..actual_end_col {
518 for k in 0..shape[2].min(max_size) {
519 let input_idx = IxDyn(&[i, j, k]);
520 let output_idx = IxDyn(&[i - my_start_row, j - my_start_col, k]);
521 let val: Complex64 =
522 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
523 output[output_idx] = val;
524 }
525 }
526 }
527 } else {
528 return Err(FFTError::DimensionError(
530 "Dimensions higher than 3 not yet implemented for pencil decomposition".to_string(),
531 ));
532 }
533
534 Ok(output)
535 }
536
537 fn volumetric_decomposition<T>(
538 &self,
539 input: &ArrayD<T>,
540 is_testing: bool,
541 ) -> FFTResult<ArrayD<Complex64>>
542 where
543 T: Into<Complex64> + Copy + NumCast,
544 {
545 let shape = input.shape();
546
547 let max_size = if is_testing {
549 self.config.max_local_size
550 } else {
551 usize::MAX
552 };
553
554 if shape.len() < 3 {
556 return Err(FFTError::DimensionError(
557 "Volumetric decomposition requires at least 3D input".to_string(),
558 ));
559 }
560
561 let process_grid = &self.config.process_grid;
564 if process_grid.len() < 3 {
565 return Err(FFTError::ValueError(
566 "Volumetric decomposition requires a 3D process grid".to_string(),
567 ));
568 }
569
570 let p1 = process_grid[0];
571 let p2 = process_grid[1];
572 let p3 = process_grid[2];
573
574 if p1 * p2 * p3 != self.config.node_count {
575 return Err(FFTError::ValueError(format!(
576 "Process grid ({} x {} x {}) doesn't match node count ({})",
577 p1, p2, p3, self.config.node_count
578 )));
579 }
580
581 let my_plane = self.config.rank / (p2 * p3);
583 let remainder = self.config.rank % (p2 * p3);
584 let my_row = remainder / p3;
585 let my_col = remainder % p3;
586
587 let n1 = shape[0];
589 let n2 = shape[1];
590 let n3 = shape[2];
591
592 let planes_per_node = n1.div_ceil(p1);
593 let rows_per_node = n2.div_ceil(p2);
594 let cols_per_node = n3.div_ceil(p3);
595
596 let my_start_plane = my_plane * planes_per_node;
597 let my_end_plane = (my_start_plane + planes_per_node).min(n1);
598
599 let my_start_row = my_row * rows_per_node;
600 let my_end_row = (my_start_row + rows_per_node).min(n2);
601
602 let my_start_col = my_col * cols_per_node;
603 let my_end_col = (my_start_col + cols_per_node).min(n3);
604
605 if my_start_plane >= n1 || my_start_row >= n2 || my_start_col >= n3 {
607 return Ok(ArrayD::zeros(IxDyn(&[0])));
609 }
610
611 let actual_end_plane = my_end_plane.min(my_start_plane + max_size);
613 let actual_end_row = my_end_row.min(my_start_row + max_size);
614 let actual_end_col = my_end_col.min(my_start_col + max_size);
615
616 let mut myshape: Vec<usize> = shape.to_vec();
618 myshape[0] = actual_end_plane - my_start_plane;
619 myshape[1] = actual_end_row - my_start_row;
620 myshape[2] = actual_end_col - my_start_col;
621
622 let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
624
625 if input.ndim() == 3 {
627 for i in my_start_plane..actual_end_plane {
629 for j in my_start_row..actual_end_row {
630 for k in my_start_col..actual_end_col {
631 let input_idx = IxDyn(&[i, j, k]);
632 let output_idx =
633 IxDyn(&[i - my_start_plane, j - my_start_row, k - my_start_col]);
634 let val: Complex64 =
635 NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
636 output[output_idx] = val;
637 }
638 }
639 }
640 } else {
641 return Err(FFTError::DimensionError(
643 "Dimensions higher than 3 not yet implemented for volumetric decomposition"
644 .to_string(),
645 ));
646 }
647
648 Ok(output)
649 }
650
651 fn adaptive_decomposition<T>(
652 &self,
653 input: &ArrayD<T>,
654 is_testing: bool,
655 ) -> FFTResult<ArrayD<Complex64>>
656 where
657 T: Into<Complex64> + Copy + NumCast,
658 {
659 let ndim = input.ndim();
660
661 if ndim == 1 || self.config.node_count == 1 {
663 self.slab_decomposition(input, is_testing)
665 } else if ndim == 2 || self.config.node_count < 8 {
666 self.slab_decomposition(input, is_testing)
668 } else if ndim == 3 && self.config.node_count >= 8 {
669 let mut config = self.config.clone();
672 if config.process_grid.len() < 2 {
673 let sqrt_nodes = (self.config.node_count as f64).sqrt().floor() as usize;
674 config.process_grid = vec![sqrt_nodes, self.config.node_count / sqrt_nodes];
675 }
676
677 let temp_dfft = DistributedFFT {
679 config,
680 communicator: self.communicator.clone(),
681 };
682
683 temp_dfft.pencil_decomposition(input, is_testing)
684 } else if ndim >= 3 && self.config.node_count >= 27 {
685 let mut config = self.config.clone();
688 if config.process_grid.len() < 3 {
689 let cbrt_nodes = (self.config.node_count as f64).cbrt().floor() as usize;
690 let remaining = self.config.node_count / cbrt_nodes;
691 let sqrt_remaining = (remaining as f64).sqrt().floor() as usize;
692 config.process_grid = vec![cbrt_nodes, sqrt_remaining, remaining / sqrt_remaining];
693 }
694
695 let temp_dfft = DistributedFFT {
697 config,
698 communicator: self.communicator.clone(),
699 };
700
701 temp_dfft.volumetric_decomposition(input, is_testing)
702 } else {
703 self.slab_decomposition(input, is_testing)
705 }
706 }
707
708 #[cfg(test)]
710 pub fn new_mock(config: DistributedConfig) -> Self {
711 let communicator = Arc::new(MockCommunicator::new(config.node_count, config.rank));
712 Self {
713 config,
714 communicator,
715 }
716 }
717}
718
719#[derive(Debug)]
721pub struct BasicCommunicator {
722 size: usize,
724 rank: usize,
726}
727
728impl BasicCommunicator {
729 pub fn new(size: usize, rank: usize) -> Self {
731 Self { size, rank }
732 }
733}
734
735impl Communicator for BasicCommunicator {
736 fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()> {
737 let _ = tag; if dest >= self.size {
739 return Err(FFTError::ValueError(format!(
740 "Invalid destination rank: {} (size: {})",
741 dest, self.size
742 )));
743 }
744
745 if data.is_empty() {
748 return Err(FFTError::ValueError("Cannot send empty data".to_string()));
749 }
750
751 Ok(())
752 }
753
754 fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>> {
755 let _ = tag; if src >= self.size {
757 return Err(FFTError::ValueError(format!(
758 "Invalid source rank: {} (size: {})",
759 src, self.size
760 )));
761 }
762
763 Ok(vec![Complex64::new(0.0, 0.0); size])
766 }
767
768 fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>> {
769 Ok(senddata.to_vec())
772 }
773
774 fn barrier(&self) -> FFTResult<()> {
775 Ok(())
778 }
779
780 fn size(&self) -> usize {
781 self.size
782 }
783
784 fn rank(&self) -> usize {
785 self.rank
786 }
787}
788
789#[derive(Debug)]
791pub struct MockCommunicator {
792 size: usize,
793 rank: usize,
794}
795
796impl MockCommunicator {
797 pub fn new(size: usize, rank: usize) -> Self {
799 Self { size, rank }
800 }
801}
802
803impl Communicator for MockCommunicator {
804 fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()> {
805 let _ = tag; if dest >= self.size {
807 return Err(FFTError::ValueError(format!(
808 "Invalid destination rank: {} (size: {})",
809 dest, self.size
810 )));
811 }
812
813 Ok(())
815 }
816
817 fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>> {
818 let _ = tag; if src >= self.size {
820 return Err(FFTError::ValueError(format!(
821 "Invalid source rank: {} (size: {})",
822 src, self.size
823 )));
824 }
825
826 Ok(vec![Complex64::new(0.0, 0.0); size])
828 }
829
830 fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>> {
831 Ok(senddata.to_vec())
833 }
834
835 fn barrier(&self) -> FFTResult<()> {
836 Ok(())
838 }
839
840 fn size(&self) -> usize {
841 self.size
842 }
843
844 fn rank(&self) -> usize {
845 self.rank
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852 use scirs2_core::ndarray::{Array1, Array2};
853
854 #[test]
855 fn test_distributed_config_default() {
856 let config = DistributedConfig::default();
857 assert_eq!(config.node_count, 1);
858 assert_eq!(config.rank, 0);
859 assert_eq!(config.decomposition, DecompositionStrategy::Slab);
860 }
861
862 #[test]
863 fn test_mock_communicator() {
864 let comm = MockCommunicator::new(4, 0);
865 assert_eq!(comm.size(), 4);
866 assert_eq!(comm.rank(), 0);
867
868 let data = vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
870 let result = comm.send(&data, 1, 0);
871 assert!(result.is_ok());
872
873 let result = comm.send(&data, 4, 0);
875 assert!(result.is_err());
876
877 let result = comm.recv(1, 0, 2);
879 assert!(result.is_ok());
880 assert_eq!(result.expect("Operation failed").len(), 2);
881
882 let result = comm.recv(4, 0, 2);
884 assert!(result.is_err());
885
886 let result = comm.all_to_all(&data);
888 assert!(result.is_ok());
889 assert_eq!(result.expect("Operation failed"), data);
890
891 let result = comm.barrier();
893 assert!(result.is_ok());
894 }
895
896 #[test]
897 fn test_slab_decomposition_1d() {
898 let config = DistributedConfig {
899 node_count: 2,
900 rank: 0,
901 decomposition: DecompositionStrategy::Slab,
902 communication: CommunicationPattern::AllToAll,
903 process_grid: vec![2],
904 local_size: vec![],
905 max_local_size: 16,
906 };
907
908 let dfft = DistributedFFT::new_mock(config);
909
910 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]).into_dyn();
911 let result = dfft.slab_decomposition(&input, true);
912 assert!(result.is_ok());
913
914 let local_data = result.expect("Operation failed");
915 assert_eq!(local_data.ndim(), 1);
916 assert_eq!(local_data.shape()[0], 2); }
918
919 #[test]
920 fn test_slab_decomposition_2d() {
921 let config = DistributedConfig {
922 node_count: 2,
923 rank: 0,
924 decomposition: DecompositionStrategy::Slab,
925 communication: CommunicationPattern::AllToAll,
926 process_grid: vec![2],
927 local_size: vec![],
928 max_local_size: 16,
929 };
930
931 let dfft = DistributedFFT::new_mock(config);
932
933 let input = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
934 .expect("Operation failed")
935 .into_dyn();
936 let result = dfft.slab_decomposition(&input, true);
937 assert!(result.is_ok());
938
939 let local_data = result.expect("Operation failed");
940 assert_eq!(local_data.ndim(), 2);
941 assert_eq!(local_data.shape()[0], 2); assert_eq!(local_data.shape()[1], 2); }
944
945 #[test]
946 fn test_pencil_decomposition_2d() {
947 let config = DistributedConfig {
948 node_count: 4,
949 rank: 0,
950 decomposition: DecompositionStrategy::Pencil,
951 communication: CommunicationPattern::AllToAll,
952 process_grid: vec![2, 2],
953 local_size: vec![],
954 max_local_size: 16,
955 };
956
957 let dfft = DistributedFFT::new_mock(config);
958
959 let input = Array2::from_shape_vec((4, 4), (1..=16).map(|x| x as f64).collect())
960 .expect("Operation failed")
961 .into_dyn();
962 let result = dfft.pencil_decomposition(&input, true);
963 assert!(result.is_ok());
964
965 let local_data = result.expect("Operation failed");
966 assert_eq!(local_data.ndim(), 2);
967 assert_eq!(local_data.shape()[0], 2); assert_eq!(local_data.shape()[1], 2); }
970
971 #[test]
972 fn test_adaptive_decomposition() {
973 let config1 = DistributedConfig {
975 node_count: 4,
976 rank: 0,
977 decomposition: DecompositionStrategy::Adaptive,
978 communication: CommunicationPattern::AllToAll,
979 process_grid: vec![4],
980 local_size: vec![],
981 max_local_size: 16,
982 };
983
984 let dfft1 = DistributedFFT::new_mock(config1);
985 let input1 = Array1::from_vec((1..=16).map(|x| x as f64).collect()).into_dyn();
986 let result1 = dfft1.adaptive_decomposition(&input1, true);
987 assert!(result1.is_ok());
988
989 let config2 = DistributedConfig {
991 node_count: 4,
992 rank: 0,
993 decomposition: DecompositionStrategy::Adaptive,
994 communication: CommunicationPattern::AllToAll,
995 process_grid: vec![2, 2],
996 local_size: vec![],
997 max_local_size: 16,
998 };
999
1000 let dfft2 = DistributedFFT::new_mock(config2);
1001 let input2 = Array2::from_shape_vec((4, 4), (1..=16).map(|x| x as f64).collect())
1002 .expect("Operation failed")
1003 .into_dyn();
1004 let result2 = dfft2.adaptive_decomposition(&input2, true);
1005 assert!(result2.is_ok());
1006 }
1007}