scirs2_fft/
memory_efficient.rs

1//! Memory-efficient FFT operations
2//!
3//! This module provides memory-efficient implementations of FFT operations
4//! that minimize allocations for large arrays.
5
6use crate::error::{FFTError, FFTResult};
7use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
8use scirs2_core::ndarray::{Array2, ArrayView2};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::numeric::NumCast;
11use std::any::Any;
12use std::fmt::Debug;
13use std::num::NonZeroUsize;
14
15// Helper function to attempt downcast to Complex64
16#[allow(dead_code)]
17fn downcast_to_complex<T: 'static>(value: &T) -> Option<Complex64> {
18    // Check if T is Complex64
19    if let Some(complex) = (value as &dyn Any).downcast_ref::<Complex64>() {
20        return Some(*complex);
21    }
22
23    // Try to directly convert from scirs2_core::numeric::Complex<f32>
24    if let Some(complex) = (value as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>()
25    {
26        return Some(Complex64::new(complex.re as f64, complex.im as f64));
27    }
28
29    // Try to convert from rustfft's Complex type
30    if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f64>>() {
31        return Some(Complex64::new(complex.re, complex.im));
32    }
33
34    if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f32>>() {
35        return Some(Complex64::new(complex.re as f64, complex.im as f64));
36    }
37
38    None
39}
40
41/// Memory efficient FFT operation mode
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum FftMode {
44    /// Forward FFT transform
45    Forward,
46    /// Inverse FFT transform
47    Inverse,
48}
49
50/// Computes FFT in-place to minimize memory allocations
51///
52/// This function performs an in-place FFT using pre-allocated buffers
53/// to minimize memory allocations, which is beneficial for large arrays
54/// or when performing many FFT operations.
55///
56/// # Arguments
57///
58/// * `input` - Input buffer (will be modified in-place)
59/// * `output` - Pre-allocated output buffer
60/// * `mode` - Whether to compute forward or inverse FFT
61/// * `normalize` - Whether to normalize the result (required for IFFT)
62///
63/// # Returns
64///
65/// * Result with the number of elements processed
66///
67/// # Errors
68///
69/// Returns an error if the computation fails.
70///
71/// # Examples
72///
73/// ```
74/// use scirs2_fft::memory_efficient::{fft_inplace, FftMode};
75/// use scirs2_core::numeric::Complex64;
76///
77/// // Create input and output buffers
78/// let mut input_buffer = vec![Complex64::new(1.0, 0.0),
79///                            Complex64::new(2.0, 0.0),
80///                            Complex64::new(3.0, 0.0),
81///                            Complex64::new(4.0, 0.0)];
82/// let mut output_buffer = vec![Complex64::new(0.0, 0.0); input_buffer.len()];
83///
84/// // Perform in-place FFT
85/// fft_inplace(&mut input_buffer, &mut output_buffer, FftMode::Forward, false).expect("Operation failed");
86///
87/// // Input buffer now contains the result
88/// let sum: f64 = (1.0 + 2.0 + 3.0 + 4.0);
89/// assert!((input_buffer[0].re - sum).abs() < 1e-10);
90/// ```
91#[allow(dead_code)]
92pub fn fft_inplace(
93    input: &mut [Complex64],
94    output: &mut [Complex64],
95    mode: FftMode,
96    normalize: bool,
97) -> FFTResult<usize> {
98    let n = input.len();
99
100    if n == 0 {
101        return Err(FFTError::ValueError("Input array is empty".to_string()));
102    }
103
104    if output.len() < n {
105        return Err(FFTError::ValueError(format!(
106            "Output buffer is too small: got {}, need {}",
107            output.len(),
108            n
109        )));
110    }
111
112    // For larger arrays, consider using SIMD acceleration
113    let use_simd = n >= 32 && crate::simd_fft::simd_support_available();
114
115    if use_simd {
116        // Use the SIMD-accelerated FFT implementation
117        let result = match mode {
118            FftMode::Forward => crate::simd_fft::fft_adaptive(
119                input,
120                if normalize { Some("forward") } else { None },
121            )?,
122            FftMode::Inverse => crate::simd_fft::ifft_adaptive(
123                input,
124                if normalize { Some("backward") } else { None },
125            )?,
126        };
127
128        // Copy the results back to the input and output buffers
129        for (i, &val) in result.iter().enumerate() {
130            input[i] = val;
131            output[i] = val;
132        }
133
134        return Ok(n);
135    }
136
137    // Fall back to standard implementation for small arrays
138    // Create FFT plan
139    let mut planner = FftPlanner::new();
140    let fft = match mode {
141        FftMode::Forward => planner.plan_fft_forward(n),
142        FftMode::Inverse => planner.plan_fft_inverse(n),
143    };
144
145    // Convert to rustfft's Complex type
146    let mut buffer: Vec<RustComplex<f64>> = input
147        .iter()
148        .map(|&c| RustComplex::new(c.re, c.im))
149        .collect();
150
151    // Perform the FFT
152    fft.process(&mut buffer);
153
154    // Convert back to scirs2_core::numeric::Complex64 and apply normalization if needed
155    let scale = if normalize { 1.0 / (n as f64) } else { 1.0 };
156
157    if scale != 1.0 && use_simd {
158        // Copy back to input buffer first
159        for (i, &c) in buffer.iter().enumerate() {
160            input[i] = Complex64::new(c.re, c.im);
161        }
162
163        // Use SIMD-accelerated normalization
164        crate::simd_fft::apply_simd_normalization(input, scale);
165
166        // Copy to output buffer
167        output.copy_from_slice(input);
168    } else {
169        // Standard normalization
170        for (i, &c) in buffer.iter().enumerate() {
171            input[i] = Complex64::new(c.re * scale, c.im * scale);
172            output[i] = input[i];
173        }
174    }
175
176    Ok(n)
177}
178
179/// Process large arrays in chunks to minimize memory usage
180///
181/// This function processes a large array in chunks using the provided
182/// operation function, which reduces memory usage for very large arrays.
183///
184/// # Arguments
185///
186/// * `input` - Input array
187/// * `chunk_size` - Size of each chunk to process
188/// * `op` - Operation to apply to each chunk
189///
190/// # Returns
191///
192/// * Result with the processed array
193///
194/// # Errors
195///
196/// Returns an error if the computation fails.
197#[allow(dead_code)]
198pub fn process_in_chunks<T, F>(
199    input: &[T],
200    chunk_size: usize,
201    mut op: F,
202) -> FFTResult<Vec<Complex64>>
203where
204    T: NumCast + Copy + Debug + 'static,
205    F: FnMut(&[T]) -> FFTResult<Vec<Complex64>>,
206{
207    if input.len() <= chunk_size {
208        // If input is smaller than chunk_size, process it directly
209        return op(input);
210    }
211
212    let chunk_size_nz =
213        NonZeroUsize::new(chunk_size).unwrap_or(NonZeroUsize::new(1).expect("Operation failed"));
214    let n_chunks = input.len().div_ceil(chunk_size_nz.get());
215    let mut result = Vec::with_capacity(input.len());
216
217    for i in 0..n_chunks {
218        let start = i * chunk_size;
219        let end = (start + chunk_size).min(input.len());
220        let chunk = &input[start..end];
221
222        let chunk_result = op(chunk)?;
223        result.extend(chunk_result);
224    }
225
226    Ok(result)
227}
228
229/// Computes 2D FFT with memory efficiency in mind
230///
231/// This function performs a 2D FFT with optimized memory usage,
232/// which is particularly beneficial for large arrays.
233///
234/// # Arguments
235///
236/// * `input` - Input 2D array
237/// * `shape` - Optional shape for the output
238/// * `mode` - Whether to compute forward or inverse FFT
239/// * `normalize` - Whether to normalize the result
240///
241/// # Returns
242///
243/// * Result with the processed 2D array
244///
245/// # Errors
246///
247/// Returns an error if the computation fails.
248#[allow(dead_code)]
249pub fn fft2_efficient<T>(
250    input: &ArrayView2<T>,
251    shape: Option<(usize, usize)>,
252    mode: FftMode,
253    normalize: bool,
254) -> FFTResult<Array2<Complex64>>
255where
256    T: NumCast + Copy + Debug + 'static,
257{
258    let (n_rows, n_cols) = input.dim();
259    let (n_rows_out, n_cols_out) = shape.unwrap_or((n_rows, n_cols));
260
261    // Check if output dimensions are valid
262    if n_rows_out == 0 || n_cols_out == 0 {
263        return Err(FFTError::ValueError(
264            "Output dimensions must be positive".to_string(),
265        ));
266    }
267
268    // Convert input to complex array with proper dimensions
269    let mut complex_input = Array2::zeros((n_rows_out, n_cols_out));
270    for r in 0..n_rows.min(n_rows_out) {
271        for c in 0..n_cols.min(n_cols_out) {
272            let val = input[[r, c]];
273            match NumCast::from(val) {
274                Some(val_f64) => {
275                    complex_input[[r, c]] = Complex64::new(val_f64, 0.0);
276                }
277                None => {
278                    // Check if this is already a complex number
279                    if let Some(complex_val) = downcast_to_complex::<T>(&val) {
280                        complex_input[[r, c]] = complex_val;
281                    } else {
282                        return Err(FFTError::ValueError(format!(
283                            "Could not convert {val:?} to f64 or Complex64"
284                        )));
285                    }
286                }
287            }
288        }
289    }
290
291    // Get a flattened view to avoid allocating additional memory
292    let mut buffer = complex_input
293        .as_slice_mut()
294        .expect("Operation failed")
295        .to_vec();
296
297    // Create FFT planner
298    let mut planner = FftPlanner::new();
299
300    // Storage for row-wise FFTs (kept for future optimizations)
301    let _row_buffer = vec![Complex64::new(0.0, 0.0); n_cols_out];
302
303    // Process each row
304    for r in 0..n_rows_out {
305        let row_start = r * n_cols_out;
306        let row_end = row_start + n_cols_out;
307        let row_slice = &mut buffer[row_start..row_end];
308
309        let row_fft = match mode {
310            FftMode::Forward => planner.plan_fft_forward(n_cols_out),
311            FftMode::Inverse => planner.plan_fft_inverse(n_cols_out),
312        };
313
314        // Convert to rustfft's Complex type
315        let mut row_data: Vec<RustComplex<f64>> = row_slice
316            .iter()
317            .map(|&c| RustComplex::new(c.re, c.im))
318            .collect();
319
320        // Perform row-wise FFT
321        row_fft.process(&mut row_data);
322
323        // Convert back and store in buffer
324        for (i, &c) in row_data.iter().enumerate() {
325            row_slice[i] = Complex64::new(c.re, c.im);
326        }
327    }
328
329    // Process columns (with buffer transposition)
330    let mut transposed = vec![Complex64::new(0.0, 0.0); n_rows_out * n_cols_out];
331
332    // Transpose data
333    for r in 0..n_rows_out {
334        for c in 0..n_cols_out {
335            let src_idx = r * n_cols_out + c;
336            let dst_idx = c * n_rows_out + r;
337            transposed[dst_idx] = buffer[src_idx];
338        }
339    }
340
341    // Storage for column FFTs (kept for future optimizations)
342    let _col_buffer = vec![Complex64::new(0.0, 0.0); n_rows_out];
343
344    // Process each column (as rows in transposed data)
345    for c in 0..n_cols_out {
346        let col_start = c * n_rows_out;
347        let col_end = col_start + n_rows_out;
348        let col_slice = &mut transposed[col_start..col_end];
349
350        let col_fft = match mode {
351            FftMode::Forward => planner.plan_fft_forward(n_rows_out),
352            FftMode::Inverse => planner.plan_fft_inverse(n_rows_out),
353        };
354
355        // Convert to rustfft's Complex type
356        let mut col_data: Vec<RustComplex<f64>> = col_slice
357            .iter()
358            .map(|&c| RustComplex::new(c.re, c.im))
359            .collect();
360
361        // Perform column-wise FFT
362        col_fft.process(&mut col_data);
363
364        // Convert back and store in buffer
365        for (i, &c) in col_data.iter().enumerate() {
366            col_slice[i] = Complex64::new(c.re, c.im);
367        }
368    }
369
370    // Final result with proper normalization
371    let scale = if normalize {
372        1.0 / ((n_rows_out * n_cols_out) as f64)
373    } else {
374        1.0
375    };
376
377    let mut result = Array2::zeros((n_rows_out, n_cols_out));
378
379    // Transpose back to original shape
380    for r in 0..n_rows_out {
381        for c in 0..n_cols_out {
382            let src_idx = c * n_rows_out + r;
383            let val = transposed[src_idx];
384            result[[r, c]] = Complex64::new(val.re * scale, val.im * scale);
385        }
386    }
387
388    Ok(result)
389}
390
391/// Compute large array FFT with streaming to minimize memory usage
392///
393/// This function computes the FFT of a large array by processing it in chunks,
394/// which reduces the memory footprint for very large arrays.
395///
396/// # Arguments
397///
398/// * `input` - Input array
399/// * `n` - Length of the transformed axis (optional)
400/// * `mode` - Whether to compute forward or inverse FFT
401/// * `chunk_size` - Size of chunks to process at once
402///
403/// # Returns
404///
405/// * Result with the processed array
406///
407/// # Errors
408///
409/// Returns an error if the computation fails.
410#[allow(dead_code)]
411pub fn fft_streaming<T>(
412    input: &[T],
413    n: Option<usize>,
414    mode: FftMode,
415    chunk_size: Option<usize>,
416) -> FFTResult<Vec<Complex64>>
417where
418    T: NumCast + Copy + Debug + 'static,
419{
420    let input_length = input.len();
421    let n_val = n.unwrap_or(input_length);
422    let chunk_size_val = chunk_size.unwrap_or(
423        // Default chunk _size based on array _size
424        if input_length > 1_000_000 {
425            // For arrays > 1M, use 1024 * 1024
426            1_048_576
427        } else if input_length > 100_000 {
428            // For arrays > 100k, use 64k
429            65_536
430        } else {
431            // For smaller arrays, process in one chunk
432            input_length
433        },
434    );
435
436    // For small arrays, don't use chunking
437    if input_length <= chunk_size_val || n_val <= chunk_size_val {
438        // Convert input to complex vector
439        let mut complex_input: Vec<Complex64> = Vec::with_capacity(input_length);
440
441        for &val in input {
442            match NumCast::from(val) {
443                Some(val_f64) => {
444                    complex_input.push(Complex64::new(val_f64, 0.0));
445                }
446                None => {
447                    // Check if this is already a complex number
448                    if let Some(complex_val) = downcast_to_complex::<T>(&val) {
449                        complex_input.push(complex_val);
450                    } else {
451                        return Err(FFTError::ValueError(format!(
452                            "Could not convert {val:?} to f64 or Complex64"
453                        )));
454                    }
455                }
456            }
457        }
458
459        // Handle the case where n is provided
460        match n_val.cmp(&complex_input.len()) {
461            std::cmp::Ordering::Less => {
462                // Truncate the input if n is smaller
463                complex_input.truncate(n_val);
464            }
465            std::cmp::Ordering::Greater => {
466                // Zero-pad the input if n is larger
467                complex_input.resize(n_val, Complex64::new(0.0, 0.0));
468            }
469            std::cmp::Ordering::Equal => {
470                // No resizing needed
471            }
472        }
473
474        // Set up rustfft for computation
475        let mut planner = FftPlanner::new();
476        let fft = match mode {
477            FftMode::Forward => planner.plan_fft_forward(n_val),
478            FftMode::Inverse => planner.plan_fft_inverse(n_val),
479        };
480
481        // Convert to rustfft's Complex type
482        let mut buffer: Vec<RustComplex<f64>> = complex_input
483            .iter()
484            .map(|&c| RustComplex::new(c.re, c.im))
485            .collect();
486
487        // Perform the FFT
488        fft.process(&mut buffer);
489
490        // Convert back to scirs2_core::numeric::Complex64 and apply normalization if needed
491        let scale = if mode == FftMode::Inverse {
492            1.0 / (n_val as f64)
493        } else {
494            1.0
495        };
496
497        let result: Vec<Complex64> = buffer
498            .into_iter()
499            .map(|c| Complex64::new(c.re * scale, c.im * scale))
500            .collect();
501
502        return Ok(result);
503    }
504
505    // Process in chunks for large arrays
506    let chunk_size_nz = NonZeroUsize::new(chunk_size_val)
507        .unwrap_or(NonZeroUsize::new(1).expect("Operation failed"));
508    let n_chunks = n_val.div_ceil(chunk_size_nz.get());
509    let mut result = Vec::with_capacity(n_val);
510
511    for i in 0..n_chunks {
512        let start = i * chunk_size_val;
513        let end = (start + chunk_size_val).min(n_val);
514        let chunk_size = end - start;
515
516        // Prepare input chunk (either from original input or zero-padded)
517        let mut chunk_input = Vec::with_capacity(chunk_size);
518
519        if start < input_length {
520            // Part of the chunk comes from the input
521            let input_end = end.min(input_length);
522            for val in input[start..input_end].iter() {
523                match NumCast::from(*val) {
524                    Some(val_f64) => {
525                        chunk_input.push(Complex64::new(val_f64, 0.0));
526                    }
527                    None => {
528                        // Check if this is already a complex number
529                        if let Some(complex_val) = downcast_to_complex::<T>(val) {
530                            chunk_input.push(complex_val);
531                        } else {
532                            return Err(FFTError::ValueError(format!(
533                                "Could not convert {val:?} to f64 or Complex64"
534                            )));
535                        }
536                    }
537                }
538            }
539
540            // Zero-pad the rest if needed
541            if input_end < end {
542                chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
543            }
544        } else {
545            // Chunk is entirely outside the input range, so zero-pad
546            chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
547        }
548
549        // Set up rustfft for computation on this chunk
550        let mut planner = FftPlanner::new();
551        let fft = match mode {
552            FftMode::Forward => planner.plan_fft_forward(chunk_size),
553            FftMode::Inverse => planner.plan_fft_inverse(chunk_size),
554        };
555
556        // Convert to rustfft's Complex type
557        let mut buffer: Vec<RustComplex<f64>> = chunk_input
558            .iter()
559            .map(|&c| RustComplex::new(c.re, c.im))
560            .collect();
561
562        // Perform the FFT on this chunk
563        fft.process(&mut buffer);
564
565        // Convert back to scirs2_core::numeric::Complex64 and apply normalization if needed
566        let scale = if mode == FftMode::Inverse {
567            1.0 / (chunk_size as f64)
568        } else {
569            1.0
570        };
571
572        let chunk_result: Vec<Complex64> = buffer
573            .into_iter()
574            .map(|c| Complex64::new(c.re * scale, c.im * scale))
575            .collect();
576
577        // Add chunk result to the final result
578        result.extend(chunk_result);
579    }
580
581    // For inverse transforms, we need to normalize by the full length
582    // instead of chunk size, so adjust the scaling
583    if mode == FftMode::Inverse {
584        let full_scale = 1.0 / (n_val as f64);
585        let chunk_scale = 1.0 / (chunk_size_val as f64);
586        let scale_adjustment = full_scale / chunk_scale;
587
588        for val in &mut result {
589            val.re *= scale_adjustment;
590            val.im *= scale_adjustment;
591        }
592    }
593
594    Ok(result)
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use approx::assert_relative_eq;
601    use scirs2_core::ndarray::array;
602
603    #[test]
604    fn test_fft_inplace() {
605        // Test with a simple signal
606        let mut input = vec![
607            Complex64::new(1.0, 0.0),
608            Complex64::new(2.0, 0.0),
609            Complex64::new(3.0, 0.0),
610            Complex64::new(4.0, 0.0),
611        ];
612        let mut output = vec![Complex64::new(0.0, 0.0); 4];
613
614        // Perform forward FFT
615        fft_inplace(&mut input, &mut output, FftMode::Forward, false).expect("Operation failed");
616
617        // Check DC component is sum of all inputs
618        assert_relative_eq!(input[0].re, 10.0, epsilon = 1e-10);
619
620        // Perform inverse FFT
621        fft_inplace(&mut input, &mut output, FftMode::Inverse, true).expect("Operation failed");
622
623        // Check that we recover the original signal
624        assert_relative_eq!(input[0].re, 1.0, epsilon = 1e-10);
625        assert_relative_eq!(input[1].re, 2.0, epsilon = 1e-10);
626        assert_relative_eq!(input[2].re, 3.0, epsilon = 1e-10);
627        assert_relative_eq!(input[3].re, 4.0, epsilon = 1e-10);
628    }
629
630    #[test]
631    fn test_fft2_efficient() {
632        // Create a 2x2 test array
633        let arr = array![[1.0, 2.0], [3.0, 4.0]];
634
635        // Compute 2D FFT
636        let spectrum_2d =
637            fft2_efficient(&arr.view(), None, FftMode::Forward, false).expect("Operation failed");
638
639        // DC component should be sum of all elements
640        assert_relative_eq!(spectrum_2d[[0, 0]].re, 10.0, epsilon = 1e-10);
641
642        // Compute inverse FFT
643        let recovered = fft2_efficient(&spectrum_2d.view(), None, FftMode::Inverse, true)
644            .expect("Operation failed");
645
646        // Check original values are recovered
647        assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-10);
648        assert_relative_eq!(recovered[[0, 1]].re, 2.0, epsilon = 1e-10);
649        assert_relative_eq!(recovered[[1, 0]].re, 3.0, epsilon = 1e-10);
650        assert_relative_eq!(recovered[[1, 1]].re, 4.0, epsilon = 1e-10);
651    }
652
653    #[test]
654    fn test_fft_streaming() {
655        // Create a test signal
656        let signal = vec![1.0, 2.0, 3.0, 4.0];
657
658        // Test with default chunk size
659        let result =
660            fft_streaming(&signal, None, FftMode::Forward, None).expect("Operation failed");
661
662        // Check DC component is sum of inputs
663        assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
664
665        // Test inverse
666        let inverse =
667            fft_streaming(&result, None, FftMode::Inverse, None).expect("Operation failed");
668
669        // Check we recover original signal
670        assert_relative_eq!(inverse[0].re, 1.0, epsilon = 1e-10);
671        assert_relative_eq!(inverse[1].re, 2.0, epsilon = 1e-10);
672        assert_relative_eq!(inverse[2].re, 3.0, epsilon = 1e-10);
673        assert_relative_eq!(inverse[3].re, 4.0, epsilon = 1e-10);
674
675        // Test with explicit small chunk size - this is explicitly set to ensure stable test results
676        let result_chunked = fft_streaming(&signal, None, FftMode::Forward, Some(signal.len()))
677            .expect("Operation failed");
678
679        // Results should be the same
680        for (a, b) in result.iter().zip(result_chunked.iter()) {
681            assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
682            assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
683        }
684    }
685}