scirs2_fft/
memory_efficient.rs

1//! Memory-efficient FFT operations
2//!
3//! This module provides memory-efficient implementations of FFT operations
4//! that minimize allocations for large arrays.
5
6use crate::error::{FFTError, FFTResult};
7use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
8use scirs2_core::ndarray::{Array2, ArrayView2};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::numeric::NumCast;
11use std::any::Any;
12use std::fmt::Debug;
13use std::num::NonZeroUsize;
14
15// Helper function to attempt downcast to Complex64
16#[allow(dead_code)]
17fn downcast_to_complex<T: 'static>(value: &T) -> Option<Complex64> {
18    // Check if T is Complex64
19    if let Some(complex) = (value as &dyn Any).downcast_ref::<Complex64>() {
20        return Some(*complex);
21    }
22
23    // Try to directly convert from scirs2_core::numeric::Complex<f32>
24    if let Some(complex) = (value as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>()
25    {
26        return Some(Complex64::new(complex.re as f64, complex.im as f64));
27    }
28
29    // Try to convert from rustfft's Complex type
30    if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f64>>() {
31        return Some(Complex64::new(complex.re, complex.im));
32    }
33
34    if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f32>>() {
35        return Some(Complex64::new(complex.re as f64, complex.im as f64));
36    }
37
38    None
39}
40
41/// Memory efficient FFT operation mode
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum FftMode {
44    /// Forward FFT transform
45    Forward,
46    /// Inverse FFT transform
47    Inverse,
48}
49
50/// Computes FFT in-place to minimize memory allocations
51///
52/// This function performs an in-place FFT using pre-allocated buffers
53/// to minimize memory allocations, which is beneficial for large arrays
54/// or when performing many FFT operations.
55///
56/// # Arguments
57///
58/// * `input` - Input buffer (will be modified in-place)
59/// * `output` - Pre-allocated output buffer
60/// * `mode` - Whether to compute forward or inverse FFT
61/// * `normalize` - Whether to normalize the result (required for IFFT)
62///
63/// # Returns
64///
65/// * Result with the number of elements processed
66///
67/// # Errors
68///
69/// Returns an error if the computation fails.
70///
71/// # Examples
72///
73/// ```
74/// use scirs2_fft::memory_efficient::{fft_inplace, FftMode};
75/// use scirs2_core::numeric::Complex64;
76///
77/// // Create input and output buffers
78/// let mut input_buffer = vec![Complex64::new(1.0, 0.0),
79///                            Complex64::new(2.0, 0.0),
80///                            Complex64::new(3.0, 0.0),
81///                            Complex64::new(4.0, 0.0)];
82/// let mut output_buffer = vec![Complex64::new(0.0, 0.0); input_buffer.len()];
83///
84/// // Perform in-place FFT
85/// fft_inplace(&mut input_buffer, &mut output_buffer, FftMode::Forward, false).unwrap();
86///
87/// // Input buffer now contains the result
88/// let sum: f64 = (1.0 + 2.0 + 3.0 + 4.0);
89/// assert!((input_buffer[0].re - sum).abs() < 1e-10);
90/// ```
91#[allow(dead_code)]
92pub fn fft_inplace(
93    input: &mut [Complex64],
94    output: &mut [Complex64],
95    mode: FftMode,
96    normalize: bool,
97) -> FFTResult<usize> {
98    let n = input.len();
99
100    if n == 0 {
101        return Err(FFTError::ValueError("Input array is empty".to_string()));
102    }
103
104    if output.len() < n {
105        return Err(FFTError::ValueError(format!(
106            "Output buffer is too small: got {}, need {}",
107            output.len(),
108            n
109        )));
110    }
111
112    // For larger arrays, consider using SIMD acceleration
113    let use_simd = n >= 32 && crate::simd_fft::simd_support_available();
114
115    if use_simd {
116        // Use the SIMD-accelerated FFT implementation
117        let result = match mode {
118            FftMode::Forward => crate::simd_fft::fft_adaptive(
119                input,
120                if normalize { Some("forward") } else { None },
121            )?,
122            FftMode::Inverse => crate::simd_fft::ifft_adaptive(
123                input,
124                if normalize { Some("backward") } else { None },
125            )?,
126        };
127
128        // Copy the results back to the input and output buffers
129        for (i, &val) in result.iter().enumerate() {
130            input[i] = val;
131            output[i] = val;
132        }
133
134        return Ok(n);
135    }
136
137    // Fall back to standard implementation for small arrays
138    // Create FFT plan
139    let mut planner = FftPlanner::new();
140    let fft = match mode {
141        FftMode::Forward => planner.plan_fft_forward(n),
142        FftMode::Inverse => planner.plan_fft_inverse(n),
143    };
144
145    // Convert to rustfft's Complex type
146    let mut buffer: Vec<RustComplex<f64>> = input
147        .iter()
148        .map(|&c| RustComplex::new(c.re, c.im))
149        .collect();
150
151    // Perform the FFT
152    fft.process(&mut buffer);
153
154    // Convert back to scirs2_core::numeric::Complex64 and apply normalization if needed
155    let scale = if normalize { 1.0 / (n as f64) } else { 1.0 };
156
157    if scale != 1.0 && use_simd {
158        // Copy back to input buffer first
159        for (i, &c) in buffer.iter().enumerate() {
160            input[i] = Complex64::new(c.re, c.im);
161        }
162
163        // Use SIMD-accelerated normalization
164        crate::simd_fft::apply_simd_normalization(input, scale);
165
166        // Copy to output buffer
167        output.copy_from_slice(input);
168    } else {
169        // Standard normalization
170        for (i, &c) in buffer.iter().enumerate() {
171            input[i] = Complex64::new(c.re * scale, c.im * scale);
172            output[i] = input[i];
173        }
174    }
175
176    Ok(n)
177}
178
179/// Process large arrays in chunks to minimize memory usage
180///
181/// This function processes a large array in chunks using the provided
182/// operation function, which reduces memory usage for very large arrays.
183///
184/// # Arguments
185///
186/// * `input` - Input array
187/// * `chunk_size` - Size of each chunk to process
188/// * `op` - Operation to apply to each chunk
189///
190/// # Returns
191///
192/// * Result with the processed array
193///
194/// # Errors
195///
196/// Returns an error if the computation fails.
197#[allow(dead_code)]
198pub fn process_in_chunks<T, F>(
199    input: &[T],
200    chunk_size: usize,
201    mut op: F,
202) -> FFTResult<Vec<Complex64>>
203where
204    T: NumCast + Copy + Debug + 'static,
205    F: FnMut(&[T]) -> FFTResult<Vec<Complex64>>,
206{
207    if input.len() <= chunk_size {
208        // If input is smaller than chunk_size, process it directly
209        return op(input);
210    }
211
212    let chunk_size_nz = NonZeroUsize::new(chunk_size).unwrap_or(NonZeroUsize::new(1).unwrap());
213    let n_chunks = input.len().div_ceil(chunk_size_nz.get());
214    let mut result = Vec::with_capacity(input.len());
215
216    for i in 0..n_chunks {
217        let start = i * chunk_size;
218        let end = (start + chunk_size).min(input.len());
219        let chunk = &input[start..end];
220
221        let chunk_result = op(chunk)?;
222        result.extend(chunk_result);
223    }
224
225    Ok(result)
226}
227
228/// Computes 2D FFT with memory efficiency in mind
229///
230/// This function performs a 2D FFT with optimized memory usage,
231/// which is particularly beneficial for large arrays.
232///
233/// # Arguments
234///
235/// * `input` - Input 2D array
236/// * `shape` - Optional shape for the output
237/// * `mode` - Whether to compute forward or inverse FFT
238/// * `normalize` - Whether to normalize the result
239///
240/// # Returns
241///
242/// * Result with the processed 2D array
243///
244/// # Errors
245///
246/// Returns an error if the computation fails.
247#[allow(dead_code)]
248pub fn fft2_efficient<T>(
249    input: &ArrayView2<T>,
250    shape: Option<(usize, usize)>,
251    mode: FftMode,
252    normalize: bool,
253) -> FFTResult<Array2<Complex64>>
254where
255    T: NumCast + Copy + Debug + 'static,
256{
257    let (n_rows, n_cols) = input.dim();
258    let (n_rows_out, n_cols_out) = shape.unwrap_or((n_rows, n_cols));
259
260    // Check if output dimensions are valid
261    if n_rows_out == 0 || n_cols_out == 0 {
262        return Err(FFTError::ValueError(
263            "Output dimensions must be positive".to_string(),
264        ));
265    }
266
267    // Convert input to complex array with proper dimensions
268    let mut complex_input = Array2::zeros((n_rows_out, n_cols_out));
269    for r in 0..n_rows.min(n_rows_out) {
270        for c in 0..n_cols.min(n_cols_out) {
271            let val = input[[r, c]];
272            match NumCast::from(val) {
273                Some(val_f64) => {
274                    complex_input[[r, c]] = Complex64::new(val_f64, 0.0);
275                }
276                None => {
277                    // Check if this is already a complex number
278                    if let Some(complex_val) = downcast_to_complex::<T>(&val) {
279                        complex_input[[r, c]] = complex_val;
280                    } else {
281                        return Err(FFTError::ValueError(format!(
282                            "Could not convert {val:?} to f64 or Complex64"
283                        )));
284                    }
285                }
286            }
287        }
288    }
289
290    // Get a flattened view to avoid allocating additional memory
291    let mut buffer = complex_input.as_slice_mut().unwrap().to_vec();
292
293    // Create FFT planner
294    let mut planner = FftPlanner::new();
295
296    // Storage for row-wise FFTs (kept for future optimizations)
297    let _row_buffer = vec![Complex64::new(0.0, 0.0); n_cols_out];
298
299    // Process each row
300    for r in 0..n_rows_out {
301        let row_start = r * n_cols_out;
302        let row_end = row_start + n_cols_out;
303        let row_slice = &mut buffer[row_start..row_end];
304
305        let row_fft = match mode {
306            FftMode::Forward => planner.plan_fft_forward(n_cols_out),
307            FftMode::Inverse => planner.plan_fft_inverse(n_cols_out),
308        };
309
310        // Convert to rustfft's Complex type
311        let mut row_data: Vec<RustComplex<f64>> = row_slice
312            .iter()
313            .map(|&c| RustComplex::new(c.re, c.im))
314            .collect();
315
316        // Perform row-wise FFT
317        row_fft.process(&mut row_data);
318
319        // Convert back and store in buffer
320        for (i, &c) in row_data.iter().enumerate() {
321            row_slice[i] = Complex64::new(c.re, c.im);
322        }
323    }
324
325    // Process columns (with buffer transposition)
326    let mut transposed = vec![Complex64::new(0.0, 0.0); n_rows_out * n_cols_out];
327
328    // Transpose data
329    for r in 0..n_rows_out {
330        for c in 0..n_cols_out {
331            let src_idx = r * n_cols_out + c;
332            let dst_idx = c * n_rows_out + r;
333            transposed[dst_idx] = buffer[src_idx];
334        }
335    }
336
337    // Storage for column FFTs (kept for future optimizations)
338    let _col_buffer = vec![Complex64::new(0.0, 0.0); n_rows_out];
339
340    // Process each column (as rows in transposed data)
341    for c in 0..n_cols_out {
342        let col_start = c * n_rows_out;
343        let col_end = col_start + n_rows_out;
344        let col_slice = &mut transposed[col_start..col_end];
345
346        let col_fft = match mode {
347            FftMode::Forward => planner.plan_fft_forward(n_rows_out),
348            FftMode::Inverse => planner.plan_fft_inverse(n_rows_out),
349        };
350
351        // Convert to rustfft's Complex type
352        let mut col_data: Vec<RustComplex<f64>> = col_slice
353            .iter()
354            .map(|&c| RustComplex::new(c.re, c.im))
355            .collect();
356
357        // Perform column-wise FFT
358        col_fft.process(&mut col_data);
359
360        // Convert back and store in buffer
361        for (i, &c) in col_data.iter().enumerate() {
362            col_slice[i] = Complex64::new(c.re, c.im);
363        }
364    }
365
366    // Final result with proper normalization
367    let scale = if normalize {
368        1.0 / ((n_rows_out * n_cols_out) as f64)
369    } else {
370        1.0
371    };
372
373    let mut result = Array2::zeros((n_rows_out, n_cols_out));
374
375    // Transpose back to original shape
376    for r in 0..n_rows_out {
377        for c in 0..n_cols_out {
378            let src_idx = c * n_rows_out + r;
379            let val = transposed[src_idx];
380            result[[r, c]] = Complex64::new(val.re * scale, val.im * scale);
381        }
382    }
383
384    Ok(result)
385}
386
387/// Compute large array FFT with streaming to minimize memory usage
388///
389/// This function computes the FFT of a large array by processing it in chunks,
390/// which reduces the memory footprint for very large arrays.
391///
392/// # Arguments
393///
394/// * `input` - Input array
395/// * `n` - Length of the transformed axis (optional)
396/// * `mode` - Whether to compute forward or inverse FFT
397/// * `chunk_size` - Size of chunks to process at once
398///
399/// # Returns
400///
401/// * Result with the processed array
402///
403/// # Errors
404///
405/// Returns an error if the computation fails.
406#[allow(dead_code)]
407pub fn fft_streaming<T>(
408    input: &[T],
409    n: Option<usize>,
410    mode: FftMode,
411    chunk_size: Option<usize>,
412) -> FFTResult<Vec<Complex64>>
413where
414    T: NumCast + Copy + Debug + 'static,
415{
416    let input_length = input.len();
417    let n_val = n.unwrap_or(input_length);
418    let chunk_size_val = chunk_size.unwrap_or(
419        // Default chunk _size based on array _size
420        if input_length > 1_000_000 {
421            // For arrays > 1M, use 1024 * 1024
422            1_048_576
423        } else if input_length > 100_000 {
424            // For arrays > 100k, use 64k
425            65_536
426        } else {
427            // For smaller arrays, process in one chunk
428            input_length
429        },
430    );
431
432    // For small arrays, don't use chunking
433    if input_length <= chunk_size_val || n_val <= chunk_size_val {
434        // Convert input to complex vector
435        let mut complex_input: Vec<Complex64> = Vec::with_capacity(input_length);
436
437        for &val in input {
438            match NumCast::from(val) {
439                Some(val_f64) => {
440                    complex_input.push(Complex64::new(val_f64, 0.0));
441                }
442                None => {
443                    // Check if this is already a complex number
444                    if let Some(complex_val) = downcast_to_complex::<T>(&val) {
445                        complex_input.push(complex_val);
446                    } else {
447                        return Err(FFTError::ValueError(format!(
448                            "Could not convert {val:?} to f64 or Complex64"
449                        )));
450                    }
451                }
452            }
453        }
454
455        // Handle the case where n is provided
456        match n_val.cmp(&complex_input.len()) {
457            std::cmp::Ordering::Less => {
458                // Truncate the input if n is smaller
459                complex_input.truncate(n_val);
460            }
461            std::cmp::Ordering::Greater => {
462                // Zero-pad the input if n is larger
463                complex_input.resize(n_val, Complex64::new(0.0, 0.0));
464            }
465            std::cmp::Ordering::Equal => {
466                // No resizing needed
467            }
468        }
469
470        // Set up rustfft for computation
471        let mut planner = FftPlanner::new();
472        let fft = match mode {
473            FftMode::Forward => planner.plan_fft_forward(n_val),
474            FftMode::Inverse => planner.plan_fft_inverse(n_val),
475        };
476
477        // Convert to rustfft's Complex type
478        let mut buffer: Vec<RustComplex<f64>> = complex_input
479            .iter()
480            .map(|&c| RustComplex::new(c.re, c.im))
481            .collect();
482
483        // Perform the FFT
484        fft.process(&mut buffer);
485
486        // Convert back to scirs2_core::numeric::Complex64 and apply normalization if needed
487        let scale = if mode == FftMode::Inverse {
488            1.0 / (n_val as f64)
489        } else {
490            1.0
491        };
492
493        let result: Vec<Complex64> = buffer
494            .into_iter()
495            .map(|c| Complex64::new(c.re * scale, c.im * scale))
496            .collect();
497
498        return Ok(result);
499    }
500
501    // Process in chunks for large arrays
502    let chunk_size_nz = NonZeroUsize::new(chunk_size_val).unwrap_or(NonZeroUsize::new(1).unwrap());
503    let n_chunks = n_val.div_ceil(chunk_size_nz.get());
504    let mut result = Vec::with_capacity(n_val);
505
506    for i in 0..n_chunks {
507        let start = i * chunk_size_val;
508        let end = (start + chunk_size_val).min(n_val);
509        let chunk_size = end - start;
510
511        // Prepare input chunk (either from original input or zero-padded)
512        let mut chunk_input = Vec::with_capacity(chunk_size);
513
514        if start < input_length {
515            // Part of the chunk comes from the input
516            let input_end = end.min(input_length);
517            for val in input[start..input_end].iter() {
518                match NumCast::from(*val) {
519                    Some(val_f64) => {
520                        chunk_input.push(Complex64::new(val_f64, 0.0));
521                    }
522                    None => {
523                        // Check if this is already a complex number
524                        if let Some(complex_val) = downcast_to_complex::<T>(val) {
525                            chunk_input.push(complex_val);
526                        } else {
527                            return Err(FFTError::ValueError(format!(
528                                "Could not convert {val:?} to f64 or Complex64"
529                            )));
530                        }
531                    }
532                }
533            }
534
535            // Zero-pad the rest if needed
536            if input_end < end {
537                chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
538            }
539        } else {
540            // Chunk is entirely outside the input range, so zero-pad
541            chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
542        }
543
544        // Set up rustfft for computation on this chunk
545        let mut planner = FftPlanner::new();
546        let fft = match mode {
547            FftMode::Forward => planner.plan_fft_forward(chunk_size),
548            FftMode::Inverse => planner.plan_fft_inverse(chunk_size),
549        };
550
551        // Convert to rustfft's Complex type
552        let mut buffer: Vec<RustComplex<f64>> = chunk_input
553            .iter()
554            .map(|&c| RustComplex::new(c.re, c.im))
555            .collect();
556
557        // Perform the FFT on this chunk
558        fft.process(&mut buffer);
559
560        // Convert back to scirs2_core::numeric::Complex64 and apply normalization if needed
561        let scale = if mode == FftMode::Inverse {
562            1.0 / (chunk_size as f64)
563        } else {
564            1.0
565        };
566
567        let chunk_result: Vec<Complex64> = buffer
568            .into_iter()
569            .map(|c| Complex64::new(c.re * scale, c.im * scale))
570            .collect();
571
572        // Add chunk result to the final result
573        result.extend(chunk_result);
574    }
575
576    // For inverse transforms, we need to normalize by the full length
577    // instead of chunk size, so adjust the scaling
578    if mode == FftMode::Inverse {
579        let full_scale = 1.0 / (n_val as f64);
580        let chunk_scale = 1.0 / (chunk_size_val as f64);
581        let scale_adjustment = full_scale / chunk_scale;
582
583        for val in &mut result {
584            val.re *= scale_adjustment;
585            val.im *= scale_adjustment;
586        }
587    }
588
589    Ok(result)
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use approx::assert_relative_eq;
596    use scirs2_core::ndarray::array;
597
598    #[test]
599    fn test_fft_inplace() {
600        // Test with a simple signal
601        let mut input = vec![
602            Complex64::new(1.0, 0.0),
603            Complex64::new(2.0, 0.0),
604            Complex64::new(3.0, 0.0),
605            Complex64::new(4.0, 0.0),
606        ];
607        let mut output = vec![Complex64::new(0.0, 0.0); 4];
608
609        // Perform forward FFT
610        fft_inplace(&mut input, &mut output, FftMode::Forward, false).unwrap();
611
612        // Check DC component is sum of all inputs
613        assert_relative_eq!(input[0].re, 10.0, epsilon = 1e-10);
614
615        // Perform inverse FFT
616        fft_inplace(&mut input, &mut output, FftMode::Inverse, true).unwrap();
617
618        // Check that we recover the original signal
619        assert_relative_eq!(input[0].re, 1.0, epsilon = 1e-10);
620        assert_relative_eq!(input[1].re, 2.0, epsilon = 1e-10);
621        assert_relative_eq!(input[2].re, 3.0, epsilon = 1e-10);
622        assert_relative_eq!(input[3].re, 4.0, epsilon = 1e-10);
623    }
624
625    #[test]
626    fn test_fft2_efficient() {
627        // Create a 2x2 test array
628        let arr = array![[1.0, 2.0], [3.0, 4.0]];
629
630        // Compute 2D FFT
631        let spectrum_2d = fft2_efficient(&arr.view(), None, FftMode::Forward, false).unwrap();
632
633        // DC component should be sum of all elements
634        assert_relative_eq!(spectrum_2d[[0, 0]].re, 10.0, epsilon = 1e-10);
635
636        // Compute inverse FFT
637        let recovered = fft2_efficient(&spectrum_2d.view(), None, FftMode::Inverse, true).unwrap();
638
639        // Check original values are recovered
640        assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-10);
641        assert_relative_eq!(recovered[[0, 1]].re, 2.0, epsilon = 1e-10);
642        assert_relative_eq!(recovered[[1, 0]].re, 3.0, epsilon = 1e-10);
643        assert_relative_eq!(recovered[[1, 1]].re, 4.0, epsilon = 1e-10);
644    }
645
646    #[test]
647    fn test_fft_streaming() {
648        // Create a test signal
649        let signal = vec![1.0, 2.0, 3.0, 4.0];
650
651        // Test with default chunk size
652        let result = fft_streaming(&signal, None, FftMode::Forward, None).unwrap();
653
654        // Check DC component is sum of inputs
655        assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
656
657        // Test inverse
658        let inverse = fft_streaming(&result, None, FftMode::Inverse, None).unwrap();
659
660        // Check we recover original signal
661        assert_relative_eq!(inverse[0].re, 1.0, epsilon = 1e-10);
662        assert_relative_eq!(inverse[1].re, 2.0, epsilon = 1e-10);
663        assert_relative_eq!(inverse[2].re, 3.0, epsilon = 1e-10);
664        assert_relative_eq!(inverse[3].re, 4.0, epsilon = 1e-10);
665
666        // Test with explicit small chunk size - this is explicitly set to ensure stable test results
667        let result_chunked =
668            fft_streaming(&signal, None, FftMode::Forward, Some(signal.len())).unwrap();
669
670        // Results should be the same
671        for (a, b) in result.iter().zip(result_chunked.iter()) {
672            assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
673            assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
674        }
675    }
676}