scirs2_fft/
czt.rs

1//! Chirp Z-Transform (CZT) implementation
2//!
3//! This module provides implementation of the Chirp Z-Transform,
4//! which enables evaluation of the Z-transform along arbitrary
5//! contours in the complex plane, including non-uniform frequency spacing.
6
7use crate::{next_fast_len, FFTError, FFTResult};
8use ndarray::{s, Array, Array1, ArrayBase, ArrayD, Axis, Data, Dimension, RemoveAxis, Zip};
9use num_complex::Complex;
10use std::f64::consts::PI;
11
12/// Compute points at which the chirp z-transform samples
13///
14/// Returns points on the z-plane where CZT evaluates the z-transform.
15/// The points follow a logarithmic spiral defined by `a` and `w`.
16///
17/// # Parameters
18/// - `m`: Number of output points
19/// - `a`: Starting point on the complex plane (default: 1+0j)
20/// - `w`: Ratio between consecutive points (default: exp(-2j*pi/m))
21///
22/// # Returns
23/// Array of complex points where z-transform is evaluated
24#[allow(dead_code)]
25pub fn czt_points(
26    m: usize,
27    a: Option<Complex<f64>>,
28    w: Option<Complex<f64>>,
29) -> Array1<Complex<f64>> {
30    let a = a.unwrap_or(Complex::new(1.0, 0.0));
31    let k = Array1::linspace(0.0, (m - 1) as f64, m);
32
33    if let Some(w) = w {
34        // Use specified w value
35        k.mapv(|ki| a * w.powf(-ki))
36    } else {
37        // Default to FFT-like equally spaced points on unit circle
38        k.mapv(|ki| a * (Complex::new(0.0, 2.0 * PI * ki / m as f64)).exp())
39    }
40}
41
42/// Chirp Z-Transform implementation
43///
44/// This structure pre-computes constant values for efficient CZT computation
45#[derive(Clone)]
46pub struct CZT {
47    n: usize,
48    m: usize,
49    w: Option<Complex<f64>>,
50    a: Complex<f64>,
51    nfft: usize,
52    awk2: Array1<Complex<f64>>,
53    fwk2: Array1<Complex<f64>>,
54    wk2: Array1<Complex<f64>>,
55}
56
57impl CZT {
58    /// Create a new CZT transform
59    ///
60    /// # Parameters
61    /// - `n`: Size of input signal
62    /// - `m`: Number of output points (default: n)
63    /// - `w`: Ratio between points (default: exp(-2j*pi/m))
64    /// - `a`: Starting point in complex plane (default: 1+0j)
65    pub fn new(
66        n: usize,
67        m: Option<usize>,
68        w: Option<Complex<f64>>,
69        a: Option<Complex<f64>>,
70    ) -> FFTResult<Self> {
71        if n < 1 {
72            return Err(FFTError::ValueError("n must be positive".to_string()));
73        }
74
75        let m = m.unwrap_or(n);
76        if m < 1 {
77            return Err(FFTError::ValueError("m must be positive".to_string()));
78        }
79
80        let a = a.unwrap_or(Complex::new(1.0, 0.0));
81        let max_size = n.max(m);
82        let k = Array1::linspace(0.0, (max_size - 1) as f64, max_size);
83
84        let (w, wk2) = if let Some(w) = w {
85            // User-specified w
86            let wk2 = k.mapv(|ki| w.powf(ki * ki / 2.0));
87            (Some(w), wk2)
88        } else {
89            // Default to FFT-like
90            let w = (-2.0 * PI * Complex::<f64>::i() / m as f64).exp();
91            let wk2 = k.mapv(|ki| {
92                let ki_i64 = ki as i64;
93                let phase = -(PI * ((ki_i64 * ki_i64) % (2 * m as i64)) as f64) / m as f64;
94                Complex::from_polar(1.0, phase)
95            });
96            (Some(w), wk2)
97        };
98
99        // Compute length for FFT
100        let nfft = next_fast_len(n + m - 1, false);
101
102        // Pre-compute A(k) * w_k^2 for the first n values
103        let awk2: Array1<Complex<f64>> = (0..n).map(|k| a.powf(-(k as f64)) * wk2[k]).collect();
104
105        // Pre-compute FFT of the reciprocal chirp
106        let mut chirp_vec = vec![Complex::new(0.0, 0.0); nfft];
107
108        // Fill with 1/wk2 values in specific order
109        for i in 1..n {
110            chirp_vec[n - 1 - i] = Complex::new(1.0, 0.0) / wk2[i];
111        }
112        for i in 0..m {
113            chirp_vec[n - 1 + i] = Complex::new(1.0, 0.0) / wk2[i];
114        }
115
116        let chirp_array = Array1::from_vec(chirp_vec);
117        let fwk2_vec = crate::fft::fft(&chirp_array.to_vec(), None)?;
118        let fwk2 = Array1::from_vec(fwk2_vec);
119
120        Ok(CZT {
121            n,
122            m,
123            w,
124            a,
125            nfft,
126            awk2,
127            fwk2,
128            wk2: wk2.slice(s![..m]).to_owned(),
129        })
130    }
131
132    /// Compute the points where this CZT evaluates the z-transform
133    pub fn points(&self) -> Array1<Complex<f64>> {
134        czt_points(self.m, Some(self.a), self.w)
135    }
136
137    /// Apply the chirp z-transform to a signal
138    ///
139    /// # Parameters
140    /// - `x`: Input signal
141    /// - `axis`: Axis along which to compute CZT (default: -1)
142    pub fn transform<S, D>(
143        &self,
144        x: &ArrayBase<S, D>,
145        axis: Option<i32>,
146    ) -> FFTResult<ArrayD<Complex<f64>>>
147    where
148        S: Data<Elem = Complex<f64>>,
149        D: Dimension + RemoveAxis,
150    {
151        let ndim = x.ndim();
152        let axis = if let Some(ax) = axis {
153            if ax < 0 {
154                let ax_pos = (ndim as i32 + ax) as usize;
155                if ax_pos >= ndim {
156                    return Err(FFTError::ValueError("Invalid axis".to_string()));
157                }
158                ax_pos
159            } else {
160                ax as usize
161            }
162        } else {
163            ndim - 1
164        };
165
166        let axis_len = x.shape()[axis];
167        if axis_len != self.n {
168            return Err(FFTError::ValueError(format!(
169                "Input size ({}) doesn't match CZT size ({})",
170                axis_len, self.n
171            )));
172        }
173
174        // Create output shape - same as input but with m points along specified axis
175        let mut outputshape = x.shape().to_vec();
176        outputshape[axis] = self.m;
177        let mut result = Array::<Complex<f64>, _>::zeros(outputshape).into_dyn();
178
179        // Apply CZT along the specified axis
180        // For 1D array, directly apply the transform
181        if x.ndim() == 1 {
182            let x_1d: Array1<Complex<f64>> = x
183                .to_owned()
184                .into_shape_with_order(x.len())
185                .map_err(|e| {
186                    FFTError::ComputationError(format!("Failed to reshape input array to 1D: {e}"))
187                })?
188                .into_dimensionality()
189                .map_err(|e| {
190                    FFTError::ComputationError(format!(
191                        "Failed to convert array dimensionality: {e}"
192                    ))
193                })?;
194            let y = self.transform_1d(&x_1d)?;
195            return Ok(y.into_dyn());
196        }
197
198        // For higher dimensions, iterate over axis
199        for (i, x_slice) in x.axis_iter(Axis(axis)).enumerate() {
200            // Convert slice to Array1
201            let x_1d: Array1<Complex<f64>> = x_slice
202                .to_owned()
203                .into_shape_with_order(x_slice.len())
204                .map_err(|e| {
205                    FFTError::ComputationError(format!("Failed to reshape slice to 1D array: {e}"))
206                })?;
207            let y = self.transform_1d(&x_1d)?;
208
209            // Dynamic slicing based on the number of dimensions
210            match result.ndim() {
211                2 => {
212                    if axis == 0 {
213                        let mut result_slice = result.slice_mut(s![i, ..]);
214                        result_slice.assign(&y);
215                    } else {
216                        let mut result_slice = result.slice_mut(s![.., i]);
217                        result_slice.assign(&y);
218                    }
219                }
220                _ => {
221                    // For higher dimensions, we need more complex handling
222                    return Err(FFTError::ValueError(
223                        "CZT currently only supports 1D and 2D arrays".to_string(),
224                    ));
225                }
226            }
227        }
228
229        Ok(result)
230    }
231
232    /// Transform a 1D signal
233    fn transform_1d(&self, x: &Array1<Complex<f64>>) -> FFTResult<Array1<Complex<f64>>> {
234        if x.len() != self.n {
235            return Err(FFTError::ValueError(format!(
236                "Input size ({}) doesn't match CZT size ({})",
237                x.len(),
238                self.n
239            )));
240        }
241
242        // Multiply input by A(k) * w_k^2
243        let x_weighted: Array1<Complex<f64>> = Zip::from(x)
244            .and(&self.awk2)
245            .map_collect(|&xi, &awki| xi * awki);
246
247        // Create zero-padded array for FFT
248        let mut padded = Array1::zeros(self.nfft);
249        padded.slice_mut(s![..self.n]).assign(&x_weighted);
250
251        // Forward FFT
252        let x_fft_vec = crate::fft::fft(&padded.to_vec(), None)?;
253        let x_fft = Array1::from_vec(x_fft_vec);
254
255        // Multiply by pre-computed FFT of reciprocal chirp
256        let product: Array1<Complex<f64>> = Zip::from(&x_fft)
257            .and(&self.fwk2)
258            .map_collect(|&xi, &fi| xi * fi);
259
260        // Inverse FFT
261        let y_full_vec = crate::fft::ifft(&product.to_vec(), None)?;
262        let y_full = Array1::from_vec(y_full_vec);
263
264        // Extract relevant portion and multiply by w_k^2
265        let y_slice = y_full.slice(s![self.n - 1..self.n - 1 + self.m]);
266        let result: Array1<Complex<f64>> = Zip::from(&y_slice)
267            .and(&self.wk2)
268            .map_collect(|&yi, &wki| yi * wki);
269
270        Ok(result)
271    }
272}
273
274/// Functional interface to chirp z-transform
275///
276/// # Parameters
277/// - `x`: Input signal
278/// - `m`: Number of output points (default: length of x)
279/// - `w`: Ratio between points (default: exp(-2j*pi/m))
280/// - `a`: Starting point in complex plane (default: 1+0j)
281/// - `axis`: Axis along which to compute CZT (default: -1)
282#[allow(dead_code)]
283pub fn czt<S, D>(
284    x: &ArrayBase<S, D>,
285    m: Option<usize>,
286    w: Option<Complex<f64>>,
287    a: Option<Complex<f64>>,
288    axis: Option<i32>,
289) -> FFTResult<ArrayD<Complex<f64>>>
290where
291    S: Data<Elem = Complex<f64>>,
292    D: Dimension + RemoveAxis,
293{
294    let axis_actual = if let Some(ax) = axis {
295        if ax < 0 {
296            (x.ndim() as i32 + ax) as usize
297        } else {
298            ax as usize
299        }
300    } else {
301        x.ndim() - 1
302    };
303
304    let n = x.shape()[axis_actual];
305    let transform = CZT::new(n, m, w, a)?;
306    transform.transform(x, axis)
307}
308
309/// Compute a zoom FFT - partial DFT on a specified frequency range
310///
311/// Efficiently evaluates the DFT over a subset of frequency range.
312///
313/// # Parameters
314/// - `x`: Input signal
315/// - `m`: Number of output points
316/// - `f0`: Starting normalized frequency (0 to 1)
317/// - `f1`: Ending normalized frequency (0 to 1)
318/// - `oversampling`: Oversampling factor for frequency resolution
319#[allow(dead_code)]
320pub fn zoom_fft<S, D>(
321    x: &ArrayBase<S, D>,
322    m: usize,
323    f0: f64,
324    f1: f64,
325    oversampling: Option<f64>,
326) -> FFTResult<ArrayD<Complex<f64>>>
327where
328    S: Data<Elem = Complex<f64>>,
329    D: Dimension + RemoveAxis,
330{
331    if !(0.0..=1.0).contains(&f0) || !(0.0..=1.0).contains(&f1) {
332        return Err(FFTError::ValueError(
333            "Frequencies must be in range [0, 1]".to_string(),
334        ));
335    }
336
337    if f0 >= f1 {
338        return Err(FFTError::ValueError("f0 must be less than f1".to_string()));
339    }
340
341    let oversampling = oversampling.unwrap_or(2.0);
342    if oversampling < 1.0 {
343        return Err(FFTError::ValueError(
344            "Oversampling must be >= 1".to_string(),
345        ));
346    }
347
348    let ndim = x.ndim();
349    let axis = ndim - 1;
350    let n = x.shape()[axis];
351
352    // Compute CZT parameters for zoom FFT
353    let k0_float = f0 * n as f64 * oversampling;
354    let k1_float = f1 * n as f64 * oversampling;
355    let step = (k1_float - k0_float) / (m - 1) as f64;
356
357    let phi = 2.0 * PI * k0_float / (n as f64 * oversampling);
358    let a = Complex::from_polar(1.0, phi);
359
360    let theta = -2.0 * PI * step / (n as f64 * oversampling);
361    let w = Complex::from_polar(1.0, theta);
362
363    czt(x, Some(m), Some(w), Some(a), Some(axis as i32))
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use approx::assert_abs_diff_eq;
370
371    #[test]
372    fn test_czt_points() {
373        // Test default points (FFT-like)
374        let points = czt_points(4, None, None);
375        assert_eq!(points.len(), 4);
376
377        // Check that they lie on unit circle
378        for p in points.iter() {
379            assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
380        }
381
382        // Test custom spiral
383        let a = Complex::new(0.8, 0.0);
384        let w = Complex::from_polar(0.95, 0.1);
385        let points = czt_points(5, Some(a), Some(w));
386        assert_eq!(points.len(), 5);
387        assert!((points[0] - a).norm() < 1e-10);
388    }
389
390    #[test]
391    fn test_czt_as_fft() {
392        // CZT with default parameters should match FFT
393        let n = 8;
394        let x: Array1<Complex<f64>> = Array1::linspace(0.0, 7.0, n).mapv(|v| Complex::new(v, 0.0));
395
396        let czt_result = czt(&x.view(), None, None, None, None)
397            .expect("CZT computation should succeed for test data");
398
399        // czt returns ArrayD, need to convert to Array1
400        assert_eq!(czt_result.ndim(), 1);
401        let czt_result_1d: Array1<Complex<f64>> = czt_result
402            .into_dimensionality()
403            .expect("CZT result should convert to 1D array");
404
405        let fft_result_vec = crate::fft::fft(&x.to_vec(), None)
406            .expect("FFT computation should succeed for test data");
407        let fft_result = Array1::from_vec(fft_result_vec);
408
409        for i in 0..n {
410            assert!((czt_result_1d[i].re - fft_result[i].re).abs() < 1e-10);
411            assert!((czt_result_1d[i].im - fft_result[i].im).abs() < 1e-10);
412        }
413    }
414
415    #[test]
416    fn test_zoom_fft() {
417        // Create a simple signal with a clear frequency peak
418        let n = 64;
419        let t: Array1<f64> = Array1::linspace(0.0, 1.0, n);
420        let x: Array1<Complex<f64>> = t.mapv(|ti| {
421            let s = (2.0 * PI * 5.0 * ti).sin(); // Single frequency for simplicity
422            Complex::new(s, 0.0)
423        });
424
425        // Zoom in on a wider frequency range to ensure we capture the signal
426        let m = 16;
427        let zoom_result =
428            zoom_fft(&x.view(), m, 0.0, 0.5, None).expect("Zoom FFT should succeed for test data");
429
430        // Basic validation - check that we got a result and it's the right size
431        assert_eq!(zoom_result.ndim(), 1);
432        let zoom_result_1d: Array1<Complex<f64>> = zoom_result
433            .into_dimensionality()
434            .expect("Zoom FFT result should convert to 1D array");
435        assert_eq!(zoom_result_1d.len(), m);
436
437        // Simple check - there should be some non-zero values in the result
438        let has_nonzero = zoom_result_1d.iter().any(|&c| c.norm() > 1e-10);
439        assert!(has_nonzero, "Zoom FFT should produce some non-zero values");
440    }
441}