1use scirs2_core::ndarray::{Array, ArrayView, Axis, Dimension};
7use scirs2_core::numeric::Complex64;
8use scirs2_core::numeric::NumCast;
9use scirs2_core::parallel_ops::*;
10use std::cmp::min;
11
12use crate::error::{FFTError, FFTResult};
13use crate::fft::fft;
14use crate::rfft::rfft;
15
16#[allow(dead_code)]
18pub fn fftn_optimized<T, D>(
19 x: &ArrayView<T, D>,
20 _shape: Option<Vec<usize>>,
21 axes: Option<Vec<usize>>,
22) -> FFTResult<Array<Complex64, D>>
23where
24 T: NumCast + Copy + Send + Sync,
25 D: Dimension,
26{
27 let ndim = x.ndim();
28
29 let mut result = Array::zeros(x.raw_dim());
31 scirs2_core::ndarray::Zip::from(&mut result)
32 .and(x)
33 .for_each(|dst, &src| {
34 *dst = Complex64::new(
35 NumCast::from(src)
36 .ok_or_else(|| {
37 FFTError::ValueError("Failed to convert input to complex".to_string())
38 })
39 .unwrap(),
40 0.0,
41 );
42 });
43
44 let axes_to_transform = if let Some(a) = axes {
46 validate_axes(&a, ndim)?;
47 a
48 } else {
49 (0..ndim).collect()
50 };
51
52 let optimized_order = optimize_axis_order(&axes_to_transform, result.shape());
54
55 for &axis in &optimized_order {
57 apply_fft_along_axis(&mut result, axis)?;
58 }
59
60 Ok(result)
61}
62
63#[allow(dead_code)]
65fn apply_fft_along_axis<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
66where
67 D: Dimension,
68{
69 let axis_len = data.shape()[axis];
70
71 let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
73
74 for mut lane in data.lanes_mut(Axis(axis)) {
76 buffer
78 .iter_mut()
79 .zip(lane.iter())
80 .for_each(|(b, &x)| *b = x);
81
82 let transformed = fft(&buffer, None)?;
84
85 lane.iter_mut()
87 .zip(transformed.iter())
88 .for_each(|(dst, &src)| *dst = src);
89 }
90
91 Ok(())
92}
93
94#[allow(dead_code)]
96fn optimize_axis_order(axes: &[usize], shape: &[usize]) -> Vec<usize> {
97 let mut axis_info: Vec<(usize, usize, usize)> = axes
98 .iter()
99 .map(|&axis| {
100 let size = shape[axis];
101 let stride = shape.iter().skip(axis + 1).product::<usize>();
102 (axis, size, stride)
103 })
104 .collect();
105
106 axis_info.sort_by_key(|&(_, _, stride)| stride);
108
109 axis_info.into_iter().map(|(axis, _, _)| axis).collect()
111}
112
113#[allow(dead_code)]
115fn validate_axes(axes: &[usize], ndim: usize) -> FFTResult<()> {
116 for &axis in axes {
117 if axis >= ndim {
118 return Err(FFTError::ValueError(format!(
119 "Axis {axis} is out of bounds for array with {ndim} dimensions"
120 )));
121 }
122 }
123 Ok(())
124}
125
126#[allow(dead_code)]
128fn should_parallelize(_data_size: usize, axislen: usize) -> bool {
129 const MIN_PARALLEL_SIZE: usize = 10000;
131 _data_size > MIN_PARALLEL_SIZE && axislen > 64
132}
133
134#[cfg(feature = "parallel")]
136#[allow(dead_code)]
137fn apply_fft_parallel<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
138where
139 D: Dimension,
140{
141 let axis_len = data.shape()[axis];
142 let total_size: usize = data.shape().iter().product();
143
144 if should_parallelize(total_size, axis_len) {
145 let mut lanes: Vec<_> = data.lanes_mut(Axis(axis)).into_iter().collect();
147
148 lanes.par_iter_mut().try_for_each(|lane| {
149 let buffer: Vec<Complex64> = lane.to_vec();
150 let transformed = fft(&buffer, None)?;
151 lane.iter_mut()
152 .zip(transformed.iter())
153 .for_each(|(dst, &src)| *dst = src);
154 Ok(())
155 })
156 } else {
157 apply_fft_along_axis(data, axis)
158 }
159}
160
161#[allow(dead_code)]
163pub fn fftn_memory_efficient<T, D>(
164 x: &ArrayView<T, D>,
165 axes: Option<Vec<usize>>,
166 _max_memory_gb: f64,
167) -> FFTResult<Array<Complex64, D>>
168where
169 T: NumCast + Copy + Send + Sync,
170 D: Dimension,
171{
172 let ndim = x.ndim();
173 let axes_to_transform = if let Some(a) = axes {
174 validate_axes(&a, ndim)?;
175 a
176 } else {
177 (0..ndim).collect()
178 };
179
180 let mut result = Array::zeros(x.raw_dim());
183
184 scirs2_core::ndarray::Zip::from(&mut result)
186 .and(x)
187 .for_each(|dst, &src| {
188 *dst = Complex64::new(
189 NumCast::from(src)
190 .ok_or_else(|| {
191 FFTError::ValueError("Failed to convert input to complex".to_string())
192 })
193 .unwrap(),
194 0.0,
195 );
196 });
197
198 for &axis in &axes_to_transform {
200 let axis_len: usize = result.shape()[axis];
201
202 if axis_len > 1048576 {
203 apply_fft_chunked(&mut result, axis)?;
205 } else {
206 apply_fft_along_axis(&mut result, axis)?;
207 }
208 }
209
210 Ok(result)
211}
212
213#[allow(dead_code)]
215fn apply_fft_chunked<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
216where
217 D: Dimension,
218{
219 let axis_len = data.shape()[axis];
220 const CHUNK_SIZE: usize = 65536; let n_chunks = axis_len.div_ceil(CHUNK_SIZE);
226
227 for chunk_idx in 0..n_chunks {
228 let start = chunk_idx * CHUNK_SIZE;
229 let end = min(start + CHUNK_SIZE, axis_len);
230 let chunk_len = end - start;
231
232 let mut buffer = vec![Complex64::new(0.0, 0.0); chunk_len];
234
235 for mut lane in data.lanes_mut(Axis(axis)) {
236 buffer
238 .iter_mut()
239 .zip(lane.slice_axis(Axis(0), (start..end).into()).iter())
240 .for_each(|(b, &x)| *b = x);
241
242 let transformed = fft(&buffer, None)?;
244
245 lane.slice_axis_mut(Axis(0), (start..end).into())
247 .iter_mut()
248 .zip(transformed.iter())
249 .for_each(|(dst, &src)| *dst = src);
250 }
251 }
252
253 Ok(())
254}
255
256#[allow(dead_code)]
258pub fn rfftn_optimized<T, D>(
259 x: &ArrayView<T, D>,
260 _shape: Option<Vec<usize>>,
261 axes: Option<Vec<usize>>,
262) -> FFTResult<Array<Complex64, D>>
263where
264 T: NumCast + Copy + Send + Sync,
265 D: Dimension,
266{
267 let ndim = x.ndim();
270 let mut axes_to_transform = if let Some(a) = axes {
271 validate_axes(&a, ndim)?;
272 a
273 } else {
274 (0..ndim).collect()
275 };
276
277 let last_axis = axes_to_transform.pop().unwrap_or(ndim - 1);
279
280 let mut real_data = Array::zeros(x.raw_dim());
282 scirs2_core::ndarray::Zip::from(&mut real_data)
283 .and(x)
284 .for_each(|dst, &src| {
285 *dst = NumCast::from(src)
286 .ok_or_else(|| FFTError::ValueError("Failed to convert input to float".to_string()))
287 .unwrap();
288 });
289
290 let mut result: Array<Complex64, D> = Array::zeros(x.raw_dim());
292
293 for lane in real_data.lanes(Axis(last_axis)) {
295 let real_vec: Vec<f64> = lane.to_vec();
296 let _complex_vec = rfft(&real_vec, None)?;
297
298 }
301
302 for &axis in &axes_to_transform {
304 apply_fft_along_axis(&mut result, axis)?;
305 }
306
307 Ok(result)
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_axis_optimization() {
316 let axes = vec![0, 1, 2];
317 let shape = vec![10, 100, 1000];
318 let optimized = optimize_axis_order(&axes, &shape);
319
320 assert_eq!(optimized[0], 2);
322 assert_eq!(optimized[1], 1);
323 assert_eq!(optimized[2], 0);
324 }
325
326 #[test]
327 fn test_parallelize_decision() {
328 assert!(should_parallelize(10001, 100));
330 assert!(!should_parallelize(10001, 50));
332 assert!(!should_parallelize(100, 10));
334 }
335
336 #[test]
337 fn test_validate_axes() {
338 assert!(validate_axes(&[0, 1, 2], 3).is_ok());
339 assert!(validate_axes(&[0, 1, 3], 3).is_err());
340 }
341}