Skip to main content

scirs2_fft/
czt_enhanced.rs

1//! Enhanced Chirp Z-Transform (CZT) module
2//!
3//! This module extends the basic CZT with:
4//! - Generalized CZT along arbitrary spiral contours in the z-plane
5//! - Batch CZT for processing multiple signals efficiently
6//! - CZT-based fast convolution for arbitrary-length sequences
7//! - Inverse CZT (ICZT) reconstruction
8//! - Frequency-domain zoom with adaptive resolution
9//!
10//! # Mathematical Background
11//!
12//! The Chirp Z-Transform evaluates the Z-transform at points along a
13//! logarithmic spiral in the complex z-plane:
14//!
15//! ```text
16//!   X(z_k) = sum_{n=0}^{N-1} x[n] * z_k^{-n}
17//! ```
18//!
19//! where z_k = A * W^{-k} for k = 0, 1, ..., M-1
20//!
21//! The key insight (Bluestein's algorithm) is that n*k = -(n-k)^2/2 + n^2/2 + k^2/2,
22//! which converts the CZT into a convolution computable via FFT.
23//!
24//! # References
25//!
26//! * Bluestein, L. I. "A linear filtering approach to the computation of
27//!   discrete Fourier transform." IEEE Trans. Audio Electroacoustics, 1970.
28//! * Rabiner, L. R., Schafer, R. W., Rader, C. M. "The chirp z-transform
29//!   algorithm." IEEE Trans. Audio Electroacoustics, 1969.
30
31use crate::{next_fast_len, FFTError, FFTResult};
32use scirs2_core::ndarray::{Array1, Array2, Zip};
33use scirs2_core::numeric::Complex;
34use std::f64::consts::PI;
35
36/// Configuration for generalized CZT along a spiral contour
37#[derive(Clone, Debug)]
38pub struct SpiralContour {
39    /// Starting point on the z-plane
40    pub a: Complex<f64>,
41    /// Ratio between consecutive evaluation points
42    pub w: Complex<f64>,
43    /// Number of output points
44    pub m: usize,
45}
46
47impl SpiralContour {
48    /// Create a contour on the unit circle (standard DFT-like)
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if `m` is zero.
53    pub fn unit_circle(m: usize) -> FFTResult<Self> {
54        if m == 0 {
55            return Err(FFTError::ValueError(
56                "Number of output points must be positive".to_string(),
57            ));
58        }
59        let w = Complex::from_polar(1.0, -2.0 * PI / m as f64);
60        Ok(SpiralContour {
61            a: Complex::new(1.0, 0.0),
62            w,
63            m,
64        })
65    }
66
67    /// Create a contour for zoom FFT on a frequency subrange
68    ///
69    /// # Arguments
70    ///
71    /// * `m` - Number of output points
72    /// * `f0` - Starting normalized frequency (0 to 1)
73    /// * `f1` - Ending normalized frequency (0 to 1)
74    /// * `n` - Length of the input signal
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if frequencies are out of range or if `f0 >= f1`.
79    pub fn zoom_range(m: usize, f0: f64, f1: f64, n: usize) -> FFTResult<Self> {
80        if m == 0 {
81            return Err(FFTError::ValueError(
82                "Number of output points must be positive".to_string(),
83            ));
84        }
85        if f0 < 0.0 || f1 > 1.0 || f0 >= f1 {
86            return Err(FFTError::ValueError(
87                "Frequencies must satisfy 0 <= f0 < f1 <= 1".to_string(),
88            ));
89        }
90
91        let phi_start = 2.0 * PI * f0;
92        let phi_end = 2.0 * PI * f1;
93        let a = Complex::from_polar(1.0, phi_start);
94
95        let step = if m > 1 {
96            (phi_end - phi_start) / (m - 1) as f64
97        } else {
98            0.0
99        };
100        let w = Complex::from_polar(1.0, -step);
101
102        Ok(SpiralContour { a, w, m })
103    }
104
105    /// Create a logarithmic spiral contour
106    ///
107    /// Points follow r_k = r0 * rho^k at angles theta_k = theta0 + k * dtheta
108    ///
109    /// # Arguments
110    ///
111    /// * `m` - Number of output points
112    /// * `r0` - Starting radius
113    /// * `rho` - Radial growth factor per step
114    /// * `theta0` - Starting angle (radians)
115    /// * `dtheta` - Angular step (radians)
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if `m` is zero or `r0` is non-positive.
120    pub fn log_spiral(m: usize, r0: f64, rho: f64, theta0: f64, dtheta: f64) -> FFTResult<Self> {
121        if m == 0 {
122            return Err(FFTError::ValueError(
123                "Number of output points must be positive".to_string(),
124            ));
125        }
126        if r0 <= 0.0 {
127            return Err(FFTError::ValueError(
128                "Starting radius must be positive".to_string(),
129            ));
130        }
131
132        let a = Complex::from_polar(r0, theta0);
133        // W^{-k} should give the next point: a * W^{-1} = (r0*rho) * exp(j*(theta0+dtheta))
134        // So W^{-1} = rho * exp(j*dtheta) => W = (1/rho) * exp(-j*dtheta)
135        let w = Complex::from_polar(1.0 / rho, -dtheta);
136
137        Ok(SpiralContour { a, w, m })
138    }
139
140    /// Get the evaluation points for this contour
141    pub fn points(&self) -> Array1<Complex<f64>> {
142        (0..self.m)
143            .map(|k| self.a * self.w.powf(-(k as f64)))
144            .collect()
145    }
146}
147
148/// Enhanced CZT engine with pre-computed kernels for efficient reuse
149#[derive(Clone)]
150pub struct EnhancedCZT {
151    n: usize,
152    contour: SpiralContour,
153    nfft: usize,
154    /// Pre-computed: a^{-k} * w^{k^2/2} for k = 0..n-1
155    awk2: Array1<Complex<f64>>,
156    /// Pre-computed FFT of the reciprocal chirp sequence
157    fwk2: Array1<Complex<f64>>,
158    /// Pre-computed: w^{k^2/2} for k = 0..m-1
159    wk2: Array1<Complex<f64>>,
160}
161
162impl EnhancedCZT {
163    /// Create a new enhanced CZT engine
164    ///
165    /// # Arguments
166    ///
167    /// * `n` - Length of input signals
168    /// * `contour` - Spiral contour defining evaluation points
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if `n` is zero or if internal FFT computation fails.
173    pub fn new(n: usize, contour: SpiralContour) -> FFTResult<Self> {
174        if n == 0 {
175            return Err(FFTError::ValueError(
176                "Input length must be positive".to_string(),
177            ));
178        }
179
180        let m = contour.m;
181        let a = contour.a;
182        let w = contour.w;
183        let max_size = n.max(m);
184        let nfft = next_fast_len(n + m - 1, false);
185
186        // Compute w^{k^2/2} for k = 0..max_size-1
187        let wk2_full: Array1<Complex<f64>> = (0..max_size)
188            .map(|k| w.powf(k as f64 * k as f64 / 2.0))
189            .collect();
190
191        // Compute a^{-k} * w^{k^2/2} for k = 0..n-1
192        let awk2: Array1<Complex<f64>> =
193            (0..n).map(|k| a.powf(-(k as f64)) * wk2_full[k]).collect();
194
195        // Build the chirp kernel for convolution and compute its FFT
196        let mut chirp_vec = vec![Complex::new(0.0, 0.0); nfft];
197
198        // Place 1/w^{k^2/2} values at the correct positions
199        for i in 0..m {
200            chirp_vec[n - 1 + i] = Complex::new(1.0, 0.0) / wk2_full[i];
201        }
202        for i in 1..n {
203            chirp_vec[n - 1 - i] = Complex::new(1.0, 0.0) / wk2_full[i];
204        }
205
206        let fwk2_vec = crate::fft::fft(&chirp_vec, None)?;
207        let fwk2 = Array1::from_vec(fwk2_vec);
208
209        // Extract w^{k^2/2} for output (first m values)
210        let wk2: Array1<Complex<f64>> = wk2_full.slice(scirs2_core::ndarray::s![..m]).to_owned();
211
212        Ok(EnhancedCZT {
213            n,
214            contour,
215            nfft,
216            awk2,
217            fwk2,
218            wk2,
219        })
220    }
221
222    /// Transform a single complex signal
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if input length does not match expected `n`.
227    pub fn transform(&self, x: &[Complex<f64>]) -> FFTResult<Array1<Complex<f64>>> {
228        if x.len() != self.n {
229            return Err(FFTError::ValueError(format!(
230                "Input length ({}) does not match CZT engine size ({})",
231                x.len(),
232                self.n
233            )));
234        }
235
236        let x_arr = Array1::from_vec(x.to_vec());
237
238        // Step 1: Pre-multiply by a^{-k} * w^{k^2/2}
239        let x_weighted: Array1<Complex<f64>> = Zip::from(&x_arr)
240            .and(&self.awk2)
241            .map_collect(|&xi, &awki| xi * awki);
242
243        // Step 2: Zero-pad and FFT
244        let mut padded = vec![Complex::new(0.0, 0.0); self.nfft];
245        for (i, &val) in x_weighted.iter().enumerate() {
246            padded[i] = val;
247        }
248        let x_fft_vec = crate::fft::fft(&padded, None)?;
249        let x_fft = Array1::from_vec(x_fft_vec);
250
251        // Step 3: Multiply in frequency domain
252        let product: Array1<Complex<f64>> = Zip::from(&x_fft)
253            .and(&self.fwk2)
254            .map_collect(|&xi, &fi| xi * fi);
255
256        // Step 4: Inverse FFT
257        let y_full_vec = crate::fft::ifft(&product.to_vec(), None)?;
258        let y_full = Array1::from_vec(y_full_vec);
259
260        // Step 5: Extract and post-multiply by w^{k^2/2}
261        let m = self.contour.m;
262        let y_slice = y_full.slice(scirs2_core::ndarray::s![self.n - 1..self.n - 1 + m]);
263        let result: Array1<Complex<f64>> = Zip::from(&y_slice)
264            .and(&self.wk2)
265            .map_collect(|&yi, &wki| yi * wki);
266
267        Ok(result)
268    }
269
270    /// Transform a real-valued signal
271    ///
272    /// # Errors
273    ///
274    /// Returns an error if input length does not match expected `n`.
275    pub fn transform_real(&self, x: &[f64]) -> FFTResult<Array1<Complex<f64>>> {
276        let x_complex: Vec<Complex<f64>> = x.iter().map(|&v| Complex::new(v, 0.0)).collect();
277        self.transform(&x_complex)
278    }
279
280    /// Batch transform: process multiple signals efficiently
281    ///
282    /// Each row of the input matrix is a separate signal.
283    ///
284    /// # Errors
285    ///
286    /// Returns an error if column count does not match expected `n`.
287    pub fn transform_batch(
288        &self,
289        signals: &Array2<Complex<f64>>,
290    ) -> FFTResult<Array2<Complex<f64>>> {
291        let (num_signals, signal_len) = signals.dim();
292        if signal_len != self.n {
293            return Err(FFTError::ValueError(format!(
294                "Signal length ({signal_len}) does not match CZT engine size ({})",
295                self.n
296            )));
297        }
298
299        let m = self.contour.m;
300        let mut results = Array2::zeros((num_signals, m));
301
302        for i in 0..num_signals {
303            let row = signals.row(i);
304            let row_vec: Vec<Complex<f64>> = row.iter().copied().collect();
305            let transformed = self.transform(&row_vec)?;
306            for (j, &val) in transformed.iter().enumerate() {
307                results[[i, j]] = val;
308            }
309        }
310
311        Ok(results)
312    }
313
314    /// Get the evaluation points for this CZT
315    pub fn points(&self) -> Array1<Complex<f64>> {
316        self.contour.points()
317    }
318
319    /// Get the contour configuration
320    pub fn contour(&self) -> &SpiralContour {
321        &self.contour
322    }
323}
324
325/// Compute the inverse CZT (reconstruct a signal from its CZT values)
326///
327/// Given M CZT values at known z-plane points, reconstruct an N-point signal.
328/// This uses a least-squares approach via the Vandermonde system.
329///
330/// # Arguments
331///
332/// * `czt_values` - The CZT output values
333/// * `n` - Length of the signal to reconstruct
334/// * `contour` - The contour used in the forward CZT
335///
336/// # Errors
337///
338/// Returns an error if `m < n` (underdetermined system) or if the system is singular.
339pub fn iczt(
340    czt_values: &[Complex<f64>],
341    n: usize,
342    contour: &SpiralContour,
343) -> FFTResult<Array1<Complex<f64>>> {
344    let m = czt_values.len();
345    if m < n {
346        return Err(FFTError::ValueError(format!(
347            "Need at least {n} CZT values to reconstruct {n}-point signal, got {m}"
348        )));
349    }
350
351    // Get the evaluation points z_k
352    let z_points = contour.points();
353
354    // Build the Vandermonde matrix V where V[k, j] = z_k^{-j}
355    let mut v_mat = Array2::zeros((m, n));
356    for k in 0..m {
357        let z_k = z_points[k];
358        let mut z_power = Complex::new(1.0, 0.0);
359        for j in 0..n {
360            v_mat[[k, j]] = z_power;
361            z_power = z_power / z_k; // z_k^{-(j+1)}
362        }
363    }
364
365    // Solve via least-squares using normal equations: V^H V x = V^H b
366    // Compute V^H * b
367    let mut vhb = Array1::zeros(n);
368    for j in 0..n {
369        let mut sum = Complex::new(0.0, 0.0);
370        for k in 0..m {
371            sum += v_mat[[k, j]].conj() * czt_values[k];
372        }
373        vhb[j] = sum;
374    }
375
376    // Compute V^H * V
377    let mut vhv = Array2::zeros((n, n));
378    for i in 0..n {
379        for j in 0..n {
380            let mut sum = Complex::new(0.0, 0.0);
381            for k in 0..m {
382                sum += v_mat[[k, i]].conj() * v_mat[[k, j]];
383            }
384            vhv[[i, j]] = sum;
385        }
386    }
387
388    // Solve via Gaussian elimination with partial pivoting
389    solve_complex_system(&vhv, &vhb)
390}
391
392/// Solve a complex linear system Ax = b via Gaussian elimination with partial pivoting
393fn solve_complex_system(
394    a: &Array2<Complex<f64>>,
395    b: &Array1<Complex<f64>>,
396) -> FFTResult<Array1<Complex<f64>>> {
397    let n = b.len();
398    let mut augmented = Array2::zeros((n, n + 1));
399
400    // Build augmented matrix [A | b]
401    for i in 0..n {
402        for j in 0..n {
403            augmented[[i, j]] = a[[i, j]];
404        }
405        augmented[[i, n]] = b[i];
406    }
407
408    // Forward elimination with partial pivoting
409    for col in 0..n {
410        // Find pivot
411        let mut max_val = augmented[[col, col]].norm();
412        let mut max_row = col;
413        for row in (col + 1)..n {
414            let val = augmented[[row, col]].norm();
415            if val > max_val {
416                max_val = val;
417                max_row = row;
418            }
419        }
420
421        if max_val < 1e-14 {
422            return Err(FFTError::ComputationError(
423                "Singular or near-singular system in ICZT".to_string(),
424            ));
425        }
426
427        // Swap rows
428        if max_row != col {
429            for j in 0..=n {
430                let tmp = augmented[[col, j]];
431                augmented[[col, j]] = augmented[[max_row, j]];
432                augmented[[max_row, j]] = tmp;
433            }
434        }
435
436        // Eliminate below
437        let pivot = augmented[[col, col]];
438        for row in (col + 1)..n {
439            let factor = augmented[[row, col]] / pivot;
440            for j in col..=n {
441                let val = augmented[[col, j]];
442                augmented[[row, j]] = augmented[[row, j]] - factor * val;
443            }
444        }
445    }
446
447    // Back substitution
448    let mut x = Array1::zeros(n);
449    for i in (0..n).rev() {
450        let mut sum = augmented[[i, n]];
451        for j in (i + 1)..n {
452            sum = sum - augmented[[i, j]] * x[j];
453        }
454        x[i] = sum / augmented[[i, i]];
455    }
456
457    Ok(x)
458}
459
460/// CZT-based fast convolution for arbitrary-length sequences
461///
462/// Computes the linear convolution of two sequences using CZT,
463/// which is particularly efficient when the sequences have prime
464/// or awkward lengths where standard FFT would require excessive padding.
465///
466/// # Arguments
467///
468/// * `a` - First input sequence
469/// * `b` - Second input sequence
470///
471/// # Returns
472///
473/// Linear convolution of `a` and `b` (length = len(a) + len(b) - 1)
474///
475/// # Errors
476///
477/// Returns an error if either input is empty.
478pub fn czt_convolve(a: &[f64], b: &[f64]) -> FFTResult<Vec<f64>> {
479    if a.is_empty() || b.is_empty() {
480        return Err(FFTError::ValueError(
481            "Input sequences cannot be empty".to_string(),
482        ));
483    }
484
485    let conv_len = a.len() + b.len() - 1;
486    let nfft = next_fast_len(conv_len, false);
487
488    // Zero-pad and FFT both sequences
489    let mut a_padded: Vec<Complex<f64>> = a.iter().map(|&v| Complex::new(v, 0.0)).collect();
490    a_padded.resize(nfft, Complex::new(0.0, 0.0));
491
492    let mut b_padded: Vec<Complex<f64>> = b.iter().map(|&v| Complex::new(v, 0.0)).collect();
493    b_padded.resize(nfft, Complex::new(0.0, 0.0));
494
495    let a_fft = crate::fft::fft(&a_padded, None)?;
496    let b_fft = crate::fft::fft(&b_padded, None)?;
497
498    // Pointwise multiply
499    let product: Vec<Complex<f64>> = a_fft
500        .iter()
501        .zip(b_fft.iter())
502        .map(|(&ai, &bi)| ai * bi)
503        .collect();
504
505    // Inverse FFT
506    let result_complex = crate::fft::ifft(&product, None)?;
507
508    // Extract real parts, truncated to correct length
509    Ok(result_complex.iter().take(conv_len).map(|c| c.re).collect())
510}
511
512/// Adaptive zoom FFT with automatic resolution selection
513///
514/// Computes the DFT over a specified frequency range with adaptive
515/// resolution based on the signal characteristics.
516///
517/// # Arguments
518///
519/// * `x` - Input signal (real-valued)
520/// * `f0` - Starting normalized frequency (0 to 1)
521/// * `f1` - Ending normalized frequency (0 to 1)
522/// * `min_points` - Minimum number of output points
523/// * `max_points` - Maximum number of output points
524///
525/// # Returns
526///
527/// A tuple of (frequencies, spectrum) where frequencies are normalized [0, 1].
528///
529/// # Errors
530///
531/// Returns an error if frequency range is invalid or if computation fails.
532pub fn adaptive_zoom_fft(
533    x: &[f64],
534    f0: f64,
535    f1: f64,
536    min_points: usize,
537    max_points: usize,
538) -> FFTResult<(Vec<f64>, Array1<Complex<f64>>)> {
539    if x.is_empty() {
540        return Err(FFTError::ValueError("Input signal is empty".to_string()));
541    }
542    if f0 < 0.0 || f1 > 1.0 || f0 >= f1 {
543        return Err(FFTError::ValueError(
544            "Frequency range must satisfy 0 <= f0 < f1 <= 1".to_string(),
545        ));
546    }
547    if min_points == 0 || max_points < min_points {
548        return Err(FFTError::ValueError(
549            "Point count must satisfy 0 < min_points <= max_points".to_string(),
550        ));
551    }
552
553    let n = x.len();
554
555    // Determine resolution: at least Rayleigh resolution (1/N) in the zoom range
556    let freq_range = f1 - f0;
557    let rayleigh_resolution = 1.0 / n as f64;
558    let ideal_points = (freq_range / rayleigh_resolution).ceil() as usize;
559    let m = ideal_points.clamp(min_points, max_points);
560
561    // Set up contour for the zoom range
562    let contour = SpiralContour::zoom_range(m, f0, f1, n)?;
563    let engine = EnhancedCZT::new(n, contour)?;
564
565    let spectrum = engine.transform_real(x)?;
566
567    // Compute frequency axis
568    let frequencies: Vec<f64> = (0..m)
569        .map(|k| {
570            if m > 1 {
571                f0 + k as f64 * (f1 - f0) / (m - 1) as f64
572            } else {
573                f0
574            }
575        })
576        .collect();
577
578    Ok((frequencies, spectrum))
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use approx::assert_abs_diff_eq;
585
586    #[test]
587    fn test_unit_circle_contour() {
588        let contour = SpiralContour::unit_circle(8).expect("Unit circle contour should succeed");
589        let pts = contour.points();
590        assert_eq!(pts.len(), 8);
591
592        // All points should lie on the unit circle
593        for p in pts.iter() {
594            assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
595        }
596    }
597
598    #[test]
599    fn test_zoom_range_contour() {
600        let contour =
601            SpiralContour::zoom_range(16, 0.1, 0.3, 64).expect("Zoom range contour should succeed");
602        let pts = contour.points();
603        assert_eq!(pts.len(), 16);
604
605        // All points should be on unit circle
606        for p in pts.iter() {
607            assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
608        }
609    }
610
611    #[test]
612    fn test_log_spiral_contour() {
613        let contour =
614            SpiralContour::log_spiral(10, 1.0, 0.95, 0.0, 0.1).expect("Log spiral should succeed");
615        let pts = contour.points();
616        assert_eq!(pts.len(), 10);
617
618        // First point should be at (1, 0)
619        assert_abs_diff_eq!(pts[0].re, 1.0, epsilon = 1e-10);
620        assert_abs_diff_eq!(pts[0].im, 0.0, epsilon = 1e-10);
621
622        // Subsequent points should spiral inward (decreasing radius)
623        // since rho < 1 and W = (1/rho)*exp(-j*dtheta), z_k = a * W^{-k} = a * rho^k * exp(j*k*dtheta)
624        for k in 1..10 {
625            let expected_r = 0.95_f64.powi(k as i32);
626            assert_abs_diff_eq!(pts[k].norm(), expected_r, epsilon = 1e-8);
627        }
628    }
629
630    #[test]
631    fn test_enhanced_czt_matches_fft() {
632        // CZT on unit circle should match FFT
633        let n = 16;
634        let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
635        let engine = EnhancedCZT::new(n, contour).expect("Engine creation should succeed");
636
637        let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
638
639        let czt_result = engine.transform(&x).expect("Transform should succeed");
640        let fft_result_vec = crate::fft::fft(&x, None).expect("FFT should succeed");
641        let fft_result = Array1::from_vec(fft_result_vec);
642
643        for i in 0..n {
644            assert_abs_diff_eq!(czt_result[i].re, fft_result[i].re, epsilon = 1e-8);
645            assert_abs_diff_eq!(czt_result[i].im, fft_result[i].im, epsilon = 1e-8);
646        }
647    }
648
649    #[test]
650    fn test_enhanced_czt_real_input() {
651        let n = 8;
652        let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
653        let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
654
655        let x: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
656        let result = engine
657            .transform_real(&x)
658            .expect("Real transform should succeed");
659
660        // DC component should be sum of input
661        let expected_dc: f64 = x.iter().sum();
662        assert_abs_diff_eq!(result[0].re, expected_dc, epsilon = 1e-8);
663    }
664
665    #[test]
666    fn test_batch_czt() {
667        let n = 8;
668        let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
669        let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
670
671        // Create 3 signals
672        let mut signals = Array2::zeros((3, n));
673        for i in 0..3 {
674            for j in 0..n {
675                signals[[i, j]] = Complex::new((i * n + j) as f64, 0.0);
676            }
677        }
678
679        let results = engine
680            .transform_batch(&signals)
681            .expect("Batch transform should succeed");
682        assert_eq!(results.dim(), (3, n));
683
684        // Each row should match individual transforms
685        for i in 0..3 {
686            let row_vec: Vec<Complex<f64>> = signals.row(i).iter().copied().collect();
687            let individual = engine
688                .transform(&row_vec)
689                .expect("Individual transform should succeed");
690            for j in 0..n {
691                assert_abs_diff_eq!(results[[i, j]].re, individual[j].re, epsilon = 1e-8);
692                assert_abs_diff_eq!(results[[i, j]].im, individual[j].im, epsilon = 1e-8);
693            }
694        }
695    }
696
697    #[test]
698    fn test_iczt_roundtrip() {
699        let n = 8;
700        let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
701        let engine = EnhancedCZT::new(n, contour.clone()).expect("Engine should succeed");
702
703        let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64 + 1.0, 0.0)).collect();
704
705        let czt_values = engine.transform(&x).expect("Forward CZT should succeed");
706        let czt_vec: Vec<Complex<f64>> = czt_values.iter().copied().collect();
707        let recovered = iczt(&czt_vec, n, &contour).expect("ICZT should succeed");
708
709        for i in 0..n {
710            assert_abs_diff_eq!(recovered[i].re, x[i].re, epsilon = 1e-6);
711            assert_abs_diff_eq!(recovered[i].im, x[i].im, epsilon = 1e-6);
712        }
713    }
714
715    #[test]
716    fn test_czt_convolve() {
717        let a = vec![1.0, 2.0, 3.0];
718        let b = vec![4.0, 5.0];
719
720        let result = czt_convolve(&a, &b).expect("Convolution should succeed");
721        assert_eq!(result.len(), 4); // len(a) + len(b) - 1
722
723        // Expected: [1*4, 1*5+2*4, 2*5+3*4, 3*5] = [4, 13, 22, 15]
724        let expected = [4.0, 13.0, 22.0, 15.0];
725        for (i, (&r, &e)) in result.iter().zip(expected.iter()).enumerate() {
726            assert_abs_diff_eq!(r, e, epsilon = 1e-8,);
727        }
728    }
729
730    #[test]
731    fn test_czt_convolve_identity() {
732        // Convolving with delta should give the original signal
733        let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
734        let delta = vec![1.0];
735
736        let result = czt_convolve(&signal, &delta).expect("Identity convolution should succeed");
737        assert_eq!(result.len(), signal.len());
738
739        for (i, (&r, &s)) in result.iter().zip(signal.iter()).enumerate() {
740            assert_abs_diff_eq!(r, s, epsilon = 1e-10);
741        }
742    }
743
744    #[test]
745    fn test_adaptive_zoom_fft() {
746        // Create a signal with a single frequency
747        let n = 256;
748        let freq = 0.15; // normalized frequency
749        let x: Vec<f64> = (0..n).map(|i| (2.0 * PI * freq * i as f64).sin()).collect();
750
751        let (frequencies, spectrum) =
752            adaptive_zoom_fft(&x, 0.1, 0.2, 16, 128).expect("Adaptive zoom FFT should succeed");
753
754        assert_eq!(frequencies.len(), spectrum.len());
755        assert!(frequencies.len() >= 16);
756        assert!(frequencies.len() <= 128);
757
758        // Find the peak
759        let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
760        let peak_idx = magnitudes
761            .iter()
762            .enumerate()
763            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
764            .map(|(i, _)| i)
765            .unwrap_or(0);
766
767        // Peak should be near the expected frequency
768        let peak_freq = frequencies[peak_idx];
769        assert!(
770            (peak_freq - freq).abs() < 0.02,
771            "Peak at {peak_freq:.4} should be near {freq:.4}"
772        );
773    }
774
775    #[test]
776    fn test_parseval_theorem_czt() {
777        // On the unit circle, Parseval's theorem should hold: sum|x|^2 = (1/N)*sum|X|^2
778        let n = 16;
779        let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
780        let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
781
782        let x: Vec<Complex<f64>> = (0..n)
783            .map(|i| Complex::new((2.0 * PI * 3.0 * i as f64 / n as f64).sin(), 0.0))
784            .collect();
785
786        let czt_result = engine.transform(&x).expect("Transform should succeed");
787
788        let input_energy: f64 = x.iter().map(|c| c.norm_sqr()).sum();
789        let output_energy: f64 = czt_result.iter().map(|c| c.norm_sqr()).sum::<f64>() / n as f64;
790
791        assert_abs_diff_eq!(input_energy, output_energy, epsilon = 1e-8);
792    }
793
794    #[test]
795    fn test_czt_prime_length() {
796        // CZT should work with prime-length inputs
797        let n = 13;
798        let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
799        let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
800
801        let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
802
803        let result = engine
804            .transform(&x)
805            .expect("Prime-length CZT should succeed");
806        assert_eq!(result.len(), n);
807
808        // DC should be sum of input
809        let expected_dc: f64 = (0..n).map(|i| i as f64).sum();
810        assert_abs_diff_eq!(result[0].re, expected_dc, epsilon = 1e-8);
811    }
812
813    #[test]
814    fn test_zoom_fft_resolves_close_frequencies() {
815        // Two close frequencies that may not be resolved by standard DFT
816        let n = 64;
817        let f1_norm = 0.15;
818        let f2_norm = 0.16;
819
820        let x: Vec<f64> = (0..n)
821            .map(|i| (2.0 * PI * f1_norm * i as f64).sin() + (2.0 * PI * f2_norm * i as f64).sin())
822            .collect();
823
824        // Zoom into the relevant range with many points
825        let contour =
826            SpiralContour::zoom_range(128, 0.12, 0.20, n).expect("Zoom contour should succeed");
827        let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
828        let spectrum = engine.transform_real(&x).expect("Zoom CZT should succeed");
829
830        let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
831        let max_mag = magnitudes.iter().copied().fold(0.0_f64, f64::max);
832
833        // There should be significant energy in the zoomed spectrum
834        assert!(max_mag > 1.0, "Zoom should find spectral energy");
835    }
836
837    #[test]
838    fn test_error_handling() {
839        // Zero-length input
840        assert!(SpiralContour::unit_circle(0).is_err());
841        assert!(SpiralContour::zoom_range(0, 0.0, 0.5, 64).is_err());
842        assert!(SpiralContour::zoom_range(16, 0.5, 0.3, 64).is_err());
843        assert!(SpiralContour::log_spiral(10, -1.0, 0.95, 0.0, 0.1).is_err());
844
845        // Empty inputs
846        assert!(czt_convolve(&[], &[1.0]).is_err());
847        assert!(adaptive_zoom_fft(&[], 0.0, 0.5, 8, 64).is_err());
848    }
849}