scirs2_fft/
memory_efficient.rs

1//! Memory-efficient FFT operations
2//!
3//! This module provides memory-efficient implementations of FFT operations
4//! that minimize allocations for large arrays.
5
6use crate::error::{FFTError, FFTResult};
7use ndarray::{Array2, ArrayView2};
8use num_complex::Complex64;
9use num_traits::NumCast;
10use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
11use std::any::Any;
12use std::fmt::Debug;
13use std::num::NonZeroUsize;
14
15// Helper function to attempt downcast to Complex64
16#[allow(dead_code)]
17fn downcast_to_complex<T: 'static>(value: &T) -> Option<Complex64> {
18    // Check if T is Complex64
19    if let Some(complex) = (value as &dyn Any).downcast_ref::<Complex64>() {
20        return Some(*complex);
21    }
22
23    // Try to directly convert from num_complex::Complex<f32>
24    if let Some(complex) = (value as &dyn Any).downcast_ref::<num_complex::Complex<f32>>() {
25        return Some(Complex64::new(complex.re as f64, complex.im as f64));
26    }
27
28    // Try to convert from rustfft's Complex type
29    if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f64>>() {
30        return Some(Complex64::new(complex.re, complex.im));
31    }
32
33    if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f32>>() {
34        return Some(Complex64::new(complex.re as f64, complex.im as f64));
35    }
36
37    None
38}
39
40/// Memory efficient FFT operation mode
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum FftMode {
43    /// Forward FFT transform
44    Forward,
45    /// Inverse FFT transform
46    Inverse,
47}
48
49/// Computes FFT in-place to minimize memory allocations
50///
51/// This function performs an in-place FFT using pre-allocated buffers
52/// to minimize memory allocations, which is beneficial for large arrays
53/// or when performing many FFT operations.
54///
55/// # Arguments
56///
57/// * `input` - Input buffer (will be modified in-place)
58/// * `output` - Pre-allocated output buffer
59/// * `mode` - Whether to compute forward or inverse FFT
60/// * `normalize` - Whether to normalize the result (required for IFFT)
61///
62/// # Returns
63///
64/// * Result with the number of elements processed
65///
66/// # Errors
67///
68/// Returns an error if the computation fails.
69///
70/// # Examples
71///
72/// ```
73/// use scirs2_fft::memory_efficient::{fft_inplace, FftMode};
74/// use num_complex::Complex64;
75///
76/// // Create input and output buffers
77/// let mut input_buffer = vec![Complex64::new(1.0, 0.0),
78///                            Complex64::new(2.0, 0.0),
79///                            Complex64::new(3.0, 0.0),
80///                            Complex64::new(4.0, 0.0)];
81/// let mut output_buffer = vec![Complex64::new(0.0, 0.0); input_buffer.len()];
82///
83/// // Perform in-place FFT
84/// fft_inplace(&mut input_buffer, &mut output_buffer, FftMode::Forward, false).unwrap();
85///
86/// // Input buffer now contains the result
87/// let sum: f64 = (1.0 + 2.0 + 3.0 + 4.0);
88/// assert!((input_buffer[0].re - sum).abs() < 1e-10);
89/// ```
90#[allow(dead_code)]
91pub fn fft_inplace(
92    input: &mut [Complex64],
93    output: &mut [Complex64],
94    mode: FftMode,
95    normalize: bool,
96) -> FFTResult<usize> {
97    let n = input.len();
98
99    if n == 0 {
100        return Err(FFTError::ValueError("Input array is empty".to_string()));
101    }
102
103    if output.len() < n {
104        return Err(FFTError::ValueError(format!(
105            "Output buffer is too small: got {}, need {}",
106            output.len(),
107            n
108        )));
109    }
110
111    // For larger arrays, consider using SIMD acceleration
112    let use_simd = n >= 32 && crate::simd_fft::simd_support_available();
113
114    if use_simd {
115        // Use the SIMD-accelerated FFT implementation
116        let result = match mode {
117            FftMode::Forward => crate::simd_fft::fft_adaptive(
118                input,
119                if normalize { Some("forward") } else { None },
120            )?,
121            FftMode::Inverse => crate::simd_fft::ifft_adaptive(
122                input,
123                if normalize { Some("backward") } else { None },
124            )?,
125        };
126
127        // Copy the results back to the input and output buffers
128        for (i, &val) in result.iter().enumerate() {
129            input[i] = val;
130            output[i] = val;
131        }
132
133        return Ok(n);
134    }
135
136    // Fall back to standard implementation for small arrays
137    // Create FFT plan
138    let mut planner = FftPlanner::new();
139    let fft = match mode {
140        FftMode::Forward => planner.plan_fft_forward(n),
141        FftMode::Inverse => planner.plan_fft_inverse(n),
142    };
143
144    // Convert to rustfft's Complex type
145    let mut buffer: Vec<RustComplex<f64>> = input
146        .iter()
147        .map(|&c| RustComplex::new(c.re, c.im))
148        .collect();
149
150    // Perform the FFT
151    fft.process(&mut buffer);
152
153    // Convert back to num_complex::Complex64 and apply normalization if needed
154    let scale = if normalize { 1.0 / (n as f64) } else { 1.0 };
155
156    if scale != 1.0 && use_simd {
157        // Copy back to input buffer first
158        for (i, &c) in buffer.iter().enumerate() {
159            input[i] = Complex64::new(c.re, c.im);
160        }
161
162        // Use SIMD-accelerated normalization
163        crate::simd_fft::apply_simd_normalization(input, scale);
164
165        // Copy to output buffer
166        output.copy_from_slice(input);
167    } else {
168        // Standard normalization
169        for (i, &c) in buffer.iter().enumerate() {
170            input[i] = Complex64::new(c.re * scale, c.im * scale);
171            output[i] = input[i];
172        }
173    }
174
175    Ok(n)
176}
177
178/// Process large arrays in chunks to minimize memory usage
179///
180/// This function processes a large array in chunks using the provided
181/// operation function, which reduces memory usage for very large arrays.
182///
183/// # Arguments
184///
185/// * `input` - Input array
186/// * `chunk_size` - Size of each chunk to process
187/// * `op` - Operation to apply to each chunk
188///
189/// # Returns
190///
191/// * Result with the processed array
192///
193/// # Errors
194///
195/// Returns an error if the computation fails.
196#[allow(dead_code)]
197pub fn process_in_chunks<T, F>(
198    input: &[T],
199    chunk_size: usize,
200    mut op: F,
201) -> FFTResult<Vec<Complex64>>
202where
203    T: NumCast + Copy + Debug + 'static,
204    F: FnMut(&[T]) -> FFTResult<Vec<Complex64>>,
205{
206    if input.len() <= chunk_size {
207        // If input is smaller than chunk_size, process it directly
208        return op(input);
209    }
210
211    let chunk_size_nz = NonZeroUsize::new(chunk_size).unwrap_or(NonZeroUsize::new(1).unwrap());
212    let n_chunks = input.len().div_ceil(chunk_size_nz.get());
213    let mut result = Vec::with_capacity(input.len());
214
215    for i in 0..n_chunks {
216        let start = i * chunk_size;
217        let end = (start + chunk_size).min(input.len());
218        let chunk = &input[start..end];
219
220        let chunk_result = op(chunk)?;
221        result.extend(chunk_result);
222    }
223
224    Ok(result)
225}
226
227/// Computes 2D FFT with memory efficiency in mind
228///
229/// This function performs a 2D FFT with optimized memory usage,
230/// which is particularly beneficial for large arrays.
231///
232/// # Arguments
233///
234/// * `input` - Input 2D array
235/// * `shape` - Optional shape for the output
236/// * `mode` - Whether to compute forward or inverse FFT
237/// * `normalize` - Whether to normalize the result
238///
239/// # Returns
240///
241/// * Result with the processed 2D array
242///
243/// # Errors
244///
245/// Returns an error if the computation fails.
246#[allow(dead_code)]
247pub fn fft2_efficient<T>(
248    input: &ArrayView2<T>,
249    shape: Option<(usize, usize)>,
250    mode: FftMode,
251    normalize: bool,
252) -> FFTResult<Array2<Complex64>>
253where
254    T: NumCast + Copy + Debug + 'static,
255{
256    let (n_rows, n_cols) = input.dim();
257    let (n_rows_out, n_cols_out) = shape.unwrap_or((n_rows, n_cols));
258
259    // Check if output dimensions are valid
260    if n_rows_out == 0 || n_cols_out == 0 {
261        return Err(FFTError::ValueError(
262            "Output dimensions must be positive".to_string(),
263        ));
264    }
265
266    // Convert input to complex array with proper dimensions
267    let mut complex_input = Array2::zeros((n_rows_out, n_cols_out));
268    for r in 0..n_rows.min(n_rows_out) {
269        for c in 0..n_cols.min(n_cols_out) {
270            let val = input[[r, c]];
271            match num_traits::cast::cast::<T, f64>(val) {
272                Some(val_f64) => {
273                    complex_input[[r, c]] = Complex64::new(val_f64, 0.0);
274                }
275                None => {
276                    // Check if this is already a complex number
277                    if let Some(complex_val) = downcast_to_complex::<T>(&val) {
278                        complex_input[[r, c]] = complex_val;
279                    } else {
280                        return Err(FFTError::ValueError(format!(
281                            "Could not convert {val:?} to f64 or Complex64"
282                        )));
283                    }
284                }
285            }
286        }
287    }
288
289    // Get a flattened view to avoid allocating additional memory
290    let mut buffer = complex_input.as_slice_mut().unwrap().to_vec();
291
292    // Create FFT planner
293    let mut planner = FftPlanner::new();
294
295    // Storage for row-wise FFTs (kept for future optimizations)
296    let _row_buffer = vec![Complex64::new(0.0, 0.0); n_cols_out];
297
298    // Process each row
299    for r in 0..n_rows_out {
300        let row_start = r * n_cols_out;
301        let row_end = row_start + n_cols_out;
302        let row_slice = &mut buffer[row_start..row_end];
303
304        let row_fft = match mode {
305            FftMode::Forward => planner.plan_fft_forward(n_cols_out),
306            FftMode::Inverse => planner.plan_fft_inverse(n_cols_out),
307        };
308
309        // Convert to rustfft's Complex type
310        let mut row_data: Vec<RustComplex<f64>> = row_slice
311            .iter()
312            .map(|&c| RustComplex::new(c.re, c.im))
313            .collect();
314
315        // Perform row-wise FFT
316        row_fft.process(&mut row_data);
317
318        // Convert back and store in buffer
319        for (i, &c) in row_data.iter().enumerate() {
320            row_slice[i] = Complex64::new(c.re, c.im);
321        }
322    }
323
324    // Process columns (with buffer transposition)
325    let mut transposed = vec![Complex64::new(0.0, 0.0); n_rows_out * n_cols_out];
326
327    // Transpose data
328    for r in 0..n_rows_out {
329        for c in 0..n_cols_out {
330            let src_idx = r * n_cols_out + c;
331            let dst_idx = c * n_rows_out + r;
332            transposed[dst_idx] = buffer[src_idx];
333        }
334    }
335
336    // Storage for column FFTs (kept for future optimizations)
337    let _col_buffer = vec![Complex64::new(0.0, 0.0); n_rows_out];
338
339    // Process each column (as rows in transposed data)
340    for c in 0..n_cols_out {
341        let col_start = c * n_rows_out;
342        let col_end = col_start + n_rows_out;
343        let col_slice = &mut transposed[col_start..col_end];
344
345        let col_fft = match mode {
346            FftMode::Forward => planner.plan_fft_forward(n_rows_out),
347            FftMode::Inverse => planner.plan_fft_inverse(n_rows_out),
348        };
349
350        // Convert to rustfft's Complex type
351        let mut col_data: Vec<RustComplex<f64>> = col_slice
352            .iter()
353            .map(|&c| RustComplex::new(c.re, c.im))
354            .collect();
355
356        // Perform column-wise FFT
357        col_fft.process(&mut col_data);
358
359        // Convert back and store in buffer
360        for (i, &c) in col_data.iter().enumerate() {
361            col_slice[i] = Complex64::new(c.re, c.im);
362        }
363    }
364
365    // Final result with proper normalization
366    let scale = if normalize {
367        1.0 / ((n_rows_out * n_cols_out) as f64)
368    } else {
369        1.0
370    };
371
372    let mut result = Array2::zeros((n_rows_out, n_cols_out));
373
374    // Transpose back to original shape
375    for r in 0..n_rows_out {
376        for c in 0..n_cols_out {
377            let src_idx = c * n_rows_out + r;
378            let val = transposed[src_idx];
379            result[[r, c]] = Complex64::new(val.re * scale, val.im * scale);
380        }
381    }
382
383    Ok(result)
384}
385
386/// Compute large array FFT with streaming to minimize memory usage
387///
388/// This function computes the FFT of a large array by processing it in chunks,
389/// which reduces the memory footprint for very large arrays.
390///
391/// # Arguments
392///
393/// * `input` - Input array
394/// * `n` - Length of the transformed axis (optional)
395/// * `mode` - Whether to compute forward or inverse FFT
396/// * `chunk_size` - Size of chunks to process at once
397///
398/// # Returns
399///
400/// * Result with the processed array
401///
402/// # Errors
403///
404/// Returns an error if the computation fails.
405#[allow(dead_code)]
406pub fn fft_streaming<T>(
407    input: &[T],
408    n: Option<usize>,
409    mode: FftMode,
410    chunk_size: Option<usize>,
411) -> FFTResult<Vec<Complex64>>
412where
413    T: NumCast + Copy + Debug + 'static,
414{
415    let input_length = input.len();
416    let n_val = n.unwrap_or(input_length);
417    let chunk_size_val = chunk_size.unwrap_or(
418        // Default chunk _size based on array _size
419        if input_length > 1_000_000 {
420            // For arrays > 1M, use 1024 * 1024
421            1_048_576
422        } else if input_length > 100_000 {
423            // For arrays > 100k, use 64k
424            65_536
425        } else {
426            // For smaller arrays, process in one chunk
427            input_length
428        },
429    );
430
431    // For small arrays, don't use chunking
432    if input_length <= chunk_size_val || n_val <= chunk_size_val {
433        // Convert input to complex vector
434        let mut complex_input: Vec<Complex64> = Vec::with_capacity(input_length);
435
436        for &val in input {
437            match num_traits::cast::cast::<T, f64>(val) {
438                Some(val_f64) => {
439                    complex_input.push(Complex64::new(val_f64, 0.0));
440                }
441                None => {
442                    // Check if this is already a complex number
443                    if let Some(complex_val) = downcast_to_complex::<T>(&val) {
444                        complex_input.push(complex_val);
445                    } else {
446                        return Err(FFTError::ValueError(format!(
447                            "Could not convert {val:?} to f64 or Complex64"
448                        )));
449                    }
450                }
451            }
452        }
453
454        // Handle the case where n is provided
455        match n_val.cmp(&complex_input.len()) {
456            std::cmp::Ordering::Less => {
457                // Truncate the input if n is smaller
458                complex_input.truncate(n_val);
459            }
460            std::cmp::Ordering::Greater => {
461                // Zero-pad the input if n is larger
462                complex_input.resize(n_val, Complex64::new(0.0, 0.0));
463            }
464            std::cmp::Ordering::Equal => {
465                // No resizing needed
466            }
467        }
468
469        // Set up rustfft for computation
470        let mut planner = FftPlanner::new();
471        let fft = match mode {
472            FftMode::Forward => planner.plan_fft_forward(n_val),
473            FftMode::Inverse => planner.plan_fft_inverse(n_val),
474        };
475
476        // Convert to rustfft's Complex type
477        let mut buffer: Vec<RustComplex<f64>> = complex_input
478            .iter()
479            .map(|&c| RustComplex::new(c.re, c.im))
480            .collect();
481
482        // Perform the FFT
483        fft.process(&mut buffer);
484
485        // Convert back to num_complex::Complex64 and apply normalization if needed
486        let scale = if mode == FftMode::Inverse {
487            1.0 / (n_val as f64)
488        } else {
489            1.0
490        };
491
492        let result: Vec<Complex64> = buffer
493            .into_iter()
494            .map(|c| Complex64::new(c.re * scale, c.im * scale))
495            .collect();
496
497        return Ok(result);
498    }
499
500    // Process in chunks for large arrays
501    let chunk_size_nz = NonZeroUsize::new(chunk_size_val).unwrap_or(NonZeroUsize::new(1).unwrap());
502    let n_chunks = n_val.div_ceil(chunk_size_nz.get());
503    let mut result = Vec::with_capacity(n_val);
504
505    for i in 0..n_chunks {
506        let start = i * chunk_size_val;
507        let end = (start + chunk_size_val).min(n_val);
508        let chunk_size = end - start;
509
510        // Prepare input chunk (either from original input or zero-padded)
511        let mut chunk_input = Vec::with_capacity(chunk_size);
512
513        if start < input_length {
514            // Part of the chunk comes from the input
515            let input_end = end.min(input_length);
516            for val in input[start..input_end].iter() {
517                match num_traits::cast::cast::<T, f64>(*val) {
518                    Some(val_f64) => {
519                        chunk_input.push(Complex64::new(val_f64, 0.0));
520                    }
521                    None => {
522                        // Check if this is already a complex number
523                        if let Some(complex_val) = downcast_to_complex::<T>(val) {
524                            chunk_input.push(complex_val);
525                        } else {
526                            return Err(FFTError::ValueError(format!(
527                                "Could not convert {val:?} to f64 or Complex64"
528                            )));
529                        }
530                    }
531                }
532            }
533
534            // Zero-pad the rest if needed
535            if input_end < end {
536                chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
537            }
538        } else {
539            // Chunk is entirely outside the input range, so zero-pad
540            chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
541        }
542
543        // Set up rustfft for computation on this chunk
544        let mut planner = FftPlanner::new();
545        let fft = match mode {
546            FftMode::Forward => planner.plan_fft_forward(chunk_size),
547            FftMode::Inverse => planner.plan_fft_inverse(chunk_size),
548        };
549
550        // Convert to rustfft's Complex type
551        let mut buffer: Vec<RustComplex<f64>> = chunk_input
552            .iter()
553            .map(|&c| RustComplex::new(c.re, c.im))
554            .collect();
555
556        // Perform the FFT on this chunk
557        fft.process(&mut buffer);
558
559        // Convert back to num_complex::Complex64 and apply normalization if needed
560        let scale = if mode == FftMode::Inverse {
561            1.0 / (chunk_size as f64)
562        } else {
563            1.0
564        };
565
566        let chunk_result: Vec<Complex64> = buffer
567            .into_iter()
568            .map(|c| Complex64::new(c.re * scale, c.im * scale))
569            .collect();
570
571        // Add chunk result to the final result
572        result.extend(chunk_result);
573    }
574
575    // For inverse transforms, we need to normalize by the full length
576    // instead of chunk size, so adjust the scaling
577    if mode == FftMode::Inverse {
578        let full_scale = 1.0 / (n_val as f64);
579        let chunk_scale = 1.0 / (chunk_size_val as f64);
580        let scale_adjustment = full_scale / chunk_scale;
581
582        for val in &mut result {
583            val.re *= scale_adjustment;
584            val.im *= scale_adjustment;
585        }
586    }
587
588    Ok(result)
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use approx::assert_relative_eq;
595    use ndarray::array;
596
597    #[test]
598    fn test_fft_inplace() {
599        // Test with a simple signal
600        let mut input = vec![
601            Complex64::new(1.0, 0.0),
602            Complex64::new(2.0, 0.0),
603            Complex64::new(3.0, 0.0),
604            Complex64::new(4.0, 0.0),
605        ];
606        let mut output = vec![Complex64::new(0.0, 0.0); 4];
607
608        // Perform forward FFT
609        fft_inplace(&mut input, &mut output, FftMode::Forward, false).unwrap();
610
611        // Check DC component is sum of all inputs
612        assert_relative_eq!(input[0].re, 10.0, epsilon = 1e-10);
613
614        // Perform inverse FFT
615        fft_inplace(&mut input, &mut output, FftMode::Inverse, true).unwrap();
616
617        // Check that we recover the original signal
618        assert_relative_eq!(input[0].re, 1.0, epsilon = 1e-10);
619        assert_relative_eq!(input[1].re, 2.0, epsilon = 1e-10);
620        assert_relative_eq!(input[2].re, 3.0, epsilon = 1e-10);
621        assert_relative_eq!(input[3].re, 4.0, epsilon = 1e-10);
622    }
623
624    #[test]
625    fn test_fft2_efficient() {
626        // Create a 2x2 test array
627        let arr = array![[1.0, 2.0], [3.0, 4.0]];
628
629        // Compute 2D FFT
630        let spectrum_2d = fft2_efficient(&arr.view(), None, FftMode::Forward, false).unwrap();
631
632        // DC component should be sum of all elements
633        assert_relative_eq!(spectrum_2d[[0, 0]].re, 10.0, epsilon = 1e-10);
634
635        // Compute inverse FFT
636        let recovered = fft2_efficient(&spectrum_2d.view(), None, FftMode::Inverse, true).unwrap();
637
638        // Check original values are recovered
639        assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-10);
640        assert_relative_eq!(recovered[[0, 1]].re, 2.0, epsilon = 1e-10);
641        assert_relative_eq!(recovered[[1, 0]].re, 3.0, epsilon = 1e-10);
642        assert_relative_eq!(recovered[[1, 1]].re, 4.0, epsilon = 1e-10);
643    }
644
645    #[test]
646    fn test_fft_streaming() {
647        // Create a test signal
648        let signal = vec![1.0, 2.0, 3.0, 4.0];
649
650        // Test with default chunk size
651        let result = fft_streaming(&signal, None, FftMode::Forward, None).unwrap();
652
653        // Check DC component is sum of inputs
654        assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
655
656        // Test inverse
657        let inverse = fft_streaming(&result, None, FftMode::Inverse, None).unwrap();
658
659        // Check we recover original signal
660        assert_relative_eq!(inverse[0].re, 1.0, epsilon = 1e-10);
661        assert_relative_eq!(inverse[1].re, 2.0, epsilon = 1e-10);
662        assert_relative_eq!(inverse[2].re, 3.0, epsilon = 1e-10);
663        assert_relative_eq!(inverse[3].re, 4.0, epsilon = 1e-10);
664
665        // Test with explicit small chunk size - this is explicitly set to ensure stable test results
666        let result_chunked =
667            fft_streaming(&signal, None, FftMode::Forward, Some(signal.len())).unwrap();
668
669        // Results should be the same
670        for (a, b) in result.iter().zip(result_chunked.iter()) {
671            assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
672            assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
673        }
674    }
675}