1use rustfft::FftPlanner;
7use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
8use scirs2_core::numeric::Complex64;
9use scirs2_core::numeric::NumCast;
10use std::sync::Arc;
11
12use crate::error::{FFTError, FFTResult};
13use crate::plan_cache::get_global_cache;
14
15#[allow(dead_code)]
17pub fn fft_strided<S, D>(
18 input: &ArrayBase<S, D>,
19 axis: usize,
20) -> FFTResult<scirs2_core::ndarray::Array<Complex64, D>>
21where
22 S: Data,
23 D: Dimension,
24 S::Elem: NumCast + Copy,
25{
26 if axis >= input.ndim() {
28 return Err(FFTError::ValueError(format!(
29 "Axis {} is out of bounds for array with {} dimensions",
30 axis,
31 input.ndim()
32 )));
33 }
34
35 let mut output = scirs2_core::ndarray::Array::zeros(input.raw_dim());
37
38 let axis_len = input.shape()[axis];
40 let mut planner = FftPlanner::new();
41 let fft_plan = get_global_cache().get_or_create_plan(axis_len, true, &mut planner);
42
43 process_strided_fft(input, &mut output, axis, fft_plan)?;
45
46 Ok(output)
47}
48
49#[allow(dead_code)]
51fn process_strided_fft<S, D>(
52 input: &ArrayBase<S, D>,
53 output: &mut scirs2_core::ndarray::Array<Complex64, D>,
54 axis: usize,
55 fft_plan: Arc<dyn rustfft::Fft<f64>>,
56) -> FFTResult<()>
57where
58 S: Data,
59 D: Dimension,
60 S::Elem: NumCast + Copy,
61{
62 let axis_len = input.shape()[axis];
63
64 let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
66
67 for (i_lane, mut o_lane) in input
69 .lanes(scirs2_core::ndarray::Axis(axis))
70 .into_iter()
71 .zip(output.lanes_mut(scirs2_core::ndarray::Axis(axis)))
72 {
73 for (i, &val) in i_lane.iter().enumerate() {
75 let val_f64 = NumCast::from(val).ok_or_else(|| {
76 FFTError::ValueError(format!("Failed to convert value at index {i} to f64"))
77 })?;
78 buffer[i] = Complex64::new(val_f64, 0.0);
79 }
80
81 fft_plan.process(&mut buffer);
83
84 for (i, dst) in o_lane.iter_mut().enumerate() {
86 *dst = buffer[i];
87 }
88 }
89
90 Ok(())
91}
92
93#[allow(dead_code)]
95pub fn fft_strided_complex<S, D>(
96 input: &ArrayBase<S, D>,
97 axis: usize,
98) -> FFTResult<scirs2_core::ndarray::Array<Complex64, D>>
99where
100 S: Data,
101 D: Dimension,
102 S::Elem: Into<Complex64> + Copy,
103{
104 if axis >= input.ndim() {
106 return Err(FFTError::ValueError(format!(
107 "Axis {} is out of bounds for array with {} dimensions",
108 axis,
109 input.ndim()
110 )));
111 }
112
113 let mut output = scirs2_core::ndarray::Array::zeros(input.raw_dim());
115
116 let axis_len = input.shape()[axis];
118 let mut planner = FftPlanner::new();
119 let fft_plan = get_global_cache().get_or_create_plan(axis_len, true, &mut planner);
120
121 process_strided_complex_fft(input, &mut output, axis, fft_plan)?;
123
124 Ok(output)
125}
126
127#[allow(dead_code)]
129fn process_strided_complex_fft<S, D>(
130 input: &ArrayBase<S, D>,
131 output: &mut scirs2_core::ndarray::Array<Complex64, D>,
132 axis: usize,
133 fft_plan: Arc<dyn rustfft::Fft<f64>>,
134) -> FFTResult<()>
135where
136 S: Data,
137 D: Dimension,
138 S::Elem: Into<Complex64> + Copy,
139{
140 let axis_len = input.shape()[axis];
141
142 let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
144
145 for (i_lane, mut o_lane) in input
147 .lanes(scirs2_core::ndarray::Axis(axis))
148 .into_iter()
149 .zip(output.lanes_mut(scirs2_core::ndarray::Axis(axis)))
150 {
151 for (i, &val) in i_lane.iter().enumerate() {
153 buffer[i] = val.into();
154 }
155
156 fft_plan.process(&mut buffer);
158
159 for (i, dst) in o_lane.iter_mut().enumerate() {
161 *dst = buffer[i];
162 }
163 }
164
165 Ok(())
166}
167
168#[allow(dead_code)]
170pub fn ifft_strided<S, D>(
171 input: &ArrayBase<S, D>,
172 axis: usize,
173) -> FFTResult<scirs2_core::ndarray::Array<Complex64, D>>
174where
175 S: Data,
176 D: Dimension,
177 S::Elem: Into<Complex64> + Copy,
178{
179 if axis >= input.ndim() {
181 return Err(FFTError::ValueError(format!(
182 "Axis {} is out of bounds for array with {} dimensions",
183 axis,
184 input.ndim()
185 )));
186 }
187
188 let mut output = scirs2_core::ndarray::Array::zeros(input.raw_dim());
190
191 let axis_len = input.shape()[axis];
193 let mut planner = FftPlanner::new();
194 let ifft_plan = get_global_cache().get_or_create_plan(axis_len, false, &mut planner);
195
196 process_strided_inverse_fft(input, &mut output, axis, ifft_plan)?;
198
199 let scale = 1.0 / (axis_len as f64);
201 output.mapv_inplace(|val| val * scale);
202
203 Ok(output)
204}
205
206#[allow(dead_code)]
208fn process_strided_inverse_fft<S, D>(
209 input: &ArrayBase<S, D>,
210 output: &mut scirs2_core::ndarray::Array<Complex64, D>,
211 axis: usize,
212 ifft_plan: Arc<dyn rustfft::Fft<f64>>,
213) -> FFTResult<()>
214where
215 S: Data,
216 D: Dimension,
217 S::Elem: Into<Complex64> + Copy,
218{
219 let axis_len = input.shape()[axis];
220
221 let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
223
224 for (i_lane, mut o_lane) in input
226 .lanes(scirs2_core::ndarray::Axis(axis))
227 .into_iter()
228 .zip(output.lanes_mut(scirs2_core::ndarray::Axis(axis)))
229 {
230 for (i, &val) in i_lane.iter().enumerate() {
232 buffer[i] = val.into();
233 }
234
235 ifft_plan.process(&mut buffer);
237
238 for (i, dst) in o_lane.iter_mut().enumerate() {
240 *dst = buffer[i];
241 }
242 }
243
244 Ok(())
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use scirs2_core::ndarray::Array2;
251
252 #[test]
253 fn test_fft_strided_1d() {
254 let n = 8;
256 let mut input = scirs2_core::ndarray::Array1::zeros(n);
257 for i in 0..n {
258 input[i] = i as f64;
259 }
260
261 let result = fft_strided(&input, 0).unwrap();
263
264 assert_eq!(result.shape(), input.shape());
267 }
268
269 #[test]
270 fn test_fft_strided_2d() {
271 let mut input = Array2::zeros((4, 6));
273 for i in 0..4 {
274 for j in 0..6 {
275 input[[i, j]] = (i * 10 + j) as f64;
276 }
277 }
278
279 let result1 = fft_strided(&input, 0).unwrap();
281 assert_eq!(result1.shape(), input.shape());
282
283 let result2 = fft_strided(&input, 1).unwrap();
285 assert_eq!(result2.shape(), input.shape());
286 }
287
288 #[test]
289 fn test_ifft_strided() {
290 let n = 8;
292 let mut input = scirs2_core::ndarray::Array1::zeros(n);
293 for i in 0..n {
294 input[i] = Complex64::new(i as f64, (i * 2) as f64);
295 }
296
297 let forward = fft_strided_complex(&input, 0).unwrap();
299 let inverse = ifft_strided(&forward, 0).unwrap();
300
301 for i in 0..n {
303 assert!((inverse[i] - input[i]).norm() < 1e-10);
304 }
305 }
306}