1use crate::{next_fast_len, FFTError, FFTResult};
8use scirs2_core::ndarray::{
9 s, Array, Array1, ArrayBase, ArrayD, Axis, Data, Dimension, RemoveAxis, Zip,
10};
11use scirs2_core::numeric::Complex;
12use std::f64::consts::PI;
13
14#[allow(dead_code)]
27pub fn czt_points(
28 m: usize,
29 a: Option<Complex<f64>>,
30 w: Option<Complex<f64>>,
31) -> Array1<Complex<f64>> {
32 let a = a.unwrap_or(Complex::new(1.0, 0.0));
33 let k = Array1::linspace(0.0, (m - 1) as f64, m);
34
35 if let Some(w) = w {
36 k.mapv(|ki| a * w.powf(-ki))
38 } else {
39 k.mapv(|ki| a * (Complex::new(0.0, 2.0 * PI * ki / m as f64)).exp())
41 }
42}
43
44#[derive(Clone)]
48pub struct CZT {
49 n: usize,
50 m: usize,
51 w: Option<Complex<f64>>,
52 a: Complex<f64>,
53 nfft: usize,
54 awk2: Array1<Complex<f64>>,
55 fwk2: Array1<Complex<f64>>,
56 wk2: Array1<Complex<f64>>,
57}
58
59impl CZT {
60 pub fn new(
68 n: usize,
69 m: Option<usize>,
70 w: Option<Complex<f64>>,
71 a: Option<Complex<f64>>,
72 ) -> FFTResult<Self> {
73 if n < 1 {
74 return Err(FFTError::ValueError("n must be positive".to_string()));
75 }
76
77 let m = m.unwrap_or(n);
78 if m < 1 {
79 return Err(FFTError::ValueError("m must be positive".to_string()));
80 }
81
82 let a = a.unwrap_or(Complex::new(1.0, 0.0));
83 let max_size = n.max(m);
84 let k = Array1::linspace(0.0, (max_size - 1) as f64, max_size);
85
86 let (w, wk2) = if let Some(w) = w {
87 let wk2 = k.mapv(|ki| w.powf(ki * ki / 2.0));
89 (Some(w), wk2)
90 } else {
91 let w = (-2.0 * PI * Complex::<f64>::i() / m as f64).exp();
93 let wk2 = k.mapv(|ki| {
94 let ki_i64 = ki as i64;
95 let phase = -(PI * ((ki_i64 * ki_i64) % (2 * m as i64)) as f64) / m as f64;
96 Complex::from_polar(1.0, phase)
97 });
98 (Some(w), wk2)
99 };
100
101 let nfft = next_fast_len(n + m - 1, false);
103
104 let awk2: Array1<Complex<f64>> = (0..n).map(|k| a.powf(-(k as f64)) * wk2[k]).collect();
106
107 let mut chirp_vec = vec![Complex::new(0.0, 0.0); nfft];
109
110 for i in 1..n {
112 chirp_vec[n - 1 - i] = Complex::new(1.0, 0.0) / wk2[i];
113 }
114 for i in 0..m {
115 chirp_vec[n - 1 + i] = Complex::new(1.0, 0.0) / wk2[i];
116 }
117
118 let chirp_array = Array1::from_vec(chirp_vec);
119 let fwk2_vec = crate::fft::fft(&chirp_array.to_vec(), None)?;
120 let fwk2 = Array1::from_vec(fwk2_vec);
121
122 Ok(CZT {
123 n,
124 m,
125 w,
126 a,
127 nfft,
128 awk2,
129 fwk2,
130 wk2: wk2.slice(s![..m]).to_owned(),
131 })
132 }
133
134 pub fn points(&self) -> Array1<Complex<f64>> {
136 czt_points(self.m, Some(self.a), self.w)
137 }
138
139 pub fn transform<S, D>(
145 &self,
146 x: &ArrayBase<S, D>,
147 axis: Option<i32>,
148 ) -> FFTResult<ArrayD<Complex<f64>>>
149 where
150 S: Data<Elem = Complex<f64>>,
151 D: Dimension + RemoveAxis,
152 {
153 let ndim = x.ndim();
154 let axis = if let Some(ax) = axis {
155 if ax < 0 {
156 let ax_pos = (ndim as i32 + ax) as usize;
157 if ax_pos >= ndim {
158 return Err(FFTError::ValueError("Invalid axis".to_string()));
159 }
160 ax_pos
161 } else {
162 ax as usize
163 }
164 } else {
165 ndim - 1
166 };
167
168 let axis_len = x.shape()[axis];
169 if axis_len != self.n {
170 return Err(FFTError::ValueError(format!(
171 "Input size ({}) doesn't match CZT size ({})",
172 axis_len, self.n
173 )));
174 }
175
176 let mut outputshape = x.shape().to_vec();
178 outputshape[axis] = self.m;
179 let mut result = Array::<Complex<f64>, _>::zeros(outputshape).into_dyn();
180
181 if x.ndim() == 1 {
184 let x_1d: Array1<Complex<f64>> = x
185 .to_owned()
186 .into_shape_with_order(x.len())
187 .map_err(|e| {
188 FFTError::ComputationError(format!("Failed to reshape input array to 1D: {e}"))
189 })?
190 .into_dimensionality()
191 .map_err(|e| {
192 FFTError::ComputationError(format!(
193 "Failed to convert array dimensionality: {e}"
194 ))
195 })?;
196 let y = self.transform_1d(&x_1d)?;
197 return Ok(y.into_dyn());
198 }
199
200 for (i, x_slice) in x.axis_iter(Axis(axis)).enumerate() {
202 let x_1d: Array1<Complex<f64>> = x_slice
204 .to_owned()
205 .into_shape_with_order(x_slice.len())
206 .map_err(|e| {
207 FFTError::ComputationError(format!("Failed to reshape slice to 1D array: {e}"))
208 })?;
209 let y = self.transform_1d(&x_1d)?;
210
211 match result.ndim() {
213 2 => {
214 if axis == 0 {
215 let mut result_slice = result.slice_mut(s![i, ..]);
216 result_slice.assign(&y);
217 } else {
218 let mut result_slice = result.slice_mut(s![.., i]);
219 result_slice.assign(&y);
220 }
221 }
222 _ => {
223 return Err(FFTError::ValueError(
225 "CZT currently only supports 1D and 2D arrays".to_string(),
226 ));
227 }
228 }
229 }
230
231 Ok(result)
232 }
233
234 fn transform_1d(&self, x: &Array1<Complex<f64>>) -> FFTResult<Array1<Complex<f64>>> {
236 if x.len() != self.n {
237 return Err(FFTError::ValueError(format!(
238 "Input size ({}) doesn't match CZT size ({})",
239 x.len(),
240 self.n
241 )));
242 }
243
244 let x_weighted: Array1<Complex<f64>> = Zip::from(x)
246 .and(&self.awk2)
247 .map_collect(|&xi, &awki| xi * awki);
248
249 let mut padded = Array1::zeros(self.nfft);
251 padded.slice_mut(s![..self.n]).assign(&x_weighted);
252
253 let x_fft_vec = crate::fft::fft(&padded.to_vec(), None)?;
255 let x_fft = Array1::from_vec(x_fft_vec);
256
257 let product: Array1<Complex<f64>> = Zip::from(&x_fft)
259 .and(&self.fwk2)
260 .map_collect(|&xi, &fi| xi * fi);
261
262 let y_full_vec = crate::fft::ifft(&product.to_vec(), None)?;
264 let y_full = Array1::from_vec(y_full_vec);
265
266 let y_slice = y_full.slice(s![self.n - 1..self.n - 1 + self.m]);
268 let result: Array1<Complex<f64>> = Zip::from(&y_slice)
269 .and(&self.wk2)
270 .map_collect(|&yi, &wki| yi * wki);
271
272 Ok(result)
273 }
274}
275
276#[allow(dead_code)]
285pub fn czt<S, D>(
286 x: &ArrayBase<S, D>,
287 m: Option<usize>,
288 w: Option<Complex<f64>>,
289 a: Option<Complex<f64>>,
290 axis: Option<i32>,
291) -> FFTResult<ArrayD<Complex<f64>>>
292where
293 S: Data<Elem = Complex<f64>>,
294 D: Dimension + RemoveAxis,
295{
296 let axis_actual = if let Some(ax) = axis {
297 if ax < 0 {
298 (x.ndim() as i32 + ax) as usize
299 } else {
300 ax as usize
301 }
302 } else {
303 x.ndim() - 1
304 };
305
306 let n = x.shape()[axis_actual];
307 let transform = CZT::new(n, m, w, a)?;
308 transform.transform(x, axis)
309}
310
311#[allow(dead_code)]
322pub fn zoom_fft<S, D>(
323 x: &ArrayBase<S, D>,
324 m: usize,
325 f0: f64,
326 f1: f64,
327 oversampling: Option<f64>,
328) -> FFTResult<ArrayD<Complex<f64>>>
329where
330 S: Data<Elem = Complex<f64>>,
331 D: Dimension + RemoveAxis,
332{
333 if !(0.0..=1.0).contains(&f0) || !(0.0..=1.0).contains(&f1) {
334 return Err(FFTError::ValueError(
335 "Frequencies must be in range [0, 1]".to_string(),
336 ));
337 }
338
339 if f0 >= f1 {
340 return Err(FFTError::ValueError("f0 must be less than f1".to_string()));
341 }
342
343 let oversampling = oversampling.unwrap_or(2.0);
344 if oversampling < 1.0 {
345 return Err(FFTError::ValueError(
346 "Oversampling must be >= 1".to_string(),
347 ));
348 }
349
350 let ndim = x.ndim();
351 let axis = ndim - 1;
352 let n = x.shape()[axis];
353
354 let k0_float = f0 * n as f64 * oversampling;
356 let k1_float = f1 * n as f64 * oversampling;
357 let step = (k1_float - k0_float) / (m - 1) as f64;
358
359 let phi = 2.0 * PI * k0_float / (n as f64 * oversampling);
360 let a = Complex::from_polar(1.0, phi);
361
362 let theta = -2.0 * PI * step / (n as f64 * oversampling);
363 let w = Complex::from_polar(1.0, theta);
364
365 czt(x, Some(m), Some(w), Some(a), Some(axis as i32))
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use approx::assert_abs_diff_eq;
372
373 #[test]
374 fn test_czt_points() {
375 let points = czt_points(4, None, None);
377 assert_eq!(points.len(), 4);
378
379 for p in points.iter() {
381 assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
382 }
383
384 let a = Complex::new(0.8, 0.0);
386 let w = Complex::from_polar(0.95, 0.1);
387 let points = czt_points(5, Some(a), Some(w));
388 assert_eq!(points.len(), 5);
389 assert!((points[0] - a).norm() < 1e-10);
390 }
391
392 #[test]
393 fn test_czt_as_fft() {
394 let n = 8;
396 let x: Array1<Complex<f64>> = Array1::linspace(0.0, 7.0, n).mapv(|v| Complex::new(v, 0.0));
397
398 let czt_result = czt(&x.view(), None, None, None, None)
399 .expect("CZT computation should succeed for test data");
400
401 assert_eq!(czt_result.ndim(), 1);
403 let czt_result_1d: Array1<Complex<f64>> = czt_result
404 .into_dimensionality()
405 .expect("CZT result should convert to 1D array");
406
407 let fft_result_vec = crate::fft::fft(&x.to_vec(), None)
408 .expect("FFT computation should succeed for test data");
409 let fft_result = Array1::from_vec(fft_result_vec);
410
411 for i in 0..n {
412 assert!((czt_result_1d[i].re - fft_result[i].re).abs() < 1e-10);
413 assert!((czt_result_1d[i].im - fft_result[i].im).abs() < 1e-10);
414 }
415 }
416
417 #[test]
418 fn test_zoom_fft() {
419 let n = 64;
421 let t: Array1<f64> = Array1::linspace(0.0, 1.0, n);
422 let x: Array1<Complex<f64>> = t.mapv(|ti| {
423 let s = (2.0 * PI * 5.0 * ti).sin(); Complex::new(s, 0.0)
425 });
426
427 let m = 16;
429 let zoom_result =
430 zoom_fft(&x.view(), m, 0.0, 0.5, None).expect("Zoom FFT should succeed for test data");
431
432 assert_eq!(zoom_result.ndim(), 1);
434 let zoom_result_1d: Array1<Complex<f64>> = zoom_result
435 .into_dimensionality()
436 .expect("Zoom FFT result should convert to 1D array");
437 assert_eq!(zoom_result_1d.len(), m);
438
439 let has_nonzero = zoom_result_1d.iter().any(|&c| c.norm() > 1e-10);
441 assert!(has_nonzero, "Zoom FFT should produce some non-zero values");
442 }
443}