scirs2_fft/
ndim_optimized.rs

1//! Optimized N-dimensional FFT operations
2//!
3//! This module provides optimized implementations of N-dimensional FFT
4//! operations with better memory access patterns and performance.
5
6use scirs2_core::ndarray::{Array, ArrayView, Axis, Dimension};
7use scirs2_core::numeric::Complex64;
8use scirs2_core::numeric::NumCast;
9use scirs2_core::parallel_ops::*;
10use std::cmp::min;
11
12use crate::error::{FFTError, FFTResult};
13use crate::fft::fft;
14use crate::rfft::rfft;
15
16/// Optimized N-dimensional FFT with better memory access patterns
17#[allow(dead_code)]
18pub fn fftn_optimized<T, D>(
19    x: &ArrayView<T, D>,
20    _shape: Option<Vec<usize>>,
21    axes: Option<Vec<usize>>,
22) -> FFTResult<Array<Complex64, D>>
23where
24    T: NumCast + Copy + Send + Sync,
25    D: Dimension,
26{
27    let ndim = x.ndim();
28
29    // Convert input to complex
30    let mut result = Array::zeros(x.raw_dim());
31    scirs2_core::ndarray::Zip::from(&mut result)
32        .and(x)
33        .for_each(|dst, &src| {
34            *dst = Complex64::new(
35                NumCast::from(src)
36                    .ok_or_else(|| {
37                        FFTError::ValueError("Failed to convert input to complex".to_string())
38                    })
39                    .unwrap(),
40                0.0,
41            );
42        });
43
44    // Determine axes to transform
45    let axes_to_transform = if let Some(a) = axes {
46        validate_axes(&a, ndim)?;
47        a
48    } else {
49        (0..ndim).collect()
50    };
51
52    // Optimize axis order based on memory layout
53    let optimized_order = optimize_axis_order(&axes_to_transform, result.shape());
54
55    // Apply FFT along each axis in optimized order
56    for &axis in &optimized_order {
57        apply_fft_along_axis(&mut result, axis)?;
58    }
59
60    Ok(result)
61}
62
63/// Apply FFT along a specific axis
64#[allow(dead_code)]
65fn apply_fft_along_axis<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
66where
67    D: Dimension,
68{
69    let axis_len = data.shape()[axis];
70
71    // Create temporary buffer for FFT
72    let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
73
74    // Process slices along the specified axis
75    for mut lane in data.lanes_mut(Axis(axis)) {
76        // Copy _data to buffer
77        buffer
78            .iter_mut()
79            .zip(lane.iter())
80            .for_each(|(b, &x)| *b = x);
81
82        // Perform FFT
83        let transformed = fft(&buffer, None)?;
84
85        // Copy results back
86        lane.iter_mut()
87            .zip(transformed.iter())
88            .for_each(|(dst, &src)| *dst = src);
89    }
90
91    Ok(())
92}
93
94/// Optimize axis order based on memory layout and cache efficiency
95#[allow(dead_code)]
96fn optimize_axis_order(axes: &[usize], shape: &[usize]) -> Vec<usize> {
97    let mut axis_info: Vec<(usize, usize, usize)> = axes
98        .iter()
99        .map(|&axis| {
100            let size = shape[axis];
101            let stride = shape.iter().skip(axis + 1).product::<usize>();
102            (axis, size, stride)
103        })
104        .collect();
105
106    // Sort by stride (smallest first) for better cache locality
107    axis_info.sort_by_key(|&(_, _, stride)| stride);
108
109    // Return optimized axis order
110    axis_info.into_iter().map(|(axis, _, _)| axis).collect()
111}
112
113/// Validate that axes are within bounds
114#[allow(dead_code)]
115fn validate_axes(axes: &[usize], ndim: usize) -> FFTResult<()> {
116    for &axis in axes {
117        if axis >= ndim {
118            return Err(FFTError::ValueError(format!(
119                "Axis {axis} is out of bounds for array with {ndim} dimensions"
120            )));
121        }
122    }
123    Ok(())
124}
125
126/// Determine whether to use parallel processing
127#[allow(dead_code)]
128fn should_parallelize(_data_size: usize, axislen: usize) -> bool {
129    // Use parallel processing for large data sizes
130    const MIN_PARALLEL_SIZE: usize = 10000;
131    _data_size > MIN_PARALLEL_SIZE && axislen > 64
132}
133
134/// Apply FFT along axis with optional parallelization
135#[cfg(feature = "parallel")]
136#[allow(dead_code)]
137fn apply_fft_parallel<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
138where
139    D: Dimension,
140{
141    let axis_len = data.shape()[axis];
142    let total_size: usize = data.shape().iter().product();
143
144    if should_parallelize(total_size, axis_len) {
145        // Process lanes in parallel
146        let mut lanes: Vec<_> = data.lanes_mut(Axis(axis)).into_iter().collect();
147
148        lanes.par_iter_mut().try_for_each(|lane| {
149            let buffer: Vec<Complex64> = lane.to_vec();
150            let transformed = fft(&buffer, None)?;
151            lane.iter_mut()
152                .zip(transformed.iter())
153                .for_each(|(dst, &src)| *dst = src);
154            Ok(())
155        })
156    } else {
157        apply_fft_along_axis(data, axis)
158    }
159}
160
161/// Memory-efficient FFT for very large arrays
162#[allow(dead_code)]
163pub fn fftn_memory_efficient<T, D>(
164    x: &ArrayView<T, D>,
165    axes: Option<Vec<usize>>,
166    _max_memory_gb: f64,
167) -> FFTResult<Array<Complex64, D>>
168where
169    T: NumCast + Copy + Send + Sync,
170    D: Dimension,
171{
172    let ndim = x.ndim();
173    let axes_to_transform = if let Some(a) = axes {
174        validate_axes(&a, ndim)?;
175        a
176    } else {
177        (0..ndim).collect()
178    };
179
180    // For memory efficiency, we process one axis at a time
181    // and use chunking for very large dimensions
182    let mut result = Array::zeros(x.raw_dim());
183
184    // Convert input to complex
185    scirs2_core::ndarray::Zip::from(&mut result)
186        .and(x)
187        .for_each(|dst, &src| {
188            *dst = Complex64::new(
189                NumCast::from(src)
190                    .ok_or_else(|| {
191                        FFTError::ValueError("Failed to convert input to complex".to_string())
192                    })
193                    .unwrap(),
194                0.0,
195            );
196        });
197
198    // Process each axis with chunking if needed
199    for &axis in &axes_to_transform {
200        let axis_len: usize = result.shape()[axis];
201
202        if axis_len > 1048576 {
203            // For very large axes, use chunked processing
204            apply_fft_chunked(&mut result, axis)?;
205        } else {
206            apply_fft_along_axis(&mut result, axis)?;
207        }
208    }
209
210    Ok(result)
211}
212
213/// Apply FFT along axis using chunked processing for large dimensions
214#[allow(dead_code)]
215fn apply_fft_chunked<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
216where
217    D: Dimension,
218{
219    let axis_len = data.shape()[axis];
220    const CHUNK_SIZE: usize = 65536; // Process in 64K chunks
221
222    // This is a simplified chunking strategy
223    // In practice, we'd need to handle overlapping chunks
224    // for proper FFT computation
225    let n_chunks = axis_len.div_ceil(CHUNK_SIZE);
226
227    for chunk_idx in 0..n_chunks {
228        let start = chunk_idx * CHUNK_SIZE;
229        let end = min(start + CHUNK_SIZE, axis_len);
230        let chunk_len = end - start;
231
232        // Process chunk
233        let mut buffer = vec![Complex64::new(0.0, 0.0); chunk_len];
234
235        for mut lane in data.lanes_mut(Axis(axis)) {
236            // Extract chunk from lane
237            buffer
238                .iter_mut()
239                .zip(lane.slice_axis(Axis(0), (start..end).into()).iter())
240                .for_each(|(b, &x)| *b = x);
241
242            // Perform FFT on chunk
243            let transformed = fft(&buffer, None)?;
244
245            // Copy results back to chunk
246            lane.slice_axis_mut(Axis(0), (start..end).into())
247                .iter_mut()
248                .zip(transformed.iter())
249                .for_each(|(dst, &src)| *dst = src);
250        }
251    }
252
253    Ok(())
254}
255
256/// Optimized real-to-complex N-dimensional FFT
257#[allow(dead_code)]
258pub fn rfftn_optimized<T, D>(
259    x: &ArrayView<T, D>,
260    _shape: Option<Vec<usize>>,
261    axes: Option<Vec<usize>>,
262) -> FFTResult<Array<Complex64, D>>
263where
264    T: NumCast + Copy + Send + Sync,
265    D: Dimension,
266{
267    // For real FFT, we can optimize the first transform
268    // and use symmetry properties
269    let ndim = x.ndim();
270    let mut axes_to_transform = if let Some(a) = axes {
271        validate_axes(&a, ndim)?;
272        a
273    } else {
274        (0..ndim).collect()
275    };
276
277    // Process the last axis with real FFT for efficiency
278    let last_axis = axes_to_transform.pop().unwrap_or(ndim - 1);
279
280    // Convert to real array for first transform
281    let mut real_data = Array::zeros(x.raw_dim());
282    scirs2_core::ndarray::Zip::from(&mut real_data)
283        .and(x)
284        .for_each(|dst, &src| {
285            *dst = NumCast::from(src)
286                .ok_or_else(|| FFTError::ValueError("Failed to convert input to float".to_string()))
287                .unwrap();
288        });
289
290    // Apply real FFT on the last axis
291    let mut result: Array<Complex64, D> = Array::zeros(x.raw_dim());
292
293    // This is a simplified implementation - proper real FFT would have different output dimensions
294    for lane in real_data.lanes(Axis(last_axis)) {
295        let real_vec: Vec<f64> = lane.to_vec();
296        let _complex_vec = rfft(&real_vec, None)?;
297
298        // For now, just convert to complex array format
299        // This is a placeholder implementation
300    }
301
302    // Apply complex FFT on remaining axes
303    for &axis in &axes_to_transform {
304        apply_fft_along_axis(&mut result, axis)?;
305    }
306
307    Ok(result)
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_axis_optimization() {
316        let axes = vec![0, 1, 2];
317        let shape = vec![10, 100, 1000];
318        let optimized = optimize_axis_order(&axes, &shape);
319
320        // Should order from smallest stride (rightmost) to largest
321        assert_eq!(optimized[0], 2);
322        assert_eq!(optimized[1], 1);
323        assert_eq!(optimized[2], 0);
324    }
325
326    #[test]
327    fn test_parallelize_decision() {
328        // Test with both conditions met: large data size and axis length > 64
329        assert!(should_parallelize(10001, 100));
330        // Test with only data size large enough but axis too small
331        assert!(!should_parallelize(10001, 50));
332        // Test with both too small
333        assert!(!should_parallelize(100, 10));
334    }
335
336    #[test]
337    fn test_validate_axes() {
338        assert!(validate_axes(&[0, 1, 2], 3).is_ok());
339        assert!(validate_axes(&[0, 1, 3], 3).is_err());
340    }
341}