1use crate::error::{FFTError, FFTResult};
9use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
10use scirs2_core::ndarray::{Array2, ArrayD, Axis, IxDyn};
11use scirs2_core::numeric::Complex64;
12use scirs2_core::numeric::NumCast;
13use scirs2_core::safe_ops::{safe_divide, safe_sqrt};
14use std::fmt::Debug;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum NormMode {
22 None,
24 Backward,
26 Ortho,
28 Forward,
30}
31
32impl From<&str> for NormMode {
33 fn from(s: &str) -> Self {
34 match s {
35 "backward" => NormMode::Backward,
36 "ortho" => NormMode::Ortho,
37 "forward" => NormMode::Forward,
38 _ => NormMode::None,
39 }
40 }
41}
42
43#[allow(dead_code)]
45pub fn parse_norm_mode(_norm: Option<&str>, isinverse: bool) -> NormMode {
46 match _norm {
47 Some(s) => NormMode::from(s),
48 None if isinverse => NormMode::Backward, None => NormMode::None, }
51}
52
53#[allow(dead_code)]
55fn apply_normalization(data: &mut [Complex64], n: usize, mode: NormMode) -> FFTResult<()> {
56 match mode {
57 NormMode::None => {} NormMode::Backward => {
59 let n_f64 = n as f64;
60 let scale = safe_divide(1.0, n_f64).map_err(|_| {
61 FFTError::ValueError(
62 "Division by zero in backward normalization: FFT size is zero".to_string(),
63 )
64 })?;
65 data.iter_mut().for_each(|c| *c *= scale);
66 }
67 NormMode::Ortho => {
68 let n_f64 = n as f64;
69 let sqrt_n = safe_sqrt(n_f64).map_err(|_| {
70 FFTError::ComputationError(
71 "Invalid square root in orthogonal normalization".to_string(),
72 )
73 })?;
74 let scale = safe_divide(1.0, sqrt_n).map_err(|_| {
75 FFTError::ValueError("Division by zero in orthogonal normalization".to_string())
76 })?;
77 data.iter_mut().for_each(|c| *c *= scale);
78 }
79 NormMode::Forward => {
80 let n_f64 = n as f64;
81 let scale = safe_divide(1.0, n_f64).map_err(|_| {
82 FFTError::ValueError(
83 "Division by zero in forward normalization: FFT size is zero".to_string(),
84 )
85 })?;
86 data.iter_mut().for_each(|c| *c *= scale);
87 }
88 }
89 Ok(())
90}
91
92#[allow(dead_code)]
94fn convert_to_complex<T>(val: T) -> FFTResult<Complex64>
95where
96 T: NumCast + Copy + Debug + 'static,
97{
98 if let Some(real) = NumCast::from(val) {
100 return Ok(Complex64::new(real, 0.0));
101 }
102
103 use std::any::Any;
105 if let Some(complex) = (&val as &dyn Any).downcast_ref::<Complex64>() {
106 return Ok(*complex);
107 }
108
109 if let Some(complex32) = (&val as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>()
111 {
112 return Ok(Complex64::new(complex32.re as f64, complex32.im as f64));
113 }
114
115 Err(FFTError::ValueError(format!(
116 "Could not convert {val:?} to numeric type"
117 )))
118}
119
120#[allow(dead_code)]
122fn to_complex<T>(input: &[T]) -> FFTResult<Vec<Complex64>>
123where
124 T: NumCast + Copy + Debug + 'static,
125{
126 input.iter().map(|&val| convert_to_complex(val)).collect()
127}
128
129#[allow(dead_code)]
157pub fn fft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
158where
159 T: NumCast + Copy + Debug + 'static,
160{
161 if input.is_empty() {
163 return Err(FFTError::ValueError("Input cannot be empty".to_string()));
164 }
165
166 let input_len = input.len();
168 let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
169
170 let mut data = to_complex(input)?;
172
173 if fft_size != input_len {
175 if fft_size > input_len {
176 data.resize(fft_size, Complex64::new(0.0, 0.0));
178 } else {
179 data.truncate(fft_size);
181 }
182 }
183
184 let mut planner = FftPlanner::new();
186 let fft = planner.plan_fft_forward(fft_size);
187
188 let mut buffer: Vec<RustComplex<f64>> =
190 data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
191
192 fft.process(&mut buffer);
194
195 let result: Vec<Complex64> = buffer
197 .into_iter()
198 .map(|c| Complex64::new(c.re, c.im))
199 .collect();
200
201 Ok(result)
202}
203
204#[allow(dead_code)]
237pub fn ifft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
238where
239 T: NumCast + Copy + Debug + 'static,
240{
241 if input.is_empty() {
243 return Err(FFTError::ValueError("Input cannot be empty".to_string()));
244 }
245
246 let input_len = input.len();
248 let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
249
250 let mut data = to_complex(input)?;
252
253 if fft_size != input_len {
255 if fft_size > input_len {
256 data.resize(fft_size, Complex64::new(0.0, 0.0));
258 } else {
259 data.truncate(fft_size);
261 }
262 }
263
264 let mut planner = FftPlanner::new();
266 let ifft = planner.plan_fft_inverse(fft_size);
267
268 let mut buffer: Vec<RustComplex<f64>> =
270 data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
271
272 ifft.process(&mut buffer);
274
275 let mut result: Vec<Complex64> = buffer
277 .into_iter()
278 .map(|c| Complex64::new(c.re, c.im))
279 .collect();
280
281 apply_normalization(&mut result, fft_size, NormMode::Backward)?;
283
284 if n.is_none() && fft_size > input_len {
286 result.truncate(input_len);
287 }
288
289 Ok(result)
290}
291
292#[allow(dead_code)]
321pub fn fft2<T>(
322 input: &Array2<T>,
323 shape: Option<(usize, usize)>,
324 axes: Option<(i32, i32)>,
325 norm: Option<&str>,
326) -> FFTResult<Array2<Complex64>>
327where
328 T: NumCast + Copy + Debug + 'static,
329{
330 let inputshape = input.shape();
332
333 let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
335
336 let axes = axes.unwrap_or((0, 1));
338
339 if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
341 return Err(FFTError::ValueError("Invalid axes for 2D FFT".to_string()));
342 }
343
344 let norm_mode = parse_norm_mode(norm, false);
346
347 let mut output = Array2::<Complex64>::zeros(outputshape);
349
350 let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
352 for i in 0..inputshape[0] {
353 for j in 0..inputshape[1] {
354 let val = input[[i, j]];
355
356 complex_input[[i, j]] = convert_to_complex(val)?;
358 }
359 }
360
361 let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
363 let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
364 let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
365 let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
366
367 for i in 0..copy_rows {
368 for j in 0..copy_cols {
369 padded[[i, j]] = complex_input[[i, j]];
370 }
371 }
372 padded
373 } else {
374 complex_input
375 };
376
377 let mut planner = FftPlanner::new();
379
380 let row_fft = planner.plan_fft_forward(outputshape.1);
382 for mut row in padded_input.rows_mut() {
383 let mut buffer: Vec<RustComplex<f64>> =
385 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
386
387 row_fft.process(&mut buffer);
389
390 for (i, val) in buffer.iter().enumerate() {
392 row[i] = Complex64::new(val.re, val.im);
393 }
394 }
395
396 let col_fft = planner.plan_fft_forward(outputshape.0);
398 for mut col in padded_input.columns_mut() {
399 let mut buffer: Vec<RustComplex<f64>> =
401 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
402
403 col_fft.process(&mut buffer);
405
406 for (i, val) in buffer.iter().enumerate() {
408 col[i] = Complex64::new(val.re, val.im);
409 }
410 }
411
412 if norm_mode != NormMode::None {
414 let total_elements = outputshape.0 * outputshape.1;
415 let scale = match norm_mode {
416 NormMode::Backward => 1.0 / (total_elements as f64),
417 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
418 NormMode::Forward => 1.0 / (total_elements as f64),
419 NormMode::None => 1.0, };
421
422 padded_input.mapv_inplace(|x| x * scale);
423 }
424
425 output.assign(&padded_input);
427
428 Ok(output)
429}
430
431#[allow(dead_code)]
468pub fn ifft2<T>(
469 input: &Array2<T>,
470 shape: Option<(usize, usize)>,
471 axes: Option<(i32, i32)>,
472 norm: Option<&str>,
473) -> FFTResult<Array2<Complex64>>
474where
475 T: NumCast + Copy + Debug + 'static,
476{
477 let inputshape = input.shape();
479
480 let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
482
483 let axes = axes.unwrap_or((0, 1));
485
486 if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
488 return Err(FFTError::ValueError("Invalid axes for 2D IFFT".to_string()));
489 }
490
491 let norm_mode = parse_norm_mode(norm, true);
493
494 let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
496 for i in 0..inputshape[0] {
497 for j in 0..inputshape[1] {
498 let val = input[[i, j]];
499
500 complex_input[[i, j]] = convert_to_complex(val)?;
502 }
503 }
504
505 let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
507 let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
508 let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
509 let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
510
511 for i in 0..copy_rows {
512 for j in 0..copy_cols {
513 padded[[i, j]] = complex_input[[i, j]];
514 }
515 }
516 padded
517 } else {
518 complex_input
519 };
520
521 let mut planner = FftPlanner::new();
523
524 let row_ifft = planner.plan_fft_inverse(outputshape.1);
526 for mut row in padded_input.rows_mut() {
527 let mut buffer: Vec<RustComplex<f64>> =
529 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
530
531 row_ifft.process(&mut buffer);
533
534 for (i, val) in buffer.iter().enumerate() {
536 row[i] = Complex64::new(val.re, val.im);
537 }
538 }
539
540 let col_ifft = planner.plan_fft_inverse(outputshape.0);
542 for mut col in padded_input.columns_mut() {
543 let mut buffer: Vec<RustComplex<f64>> =
545 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
546
547 col_ifft.process(&mut buffer);
549
550 for (i, val) in buffer.iter().enumerate() {
552 col[i] = Complex64::new(val.re, val.im);
553 }
554 }
555
556 let total_elements = outputshape.0 * outputshape.1;
558 let scale = match norm_mode {
559 NormMode::Backward => 1.0 / (total_elements as f64),
560 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
561 NormMode::Forward => 1.0, NormMode::None => 1.0, };
564
565 if scale != 1.0 {
566 padded_input.mapv_inplace(|x| x * scale);
567 }
568
569 Ok(padded_input)
570}
571
572#[allow(clippy::too_many_arguments)]
605#[allow(dead_code)]
606pub fn fftn<T>(
607 input: &ArrayD<T>,
608 shape: Option<Vec<usize>>,
609 axes: Option<Vec<usize>>,
610 norm: Option<&str>,
611 _overwrite_x: Option<bool>,
612 _workers: Option<usize>,
613) -> FFTResult<ArrayD<Complex64>>
614where
615 T: NumCast + Copy + Debug + 'static,
616{
617 let inputshape = input.shape().to_vec();
618 let input_ndim = inputshape.len();
619
620 let outputshape = shape.unwrap_or_else(|| inputshape.clone());
622
623 if outputshape.len() != input_ndim {
625 return Err(FFTError::ValueError(
626 "Output shape must have the same number of dimensions as input".to_string(),
627 ));
628 }
629
630 let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
632
633 for &axis in &axes {
635 if axis >= input_ndim {
636 return Err(FFTError::ValueError(format!(
637 "Axis {axis} out of bounds for array of dimension {input_ndim}"
638 )));
639 }
640 }
641
642 let norm_mode = parse_norm_mode(norm, false);
644
645 let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
647 for (idx, &val) in input.iter().enumerate() {
648 let mut idx_vec = Vec::with_capacity(input_ndim);
649 let mut remaining = idx;
650
651 for &dim in input.shape().iter().rev() {
652 idx_vec.push(remaining % dim);
653 remaining /= dim;
654 }
655
656 idx_vec.reverse();
657
658 complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
659 }
660
661 let mut result = if inputshape != outputshape {
663 let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
664
665 for (idx, &val) in complex_input.iter().enumerate() {
667 let mut idx_vec = Vec::with_capacity(input_ndim);
668 let mut remaining = idx;
669
670 for &dim in input.shape().iter().rev() {
671 idx_vec.push(remaining % dim);
672 remaining /= dim;
673 }
674
675 idx_vec.reverse();
676
677 let mut in_bounds = true;
678 for (dim, &idx_val) in idx_vec.iter().enumerate() {
679 if idx_val >= outputshape[dim] {
680 in_bounds = false;
681 break;
682 }
683 }
684
685 if in_bounds {
686 padded[IxDyn(&idx_vec)] = val;
687 }
688 }
689
690 padded
691 } else {
692 complex_input
693 };
694
695 let mut planner = FftPlanner::new();
697
698 for &axis in &axes {
700 let axis_len = outputshape[axis];
701 let fft = planner.plan_fft_forward(axis_len);
702
703 let axis = Axis(axis);
705
706 for mut lane in result.lanes_mut(axis) {
707 let mut buffer: Vec<RustComplex<f64>> =
709 lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
710
711 fft.process(&mut buffer);
713
714 for (i, val) in buffer.iter().enumerate() {
716 lane[i] = Complex64::new(val.re, val.im);
717 }
718 }
719 }
720
721 if norm_mode != NormMode::None {
723 let total_elements: usize = outputshape.iter().product();
724 let scale = match norm_mode {
725 NormMode::Backward => 1.0 / (total_elements as f64),
726 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
727 NormMode::Forward => 1.0 / (total_elements as f64),
728 NormMode::None => 1.0, };
730
731 result.mapv_inplace(|_x| _x * scale);
732 }
733
734 Ok(result)
735}
736
737#[allow(clippy::too_many_arguments)]
786#[allow(dead_code)]
787pub fn ifftn<T>(
788 input: &ArrayD<T>,
789 shape: Option<Vec<usize>>,
790 axes: Option<Vec<usize>>,
791 norm: Option<&str>,
792 _overwrite_x: Option<bool>,
793 _workers: Option<usize>,
794) -> FFTResult<ArrayD<Complex64>>
795where
796 T: NumCast + Copy + Debug + 'static,
797{
798 let inputshape = input.shape().to_vec();
799 let input_ndim = inputshape.len();
800
801 let outputshape = shape.unwrap_or_else(|| inputshape.clone());
803
804 if outputshape.len() != input_ndim {
806 return Err(FFTError::ValueError(
807 "Output shape must have the same number of dimensions as input".to_string(),
808 ));
809 }
810
811 let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
813
814 for &axis in &axes {
816 if axis >= input_ndim {
817 return Err(FFTError::ValueError(format!(
818 "Axis {axis} out of bounds for array of dimension {input_ndim}"
819 )));
820 }
821 }
822
823 let norm_mode = parse_norm_mode(norm, true);
825
826 let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
828 for (idx, &val) in input.iter().enumerate() {
829 let mut idx_vec = Vec::with_capacity(input_ndim);
830 let mut remaining = idx;
831
832 for &dim in input.shape().iter().rev() {
833 idx_vec.push(remaining % dim);
834 remaining /= dim;
835 }
836
837 idx_vec.reverse();
838
839 complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
841 }
842
843 let mut result = if inputshape != outputshape {
845 let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
846
847 for (idx, &val) in complex_input.iter().enumerate() {
849 let mut idx_vec = Vec::with_capacity(input_ndim);
850 let mut remaining = idx;
851
852 for &dim in input.shape().iter().rev() {
853 idx_vec.push(remaining % dim);
854 remaining /= dim;
855 }
856
857 idx_vec.reverse();
858
859 let mut in_bounds = true;
860 for (dim, &idx_val) in idx_vec.iter().enumerate() {
861 if idx_val >= outputshape[dim] {
862 in_bounds = false;
863 break;
864 }
865 }
866
867 if in_bounds {
868 padded[IxDyn(&idx_vec)] = val;
869 }
870 }
871
872 padded
873 } else {
874 complex_input
875 };
876
877 let mut planner = FftPlanner::new();
879
880 for &axis in &axes {
882 let axis_len = outputshape[axis];
883 let ifft = planner.plan_fft_inverse(axis_len);
884
885 let axis = Axis(axis);
887
888 for mut lane in result.lanes_mut(axis) {
889 let mut buffer: Vec<RustComplex<f64>> =
891 lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
892
893 ifft.process(&mut buffer);
895
896 for (i, val) in buffer.iter().enumerate() {
898 lane[i] = Complex64::new(val.re, val.im);
899 }
900 }
901 }
902
903 if norm_mode != NormMode::None {
905 let total_elements: usize = axes.iter().map(|&a| outputshape[a]).product();
906 let scale = match norm_mode {
907 NormMode::Backward => 1.0 / (total_elements as f64),
908 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
909 NormMode::Forward => 1.0, NormMode::None => 1.0, };
912
913 if scale != 1.0 {
914 result.mapv_inplace(|_x| _x * scale);
915 }
916 }
917
918 Ok(result)
919}