scirs2_fft/
strided_fft.rs

1//! Advanced Strided FFT Operations
2//!
3//! This module provides optimized FFT operations for arrays with
4//! arbitrary memory layouts and striding patterns.
5
6use rustfft::FftPlanner;
7use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
8use scirs2_core::numeric::Complex64;
9use scirs2_core::numeric::NumCast;
10use std::sync::Arc;
11
12use crate::error::{FFTError, FFTResult};
13use crate::plan_cache::get_global_cache;
14
15/// Execute FFT on strided data with optimal memory access
16#[allow(dead_code)]
17pub fn fft_strided<S, D>(
18    input: &ArrayBase<S, D>,
19    axis: usize,
20) -> FFTResult<scirs2_core::ndarray::Array<Complex64, D>>
21where
22    S: Data,
23    D: Dimension,
24    S::Elem: NumCast + Copy,
25{
26    // Validate axis
27    if axis >= input.ndim() {
28        return Err(FFTError::ValueError(format!(
29            "Axis {} is out of bounds for array with {} dimensions",
30            axis,
31            input.ndim()
32        )));
33    }
34
35    // Create output array with same shape
36    let mut output = scirs2_core::ndarray::Array::zeros(input.raw_dim());
37
38    // Get FFT plan from cache
39    let axis_len = input.shape()[axis];
40    let mut planner = FftPlanner::new();
41    let fft_plan = get_global_cache().get_or_create_plan(axis_len, true, &mut planner);
42
43    // Process data along the specified axis
44    process_strided_fft(input, &mut output, axis, fft_plan)?;
45
46    Ok(output)
47}
48
49/// Process data with arbitrary striding
50#[allow(dead_code)]
51fn process_strided_fft<S, D>(
52    input: &ArrayBase<S, D>,
53    output: &mut scirs2_core::ndarray::Array<Complex64, D>,
54    axis: usize,
55    fft_plan: Arc<dyn rustfft::Fft<f64>>,
56) -> FFTResult<()>
57where
58    S: Data,
59    D: Dimension,
60    S::Elem: NumCast + Copy,
61{
62    let axis_len = input.shape()[axis];
63
64    // Create temporary buffer for FFT input/output
65    let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
66
67    // Process each lane along the given axis
68    for (i_lane, mut o_lane) in input
69        .lanes(scirs2_core::ndarray::Axis(axis))
70        .into_iter()
71        .zip(output.lanes_mut(scirs2_core::ndarray::Axis(axis)))
72    {
73        // Copy data to input buffer with proper conversion
74        for (i, &val) in i_lane.iter().enumerate() {
75            let val_f64 = NumCast::from(val).ok_or_else(|| {
76                FFTError::ValueError(format!("Failed to convert value at index {i} to f64"))
77            })?;
78            buffer[i] = Complex64::new(val_f64, 0.0);
79        }
80
81        // Perform FFT (in-place)
82        fft_plan.process(&mut buffer);
83
84        // Copy results back to output
85        for (i, dst) in o_lane.iter_mut().enumerate() {
86            *dst = buffer[i];
87        }
88    }
89
90    Ok(())
91}
92
93/// Execute FFT on strided data with optimal memory access for complex input
94#[allow(dead_code)]
95pub fn fft_strided_complex<S, D>(
96    input: &ArrayBase<S, D>,
97    axis: usize,
98) -> FFTResult<scirs2_core::ndarray::Array<Complex64, D>>
99where
100    S: Data,
101    D: Dimension,
102    S::Elem: Into<Complex64> + Copy,
103{
104    // Validate axis
105    if axis >= input.ndim() {
106        return Err(FFTError::ValueError(format!(
107            "Axis {} is out of bounds for array with {} dimensions",
108            axis,
109            input.ndim()
110        )));
111    }
112
113    // Create output array with same shape
114    let mut output = scirs2_core::ndarray::Array::zeros(input.raw_dim());
115
116    // Get FFT plan from cache
117    let axis_len = input.shape()[axis];
118    let mut planner = FftPlanner::new();
119    let fft_plan = get_global_cache().get_or_create_plan(axis_len, true, &mut planner);
120
121    // Process data along the specified axis
122    process_strided_complex_fft(input, &mut output, axis, fft_plan)?;
123
124    Ok(output)
125}
126
127/// Process complex data with arbitrary striding
128#[allow(dead_code)]
129fn process_strided_complex_fft<S, D>(
130    input: &ArrayBase<S, D>,
131    output: &mut scirs2_core::ndarray::Array<Complex64, D>,
132    axis: usize,
133    fft_plan: Arc<dyn rustfft::Fft<f64>>,
134) -> FFTResult<()>
135where
136    S: Data,
137    D: Dimension,
138    S::Elem: Into<Complex64> + Copy,
139{
140    let axis_len = input.shape()[axis];
141
142    // Create temporary buffer for FFT input/output
143    let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
144
145    // Process each lane along the given axis
146    for (i_lane, mut o_lane) in input
147        .lanes(scirs2_core::ndarray::Axis(axis))
148        .into_iter()
149        .zip(output.lanes_mut(scirs2_core::ndarray::Axis(axis)))
150    {
151        // Copy data to input buffer with proper conversion
152        for (i, &val) in i_lane.iter().enumerate() {
153            buffer[i] = val.into();
154        }
155
156        // Perform FFT (in-place)
157        fft_plan.process(&mut buffer);
158
159        // Copy results back to output
160        for (i, dst) in o_lane.iter_mut().enumerate() {
161            *dst = buffer[i];
162        }
163    }
164
165    Ok(())
166}
167
168/// Execute inverse FFT on strided data
169#[allow(dead_code)]
170pub fn ifft_strided<S, D>(
171    input: &ArrayBase<S, D>,
172    axis: usize,
173) -> FFTResult<scirs2_core::ndarray::Array<Complex64, D>>
174where
175    S: Data,
176    D: Dimension,
177    S::Elem: Into<Complex64> + Copy,
178{
179    // Validate axis
180    if axis >= input.ndim() {
181        return Err(FFTError::ValueError(format!(
182            "Axis {} is out of bounds for array with {} dimensions",
183            axis,
184            input.ndim()
185        )));
186    }
187
188    // Create output array with same shape
189    let mut output = scirs2_core::ndarray::Array::zeros(input.raw_dim());
190
191    // Get inverse FFT plan from cache
192    let axis_len = input.shape()[axis];
193    let mut planner = FftPlanner::new();
194    let ifft_plan = get_global_cache().get_or_create_plan(axis_len, false, &mut planner);
195
196    // Process data along the specified axis
197    process_strided_inverse_fft(input, &mut output, axis, ifft_plan)?;
198
199    // Apply normalization
200    let scale = 1.0 / (axis_len as f64);
201    output.mapv_inplace(|val| val * scale);
202
203    Ok(output)
204}
205
206/// Process data with arbitrary striding for inverse FFT
207#[allow(dead_code)]
208fn process_strided_inverse_fft<S, D>(
209    input: &ArrayBase<S, D>,
210    output: &mut scirs2_core::ndarray::Array<Complex64, D>,
211    axis: usize,
212    ifft_plan: Arc<dyn rustfft::Fft<f64>>,
213) -> FFTResult<()>
214where
215    S: Data,
216    D: Dimension,
217    S::Elem: Into<Complex64> + Copy,
218{
219    let axis_len = input.shape()[axis];
220
221    // Create temporary buffer for FFT input/output
222    let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
223
224    // Process each lane along the given axis
225    for (i_lane, mut o_lane) in input
226        .lanes(scirs2_core::ndarray::Axis(axis))
227        .into_iter()
228        .zip(output.lanes_mut(scirs2_core::ndarray::Axis(axis)))
229    {
230        // Copy data to input buffer with proper conversion
231        for (i, &val) in i_lane.iter().enumerate() {
232            buffer[i] = val.into();
233        }
234
235        // Perform inverse FFT (in-place)
236        ifft_plan.process(&mut buffer);
237
238        // Copy results back to output
239        for (i, dst) in o_lane.iter_mut().enumerate() {
240            *dst = buffer[i];
241        }
242    }
243
244    Ok(())
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use scirs2_core::ndarray::Array2;
251
252    #[test]
253    fn test_fft_strided_1d() {
254        // Create a test signal
255        let n = 8;
256        let mut input = scirs2_core::ndarray::Array1::zeros(n);
257        for i in 0..n {
258            input[i] = i as f64;
259        }
260
261        // Compute FFT using strided implementation
262        let result = fft_strided(&input, 0).unwrap();
263
264        // Compare with expected FFT result
265        // (We would compare with the standard FFT implementation)
266        assert_eq!(result.shape(), input.shape());
267    }
268
269    #[test]
270    fn test_fft_strided_2d() {
271        // Create a 2D test array
272        let mut input = Array2::zeros((4, 6));
273        for i in 0..4 {
274            for j in 0..6 {
275                input[[i, j]] = (i * 10 + j) as f64;
276            }
277        }
278
279        // FFT along first axis
280        let result1 = fft_strided(&input, 0).unwrap();
281        assert_eq!(result1.shape(), input.shape());
282
283        // FFT along second axis
284        let result2 = fft_strided(&input, 1).unwrap();
285        assert_eq!(result2.shape(), input.shape());
286    }
287
288    #[test]
289    fn test_ifft_strided() {
290        // Create a complex test signal
291        let n = 8;
292        let mut input = scirs2_core::ndarray::Array1::zeros(n);
293        for i in 0..n {
294            input[i] = Complex64::new(i as f64, (i * 2) as f64);
295        }
296
297        // Forward and inverse FFT should give back the input
298        let forward = fft_strided_complex(&input, 0).unwrap();
299        let inverse = ifft_strided(&forward, 0).unwrap();
300
301        // Check round-trip accuracy
302        for i in 0..n {
303            assert!((inverse[i] - input[i]).norm() < 1e-10);
304        }
305    }
306}