scirs2_fft/
czt.rs

1//! Chirp Z-Transform (CZT) implementation
2//!
3//! This module provides implementation of the Chirp Z-Transform,
4//! which enables evaluation of the Z-transform along arbitrary
5//! contours in the complex plane, including non-uniform frequency spacing.
6
7use crate::{next_fast_len, FFTError, FFTResult};
8use scirs2_core::ndarray::{
9    s, Array, Array1, ArrayBase, ArrayD, Axis, Data, Dimension, RemoveAxis, Zip,
10};
11use scirs2_core::numeric::Complex;
12use std::f64::consts::PI;
13
14/// Compute points at which the chirp z-transform samples
15///
16/// Returns points on the z-plane where CZT evaluates the z-transform.
17/// The points follow a logarithmic spiral defined by `a` and `w`.
18///
19/// # Parameters
20/// - `m`: Number of output points
21/// - `a`: Starting point on the complex plane (default: 1+0j)
22/// - `w`: Ratio between consecutive points (default: exp(-2j*pi/m))
23///
24/// # Returns
25/// Array of complex points where z-transform is evaluated
26#[allow(dead_code)]
27pub fn czt_points(
28    m: usize,
29    a: Option<Complex<f64>>,
30    w: Option<Complex<f64>>,
31) -> Array1<Complex<f64>> {
32    let a = a.unwrap_or(Complex::new(1.0, 0.0));
33    let k = Array1::linspace(0.0, (m - 1) as f64, m);
34
35    if let Some(w) = w {
36        // Use specified w value
37        k.mapv(|ki| a * w.powf(-ki))
38    } else {
39        // Default to FFT-like equally spaced points on unit circle
40        k.mapv(|ki| a * (Complex::new(0.0, 2.0 * PI * ki / m as f64)).exp())
41    }
42}
43
44/// Chirp Z-Transform implementation
45///
46/// This structure pre-computes constant values for efficient CZT computation
47#[derive(Clone)]
48pub struct CZT {
49    n: usize,
50    m: usize,
51    w: Option<Complex<f64>>,
52    a: Complex<f64>,
53    nfft: usize,
54    awk2: Array1<Complex<f64>>,
55    fwk2: Array1<Complex<f64>>,
56    wk2: Array1<Complex<f64>>,
57}
58
59impl CZT {
60    /// Create a new CZT transform
61    ///
62    /// # Parameters
63    /// - `n`: Size of input signal
64    /// - `m`: Number of output points (default: n)
65    /// - `w`: Ratio between points (default: exp(-2j*pi/m))
66    /// - `a`: Starting point in complex plane (default: 1+0j)
67    pub fn new(
68        n: usize,
69        m: Option<usize>,
70        w: Option<Complex<f64>>,
71        a: Option<Complex<f64>>,
72    ) -> FFTResult<Self> {
73        if n < 1 {
74            return Err(FFTError::ValueError("n must be positive".to_string()));
75        }
76
77        let m = m.unwrap_or(n);
78        if m < 1 {
79            return Err(FFTError::ValueError("m must be positive".to_string()));
80        }
81
82        let a = a.unwrap_or(Complex::new(1.0, 0.0));
83        let max_size = n.max(m);
84        let k = Array1::linspace(0.0, (max_size - 1) as f64, max_size);
85
86        let (w, wk2) = if let Some(w) = w {
87            // User-specified w
88            let wk2 = k.mapv(|ki| w.powf(ki * ki / 2.0));
89            (Some(w), wk2)
90        } else {
91            // Default to FFT-like
92            let w = (-2.0 * PI * Complex::<f64>::i() / m as f64).exp();
93            let wk2 = k.mapv(|ki| {
94                let ki_i64 = ki as i64;
95                let phase = -(PI * ((ki_i64 * ki_i64) % (2 * m as i64)) as f64) / m as f64;
96                Complex::from_polar(1.0, phase)
97            });
98            (Some(w), wk2)
99        };
100
101        // Compute length for FFT
102        let nfft = next_fast_len(n + m - 1, false);
103
104        // Pre-compute A(k) * w_k^2 for the first n values
105        let awk2: Array1<Complex<f64>> = (0..n).map(|k| a.powf(-(k as f64)) * wk2[k]).collect();
106
107        // Pre-compute FFT of the reciprocal chirp
108        let mut chirp_vec = vec![Complex::new(0.0, 0.0); nfft];
109
110        // Fill with 1/wk2 values in specific order
111        for i in 1..n {
112            chirp_vec[n - 1 - i] = Complex::new(1.0, 0.0) / wk2[i];
113        }
114        for i in 0..m {
115            chirp_vec[n - 1 + i] = Complex::new(1.0, 0.0) / wk2[i];
116        }
117
118        let chirp_array = Array1::from_vec(chirp_vec);
119        let fwk2_vec = crate::fft::fft(&chirp_array.to_vec(), None)?;
120        let fwk2 = Array1::from_vec(fwk2_vec);
121
122        Ok(CZT {
123            n,
124            m,
125            w,
126            a,
127            nfft,
128            awk2,
129            fwk2,
130            wk2: wk2.slice(s![..m]).to_owned(),
131        })
132    }
133
134    /// Compute the points where this CZT evaluates the z-transform
135    pub fn points(&self) -> Array1<Complex<f64>> {
136        czt_points(self.m, Some(self.a), self.w)
137    }
138
139    /// Apply the chirp z-transform to a signal
140    ///
141    /// # Parameters
142    /// - `x`: Input signal
143    /// - `axis`: Axis along which to compute CZT (default: -1)
144    pub fn transform<S, D>(
145        &self,
146        x: &ArrayBase<S, D>,
147        axis: Option<i32>,
148    ) -> FFTResult<ArrayD<Complex<f64>>>
149    where
150        S: Data<Elem = Complex<f64>>,
151        D: Dimension + RemoveAxis,
152    {
153        let ndim = x.ndim();
154        let axis = if let Some(ax) = axis {
155            if ax < 0 {
156                let ax_pos = (ndim as i32 + ax) as usize;
157                if ax_pos >= ndim {
158                    return Err(FFTError::ValueError("Invalid axis".to_string()));
159                }
160                ax_pos
161            } else {
162                ax as usize
163            }
164        } else {
165            ndim - 1
166        };
167
168        let axis_len = x.shape()[axis];
169        if axis_len != self.n {
170            return Err(FFTError::ValueError(format!(
171                "Input size ({}) doesn't match CZT size ({})",
172                axis_len, self.n
173            )));
174        }
175
176        // Create output shape - same as input but with m points along specified axis
177        let mut outputshape = x.shape().to_vec();
178        outputshape[axis] = self.m;
179        let mut result = Array::<Complex<f64>, _>::zeros(outputshape).into_dyn();
180
181        // Apply CZT along the specified axis
182        // For 1D array, directly apply the transform
183        if x.ndim() == 1 {
184            let x_1d: Array1<Complex<f64>> = x
185                .to_owned()
186                .into_shape_with_order(x.len())
187                .map_err(|e| {
188                    FFTError::ComputationError(format!("Failed to reshape input array to 1D: {e}"))
189                })?
190                .into_dimensionality()
191                .map_err(|e| {
192                    FFTError::ComputationError(format!(
193                        "Failed to convert array dimensionality: {e}"
194                    ))
195                })?;
196            let y = self.transform_1d(&x_1d)?;
197            return Ok(y.into_dyn());
198        }
199
200        // For higher dimensions, iterate over axis
201        for (i, x_slice) in x.axis_iter(Axis(axis)).enumerate() {
202            // Convert slice to Array1
203            let x_1d: Array1<Complex<f64>> = x_slice
204                .to_owned()
205                .into_shape_with_order(x_slice.len())
206                .map_err(|e| {
207                    FFTError::ComputationError(format!("Failed to reshape slice to 1D array: {e}"))
208                })?;
209            let y = self.transform_1d(&x_1d)?;
210
211            // Dynamic slicing based on the number of dimensions
212            match result.ndim() {
213                2 => {
214                    if axis == 0 {
215                        let mut result_slice = result.slice_mut(s![i, ..]);
216                        result_slice.assign(&y);
217                    } else {
218                        let mut result_slice = result.slice_mut(s![.., i]);
219                        result_slice.assign(&y);
220                    }
221                }
222                _ => {
223                    // For higher dimensions, we need more complex handling
224                    return Err(FFTError::ValueError(
225                        "CZT currently only supports 1D and 2D arrays".to_string(),
226                    ));
227                }
228            }
229        }
230
231        Ok(result)
232    }
233
234    /// Transform a 1D signal
235    fn transform_1d(&self, x: &Array1<Complex<f64>>) -> FFTResult<Array1<Complex<f64>>> {
236        if x.len() != self.n {
237            return Err(FFTError::ValueError(format!(
238                "Input size ({}) doesn't match CZT size ({})",
239                x.len(),
240                self.n
241            )));
242        }
243
244        // Multiply input by A(k) * w_k^2
245        let x_weighted: Array1<Complex<f64>> = Zip::from(x)
246            .and(&self.awk2)
247            .map_collect(|&xi, &awki| xi * awki);
248
249        // Create zero-padded array for FFT
250        let mut padded = Array1::zeros(self.nfft);
251        padded.slice_mut(s![..self.n]).assign(&x_weighted);
252
253        // Forward FFT
254        let x_fft_vec = crate::fft::fft(&padded.to_vec(), None)?;
255        let x_fft = Array1::from_vec(x_fft_vec);
256
257        // Multiply by pre-computed FFT of reciprocal chirp
258        let product: Array1<Complex<f64>> = Zip::from(&x_fft)
259            .and(&self.fwk2)
260            .map_collect(|&xi, &fi| xi * fi);
261
262        // Inverse FFT
263        let y_full_vec = crate::fft::ifft(&product.to_vec(), None)?;
264        let y_full = Array1::from_vec(y_full_vec);
265
266        // Extract relevant portion and multiply by w_k^2
267        let y_slice = y_full.slice(s![self.n - 1..self.n - 1 + self.m]);
268        let result: Array1<Complex<f64>> = Zip::from(&y_slice)
269            .and(&self.wk2)
270            .map_collect(|&yi, &wki| yi * wki);
271
272        Ok(result)
273    }
274}
275
276/// Functional interface to chirp z-transform
277///
278/// # Parameters
279/// - `x`: Input signal
280/// - `m`: Number of output points (default: length of x)
281/// - `w`: Ratio between points (default: exp(-2j*pi/m))
282/// - `a`: Starting point in complex plane (default: 1+0j)
283/// - `axis`: Axis along which to compute CZT (default: -1)
284#[allow(dead_code)]
285pub fn czt<S, D>(
286    x: &ArrayBase<S, D>,
287    m: Option<usize>,
288    w: Option<Complex<f64>>,
289    a: Option<Complex<f64>>,
290    axis: Option<i32>,
291) -> FFTResult<ArrayD<Complex<f64>>>
292where
293    S: Data<Elem = Complex<f64>>,
294    D: Dimension + RemoveAxis,
295{
296    let axis_actual = if let Some(ax) = axis {
297        if ax < 0 {
298            (x.ndim() as i32 + ax) as usize
299        } else {
300            ax as usize
301        }
302    } else {
303        x.ndim() - 1
304    };
305
306    let n = x.shape()[axis_actual];
307    let transform = CZT::new(n, m, w, a)?;
308    transform.transform(x, axis)
309}
310
311/// Compute a zoom FFT - partial DFT on a specified frequency range
312///
313/// Efficiently evaluates the DFT over a subset of frequency range.
314///
315/// # Parameters
316/// - `x`: Input signal
317/// - `m`: Number of output points
318/// - `f0`: Starting normalized frequency (0 to 1)
319/// - `f1`: Ending normalized frequency (0 to 1)
320/// - `oversampling`: Oversampling factor for frequency resolution
321#[allow(dead_code)]
322pub fn zoom_fft<S, D>(
323    x: &ArrayBase<S, D>,
324    m: usize,
325    f0: f64,
326    f1: f64,
327    oversampling: Option<f64>,
328) -> FFTResult<ArrayD<Complex<f64>>>
329where
330    S: Data<Elem = Complex<f64>>,
331    D: Dimension + RemoveAxis,
332{
333    if !(0.0..=1.0).contains(&f0) || !(0.0..=1.0).contains(&f1) {
334        return Err(FFTError::ValueError(
335            "Frequencies must be in range [0, 1]".to_string(),
336        ));
337    }
338
339    if f0 >= f1 {
340        return Err(FFTError::ValueError("f0 must be less than f1".to_string()));
341    }
342
343    let oversampling = oversampling.unwrap_or(2.0);
344    if oversampling < 1.0 {
345        return Err(FFTError::ValueError(
346            "Oversampling must be >= 1".to_string(),
347        ));
348    }
349
350    let ndim = x.ndim();
351    let axis = ndim - 1;
352    let n = x.shape()[axis];
353
354    // Compute CZT parameters for zoom FFT
355    let k0_float = f0 * n as f64 * oversampling;
356    let k1_float = f1 * n as f64 * oversampling;
357    let step = (k1_float - k0_float) / (m - 1) as f64;
358
359    let phi = 2.0 * PI * k0_float / (n as f64 * oversampling);
360    let a = Complex::from_polar(1.0, phi);
361
362    let theta = -2.0 * PI * step / (n as f64 * oversampling);
363    let w = Complex::from_polar(1.0, theta);
364
365    czt(x, Some(m), Some(w), Some(a), Some(axis as i32))
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use approx::assert_abs_diff_eq;
372
373    #[test]
374    fn test_czt_points() {
375        // Test default points (FFT-like)
376        let points = czt_points(4, None, None);
377        assert_eq!(points.len(), 4);
378
379        // Check that they lie on unit circle
380        for p in points.iter() {
381            assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
382        }
383
384        // Test custom spiral
385        let a = Complex::new(0.8, 0.0);
386        let w = Complex::from_polar(0.95, 0.1);
387        let points = czt_points(5, Some(a), Some(w));
388        assert_eq!(points.len(), 5);
389        assert!((points[0] - a).norm() < 1e-10);
390    }
391
392    #[test]
393    fn test_czt_as_fft() {
394        // CZT with default parameters should match FFT
395        let n = 8;
396        let x: Array1<Complex<f64>> = Array1::linspace(0.0, 7.0, n).mapv(|v| Complex::new(v, 0.0));
397
398        let czt_result = czt(&x.view(), None, None, None, None)
399            .expect("CZT computation should succeed for test data");
400
401        // czt returns ArrayD, need to convert to Array1
402        assert_eq!(czt_result.ndim(), 1);
403        let czt_result_1d: Array1<Complex<f64>> = czt_result
404            .into_dimensionality()
405            .expect("CZT result should convert to 1D array");
406
407        let fft_result_vec = crate::fft::fft(&x.to_vec(), None)
408            .expect("FFT computation should succeed for test data");
409        let fft_result = Array1::from_vec(fft_result_vec);
410
411        for i in 0..n {
412            assert!((czt_result_1d[i].re - fft_result[i].re).abs() < 1e-10);
413            assert!((czt_result_1d[i].im - fft_result[i].im).abs() < 1e-10);
414        }
415    }
416
417    #[test]
418    fn test_zoom_fft() {
419        // Create a simple signal with a clear frequency peak
420        let n = 64;
421        let t: Array1<f64> = Array1::linspace(0.0, 1.0, n);
422        let x: Array1<Complex<f64>> = t.mapv(|ti| {
423            let s = (2.0 * PI * 5.0 * ti).sin(); // Single frequency for simplicity
424            Complex::new(s, 0.0)
425        });
426
427        // Zoom in on a wider frequency range to ensure we capture the signal
428        let m = 16;
429        let zoom_result =
430            zoom_fft(&x.view(), m, 0.0, 0.5, None).expect("Zoom FFT should succeed for test data");
431
432        // Basic validation - check that we got a result and it's the right size
433        assert_eq!(zoom_result.ndim(), 1);
434        let zoom_result_1d: Array1<Complex<f64>> = zoom_result
435            .into_dimensionality()
436            .expect("Zoom FFT result should convert to 1D array");
437        assert_eq!(zoom_result_1d.len(), m);
438
439        // Simple check - there should be some non-zero values in the result
440        let has_nonzero = zoom_result_1d.iter().any(|&c| c.norm() > 1e-10);
441        assert!(has_nonzero, "Zoom FFT should produce some non-zero values");
442    }
443}