1use crate::error::{FFTError, FFTResult};
9use ndarray::{Array2, ArrayD, Axis, IxDyn};
10use num_complex::Complex64;
11use num_traits::NumCast;
12use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
13use scirs2_core::safe_ops::{safe_divide, safe_sqrt};
14use std::fmt::Debug;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum NormMode {
22 None,
24 Backward,
26 Ortho,
28 Forward,
30}
31
32impl From<&str> for NormMode {
33 fn from(s: &str) -> Self {
34 match s {
35 "backward" => NormMode::Backward,
36 "ortho" => NormMode::Ortho,
37 "forward" => NormMode::Forward,
38 _ => NormMode::None,
39 }
40 }
41}
42
43#[allow(dead_code)]
45pub fn parse_norm_mode(_norm: Option<&str>, isinverse: bool) -> NormMode {
46 match _norm {
47 Some(s) => NormMode::from(s),
48 None if isinverse => NormMode::Backward, None => NormMode::None, }
51}
52
53#[allow(dead_code)]
55fn apply_normalization(data: &mut [Complex64], n: usize, mode: NormMode) -> FFTResult<()> {
56 match mode {
57 NormMode::None => {} NormMode::Backward => {
59 let n_f64 = n as f64;
60 let scale = safe_divide(1.0, n_f64).map_err(|_| {
61 FFTError::ValueError(
62 "Division by zero in backward normalization: FFT size is zero".to_string(),
63 )
64 })?;
65 data.iter_mut().for_each(|c| *c *= scale);
66 }
67 NormMode::Ortho => {
68 let n_f64 = n as f64;
69 let sqrt_n = safe_sqrt(n_f64).map_err(|_| {
70 FFTError::ComputationError(
71 "Invalid square root in orthogonal normalization".to_string(),
72 )
73 })?;
74 let scale = safe_divide(1.0, sqrt_n).map_err(|_| {
75 FFTError::ValueError("Division by zero in orthogonal normalization".to_string())
76 })?;
77 data.iter_mut().for_each(|c| *c *= scale);
78 }
79 NormMode::Forward => {
80 let n_f64 = n as f64;
81 let scale = safe_divide(1.0, n_f64).map_err(|_| {
82 FFTError::ValueError(
83 "Division by zero in forward normalization: FFT size is zero".to_string(),
84 )
85 })?;
86 data.iter_mut().for_each(|c| *c *= scale);
87 }
88 }
89 Ok(())
90}
91
92#[allow(dead_code)]
94fn convert_to_complex<T>(val: T) -> FFTResult<Complex64>
95where
96 T: NumCast + Copy + Debug + 'static,
97{
98 if let Some(real) = num_traits::cast::<T, f64>(val) {
100 return Ok(Complex64::new(real, 0.0));
101 }
102
103 use std::any::Any;
105 if let Some(complex) = (&val as &dyn Any).downcast_ref::<Complex64>() {
106 return Ok(*complex);
107 }
108
109 if let Some(complex32) = (&val as &dyn Any).downcast_ref::<num_complex::Complex<f32>>() {
111 return Ok(Complex64::new(complex32.re as f64, complex32.im as f64));
112 }
113
114 Err(FFTError::ValueError(format!(
115 "Could not convert {val:?} to numeric type"
116 )))
117}
118
119#[allow(dead_code)]
121fn to_complex<T>(input: &[T]) -> FFTResult<Vec<Complex64>>
122where
123 T: NumCast + Copy + Debug + 'static,
124{
125 input.iter().map(|&val| convert_to_complex(val)).collect()
126}
127
128#[allow(dead_code)]
156pub fn fft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
157where
158 T: NumCast + Copy + Debug + 'static,
159{
160 if input.is_empty() {
162 return Err(FFTError::ValueError("Input cannot be empty".to_string()));
163 }
164
165 let input_len = input.len();
167 let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
168
169 let mut data = to_complex(input)?;
171
172 if fft_size != input_len {
174 if fft_size > input_len {
175 data.resize(fft_size, Complex64::new(0.0, 0.0));
177 } else {
178 data.truncate(fft_size);
180 }
181 }
182
183 let mut planner = FftPlanner::new();
185 let fft = planner.plan_fft_forward(fft_size);
186
187 let mut buffer: Vec<RustComplex<f64>> =
189 data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
190
191 fft.process(&mut buffer);
193
194 let result: Vec<Complex64> = buffer
196 .into_iter()
197 .map(|c| Complex64::new(c.re, c.im))
198 .collect();
199
200 Ok(result)
201}
202
203#[allow(dead_code)]
236pub fn ifft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
237where
238 T: NumCast + Copy + Debug + 'static,
239{
240 if input.is_empty() {
242 return Err(FFTError::ValueError("Input cannot be empty".to_string()));
243 }
244
245 let input_len = input.len();
247 let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
248
249 let mut data = to_complex(input)?;
251
252 if fft_size != input_len {
254 if fft_size > input_len {
255 data.resize(fft_size, Complex64::new(0.0, 0.0));
257 } else {
258 data.truncate(fft_size);
260 }
261 }
262
263 let mut planner = FftPlanner::new();
265 let ifft = planner.plan_fft_inverse(fft_size);
266
267 let mut buffer: Vec<RustComplex<f64>> =
269 data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
270
271 ifft.process(&mut buffer);
273
274 let mut result: Vec<Complex64> = buffer
276 .into_iter()
277 .map(|c| Complex64::new(c.re, c.im))
278 .collect();
279
280 apply_normalization(&mut result, fft_size, NormMode::Backward)?;
282
283 if n.is_none() && fft_size > input_len {
285 result.truncate(input_len);
286 }
287
288 Ok(result)
289}
290
291#[allow(dead_code)]
320pub fn fft2<T>(
321 input: &Array2<T>,
322 shape: Option<(usize, usize)>,
323 axes: Option<(i32, i32)>,
324 norm: Option<&str>,
325) -> FFTResult<Array2<Complex64>>
326where
327 T: NumCast + Copy + Debug + 'static,
328{
329 let inputshape = input.shape();
331
332 let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
334
335 let axes = axes.unwrap_or((0, 1));
337
338 if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
340 return Err(FFTError::ValueError("Invalid axes for 2D FFT".to_string()));
341 }
342
343 let norm_mode = parse_norm_mode(norm, false);
345
346 let mut output = Array2::<Complex64>::zeros(outputshape);
348
349 let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
351 for i in 0..inputshape[0] {
352 for j in 0..inputshape[1] {
353 let val = input[[i, j]];
354
355 complex_input[[i, j]] = convert_to_complex(val)?;
357 }
358 }
359
360 let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
362 let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
363 let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
364 let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
365
366 for i in 0..copy_rows {
367 for j in 0..copy_cols {
368 padded[[i, j]] = complex_input[[i, j]];
369 }
370 }
371 padded
372 } else {
373 complex_input
374 };
375
376 let mut planner = FftPlanner::new();
378
379 let row_fft = planner.plan_fft_forward(outputshape.1);
381 for mut row in padded_input.rows_mut() {
382 let mut buffer: Vec<RustComplex<f64>> =
384 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
385
386 row_fft.process(&mut buffer);
388
389 for (i, val) in buffer.iter().enumerate() {
391 row[i] = Complex64::new(val.re, val.im);
392 }
393 }
394
395 let col_fft = planner.plan_fft_forward(outputshape.0);
397 for mut col in padded_input.columns_mut() {
398 let mut buffer: Vec<RustComplex<f64>> =
400 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
401
402 col_fft.process(&mut buffer);
404
405 for (i, val) in buffer.iter().enumerate() {
407 col[i] = Complex64::new(val.re, val.im);
408 }
409 }
410
411 if norm_mode != NormMode::None {
413 let total_elements = outputshape.0 * outputshape.1;
414 let scale = match norm_mode {
415 NormMode::Backward => 1.0 / (total_elements as f64),
416 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
417 NormMode::Forward => 1.0 / (total_elements as f64),
418 NormMode::None => 1.0, };
420
421 padded_input.mapv_inplace(|x| x * scale);
422 }
423
424 output.assign(&padded_input);
426
427 Ok(output)
428}
429
430#[allow(dead_code)]
467pub fn ifft2<T>(
468 input: &Array2<T>,
469 shape: Option<(usize, usize)>,
470 axes: Option<(i32, i32)>,
471 norm: Option<&str>,
472) -> FFTResult<Array2<Complex64>>
473where
474 T: NumCast + Copy + Debug + 'static,
475{
476 let inputshape = input.shape();
478
479 let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
481
482 let axes = axes.unwrap_or((0, 1));
484
485 if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
487 return Err(FFTError::ValueError("Invalid axes for 2D IFFT".to_string()));
488 }
489
490 let norm_mode = parse_norm_mode(norm, true);
492
493 let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
495 for i in 0..inputshape[0] {
496 for j in 0..inputshape[1] {
497 let val = input[[i, j]];
498
499 complex_input[[i, j]] = convert_to_complex(val)?;
501 }
502 }
503
504 let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
506 let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
507 let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
508 let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
509
510 for i in 0..copy_rows {
511 for j in 0..copy_cols {
512 padded[[i, j]] = complex_input[[i, j]];
513 }
514 }
515 padded
516 } else {
517 complex_input
518 };
519
520 let mut planner = FftPlanner::new();
522
523 let row_ifft = planner.plan_fft_inverse(outputshape.1);
525 for mut row in padded_input.rows_mut() {
526 let mut buffer: Vec<RustComplex<f64>> =
528 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
529
530 row_ifft.process(&mut buffer);
532
533 for (i, val) in buffer.iter().enumerate() {
535 row[i] = Complex64::new(val.re, val.im);
536 }
537 }
538
539 let col_ifft = planner.plan_fft_inverse(outputshape.0);
541 for mut col in padded_input.columns_mut() {
542 let mut buffer: Vec<RustComplex<f64>> =
544 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
545
546 col_ifft.process(&mut buffer);
548
549 for (i, val) in buffer.iter().enumerate() {
551 col[i] = Complex64::new(val.re, val.im);
552 }
553 }
554
555 let total_elements = outputshape.0 * outputshape.1;
557 let scale = match norm_mode {
558 NormMode::Backward => 1.0 / (total_elements as f64),
559 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
560 NormMode::Forward => 1.0, NormMode::None => 1.0, };
563
564 if scale != 1.0 {
565 padded_input.mapv_inplace(|x| x * scale);
566 }
567
568 Ok(padded_input)
569}
570
571#[allow(clippy::too_many_arguments)]
604#[allow(dead_code)]
605pub fn fftn<T>(
606 input: &ArrayD<T>,
607 shape: Option<Vec<usize>>,
608 axes: Option<Vec<usize>>,
609 norm: Option<&str>,
610 _overwrite_x: Option<bool>,
611 _workers: Option<usize>,
612) -> FFTResult<ArrayD<Complex64>>
613where
614 T: NumCast + Copy + Debug + 'static,
615{
616 let inputshape = input.shape().to_vec();
617 let input_ndim = inputshape.len();
618
619 let outputshape = shape.unwrap_or_else(|| inputshape.clone());
621
622 if outputshape.len() != input_ndim {
624 return Err(FFTError::ValueError(
625 "Output shape must have the same number of dimensions as input".to_string(),
626 ));
627 }
628
629 let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
631
632 for &axis in &axes {
634 if axis >= input_ndim {
635 return Err(FFTError::ValueError(format!(
636 "Axis {axis} out of bounds for array of dimension {input_ndim}"
637 )));
638 }
639 }
640
641 let norm_mode = parse_norm_mode(norm, false);
643
644 let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
646 for (idx, &val) in input.iter().enumerate() {
647 let mut idx_vec = Vec::with_capacity(input_ndim);
648 let mut remaining = idx;
649
650 for &dim in input.shape().iter().rev() {
651 idx_vec.push(remaining % dim);
652 remaining /= dim;
653 }
654
655 idx_vec.reverse();
656
657 complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
658 }
659
660 let mut result = if inputshape != outputshape {
662 let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
663
664 for (idx, &val) in complex_input.iter().enumerate() {
666 let mut idx_vec = Vec::with_capacity(input_ndim);
667 let mut remaining = idx;
668
669 for &dim in input.shape().iter().rev() {
670 idx_vec.push(remaining % dim);
671 remaining /= dim;
672 }
673
674 idx_vec.reverse();
675
676 let mut in_bounds = true;
677 for (dim, &idx_val) in idx_vec.iter().enumerate() {
678 if idx_val >= outputshape[dim] {
679 in_bounds = false;
680 break;
681 }
682 }
683
684 if in_bounds {
685 padded[IxDyn(&idx_vec)] = val;
686 }
687 }
688
689 padded
690 } else {
691 complex_input
692 };
693
694 let mut planner = FftPlanner::new();
696
697 for &axis in &axes {
699 let axis_len = outputshape[axis];
700 let fft = planner.plan_fft_forward(axis_len);
701
702 let axis = Axis(axis);
704
705 for mut lane in result.lanes_mut(axis) {
706 let mut buffer: Vec<RustComplex<f64>> =
708 lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
709
710 fft.process(&mut buffer);
712
713 for (i, val) in buffer.iter().enumerate() {
715 lane[i] = Complex64::new(val.re, val.im);
716 }
717 }
718 }
719
720 if norm_mode != NormMode::None {
722 let total_elements: usize = outputshape.iter().product();
723 let scale = match norm_mode {
724 NormMode::Backward => 1.0 / (total_elements as f64),
725 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
726 NormMode::Forward => 1.0 / (total_elements as f64),
727 NormMode::None => 1.0, };
729
730 result.mapv_inplace(|_x| _x * scale);
731 }
732
733 Ok(result)
734}
735
736#[allow(clippy::too_many_arguments)]
785#[allow(dead_code)]
786pub fn ifftn<T>(
787 input: &ArrayD<T>,
788 shape: Option<Vec<usize>>,
789 axes: Option<Vec<usize>>,
790 norm: Option<&str>,
791 _overwrite_x: Option<bool>,
792 _workers: Option<usize>,
793) -> FFTResult<ArrayD<Complex64>>
794where
795 T: NumCast + Copy + Debug + 'static,
796{
797 let inputshape = input.shape().to_vec();
798 let input_ndim = inputshape.len();
799
800 let outputshape = shape.unwrap_or_else(|| inputshape.clone());
802
803 if outputshape.len() != input_ndim {
805 return Err(FFTError::ValueError(
806 "Output shape must have the same number of dimensions as input".to_string(),
807 ));
808 }
809
810 let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
812
813 for &axis in &axes {
815 if axis >= input_ndim {
816 return Err(FFTError::ValueError(format!(
817 "Axis {axis} out of bounds for array of dimension {input_ndim}"
818 )));
819 }
820 }
821
822 let norm_mode = parse_norm_mode(norm, true);
824
825 let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
827 for (idx, &val) in input.iter().enumerate() {
828 let mut idx_vec = Vec::with_capacity(input_ndim);
829 let mut remaining = idx;
830
831 for &dim in input.shape().iter().rev() {
832 idx_vec.push(remaining % dim);
833 remaining /= dim;
834 }
835
836 idx_vec.reverse();
837
838 complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
840 }
841
842 let mut result = if inputshape != outputshape {
844 let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
845
846 for (idx, &val) in complex_input.iter().enumerate() {
848 let mut idx_vec = Vec::with_capacity(input_ndim);
849 let mut remaining = idx;
850
851 for &dim in input.shape().iter().rev() {
852 idx_vec.push(remaining % dim);
853 remaining /= dim;
854 }
855
856 idx_vec.reverse();
857
858 let mut in_bounds = true;
859 for (dim, &idx_val) in idx_vec.iter().enumerate() {
860 if idx_val >= outputshape[dim] {
861 in_bounds = false;
862 break;
863 }
864 }
865
866 if in_bounds {
867 padded[IxDyn(&idx_vec)] = val;
868 }
869 }
870
871 padded
872 } else {
873 complex_input
874 };
875
876 let mut planner = FftPlanner::new();
878
879 for &axis in &axes {
881 let axis_len = outputshape[axis];
882 let ifft = planner.plan_fft_inverse(axis_len);
883
884 let axis = Axis(axis);
886
887 for mut lane in result.lanes_mut(axis) {
888 let mut buffer: Vec<RustComplex<f64>> =
890 lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
891
892 ifft.process(&mut buffer);
894
895 for (i, val) in buffer.iter().enumerate() {
897 lane[i] = Complex64::new(val.re, val.im);
898 }
899 }
900 }
901
902 if norm_mode != NormMode::None {
904 let total_elements: usize = axes.iter().map(|&a| outputshape[a]).product();
905 let scale = match norm_mode {
906 NormMode::Backward => 1.0 / (total_elements as f64),
907 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
908 NormMode::Forward => 1.0, NormMode::None => 1.0, };
911
912 if scale != 1.0 {
913 result.mapv_inplace(|_x| _x * scale);
914 }
915 }
916
917 Ok(result)
918}