1use crate::{next_fast_len, FFTError, FFTResult};
32use scirs2_core::ndarray::{Array1, Array2, Zip};
33use scirs2_core::numeric::Complex;
34use std::f64::consts::PI;
35
36#[derive(Clone, Debug)]
38pub struct SpiralContour {
39 pub a: Complex<f64>,
41 pub w: Complex<f64>,
43 pub m: usize,
45}
46
47impl SpiralContour {
48 pub fn unit_circle(m: usize) -> FFTResult<Self> {
54 if m == 0 {
55 return Err(FFTError::ValueError(
56 "Number of output points must be positive".to_string(),
57 ));
58 }
59 let w = Complex::from_polar(1.0, -2.0 * PI / m as f64);
60 Ok(SpiralContour {
61 a: Complex::new(1.0, 0.0),
62 w,
63 m,
64 })
65 }
66
67 pub fn zoom_range(m: usize, f0: f64, f1: f64, n: usize) -> FFTResult<Self> {
80 if m == 0 {
81 return Err(FFTError::ValueError(
82 "Number of output points must be positive".to_string(),
83 ));
84 }
85 if f0 < 0.0 || f1 > 1.0 || f0 >= f1 {
86 return Err(FFTError::ValueError(
87 "Frequencies must satisfy 0 <= f0 < f1 <= 1".to_string(),
88 ));
89 }
90
91 let phi_start = 2.0 * PI * f0;
92 let phi_end = 2.0 * PI * f1;
93 let a = Complex::from_polar(1.0, phi_start);
94
95 let step = if m > 1 {
96 (phi_end - phi_start) / (m - 1) as f64
97 } else {
98 0.0
99 };
100 let w = Complex::from_polar(1.0, -step);
101
102 Ok(SpiralContour { a, w, m })
103 }
104
105 pub fn log_spiral(m: usize, r0: f64, rho: f64, theta0: f64, dtheta: f64) -> FFTResult<Self> {
121 if m == 0 {
122 return Err(FFTError::ValueError(
123 "Number of output points must be positive".to_string(),
124 ));
125 }
126 if r0 <= 0.0 {
127 return Err(FFTError::ValueError(
128 "Starting radius must be positive".to_string(),
129 ));
130 }
131
132 let a = Complex::from_polar(r0, theta0);
133 let w = Complex::from_polar(1.0 / rho, -dtheta);
136
137 Ok(SpiralContour { a, w, m })
138 }
139
140 pub fn points(&self) -> Array1<Complex<f64>> {
142 (0..self.m)
143 .map(|k| self.a * self.w.powf(-(k as f64)))
144 .collect()
145 }
146}
147
148#[derive(Clone)]
150pub struct EnhancedCZT {
151 n: usize,
152 contour: SpiralContour,
153 nfft: usize,
154 awk2: Array1<Complex<f64>>,
156 fwk2: Array1<Complex<f64>>,
158 wk2: Array1<Complex<f64>>,
160}
161
162impl EnhancedCZT {
163 pub fn new(n: usize, contour: SpiralContour) -> FFTResult<Self> {
174 if n == 0 {
175 return Err(FFTError::ValueError(
176 "Input length must be positive".to_string(),
177 ));
178 }
179
180 let m = contour.m;
181 let a = contour.a;
182 let w = contour.w;
183 let max_size = n.max(m);
184 let nfft = next_fast_len(n + m - 1, false);
185
186 let wk2_full: Array1<Complex<f64>> = (0..max_size)
188 .map(|k| w.powf(k as f64 * k as f64 / 2.0))
189 .collect();
190
191 let awk2: Array1<Complex<f64>> =
193 (0..n).map(|k| a.powf(-(k as f64)) * wk2_full[k]).collect();
194
195 let mut chirp_vec = vec![Complex::new(0.0, 0.0); nfft];
197
198 for i in 0..m {
200 chirp_vec[n - 1 + i] = Complex::new(1.0, 0.0) / wk2_full[i];
201 }
202 for i in 1..n {
203 chirp_vec[n - 1 - i] = Complex::new(1.0, 0.0) / wk2_full[i];
204 }
205
206 let fwk2_vec = crate::fft::fft(&chirp_vec, None)?;
207 let fwk2 = Array1::from_vec(fwk2_vec);
208
209 let wk2: Array1<Complex<f64>> = wk2_full.slice(scirs2_core::ndarray::s![..m]).to_owned();
211
212 Ok(EnhancedCZT {
213 n,
214 contour,
215 nfft,
216 awk2,
217 fwk2,
218 wk2,
219 })
220 }
221
222 pub fn transform(&self, x: &[Complex<f64>]) -> FFTResult<Array1<Complex<f64>>> {
228 if x.len() != self.n {
229 return Err(FFTError::ValueError(format!(
230 "Input length ({}) does not match CZT engine size ({})",
231 x.len(),
232 self.n
233 )));
234 }
235
236 let x_arr = Array1::from_vec(x.to_vec());
237
238 let x_weighted: Array1<Complex<f64>> = Zip::from(&x_arr)
240 .and(&self.awk2)
241 .map_collect(|&xi, &awki| xi * awki);
242
243 let mut padded = vec![Complex::new(0.0, 0.0); self.nfft];
245 for (i, &val) in x_weighted.iter().enumerate() {
246 padded[i] = val;
247 }
248 let x_fft_vec = crate::fft::fft(&padded, None)?;
249 let x_fft = Array1::from_vec(x_fft_vec);
250
251 let product: Array1<Complex<f64>> = Zip::from(&x_fft)
253 .and(&self.fwk2)
254 .map_collect(|&xi, &fi| xi * fi);
255
256 let y_full_vec = crate::fft::ifft(&product.to_vec(), None)?;
258 let y_full = Array1::from_vec(y_full_vec);
259
260 let m = self.contour.m;
262 let y_slice = y_full.slice(scirs2_core::ndarray::s![self.n - 1..self.n - 1 + m]);
263 let result: Array1<Complex<f64>> = Zip::from(&y_slice)
264 .and(&self.wk2)
265 .map_collect(|&yi, &wki| yi * wki);
266
267 Ok(result)
268 }
269
270 pub fn transform_real(&self, x: &[f64]) -> FFTResult<Array1<Complex<f64>>> {
276 let x_complex: Vec<Complex<f64>> = x.iter().map(|&v| Complex::new(v, 0.0)).collect();
277 self.transform(&x_complex)
278 }
279
280 pub fn transform_batch(
288 &self,
289 signals: &Array2<Complex<f64>>,
290 ) -> FFTResult<Array2<Complex<f64>>> {
291 let (num_signals, signal_len) = signals.dim();
292 if signal_len != self.n {
293 return Err(FFTError::ValueError(format!(
294 "Signal length ({signal_len}) does not match CZT engine size ({})",
295 self.n
296 )));
297 }
298
299 let m = self.contour.m;
300 let mut results = Array2::zeros((num_signals, m));
301
302 for i in 0..num_signals {
303 let row = signals.row(i);
304 let row_vec: Vec<Complex<f64>> = row.iter().copied().collect();
305 let transformed = self.transform(&row_vec)?;
306 for (j, &val) in transformed.iter().enumerate() {
307 results[[i, j]] = val;
308 }
309 }
310
311 Ok(results)
312 }
313
314 pub fn points(&self) -> Array1<Complex<f64>> {
316 self.contour.points()
317 }
318
319 pub fn contour(&self) -> &SpiralContour {
321 &self.contour
322 }
323}
324
325pub fn iczt(
340 czt_values: &[Complex<f64>],
341 n: usize,
342 contour: &SpiralContour,
343) -> FFTResult<Array1<Complex<f64>>> {
344 let m = czt_values.len();
345 if m < n {
346 return Err(FFTError::ValueError(format!(
347 "Need at least {n} CZT values to reconstruct {n}-point signal, got {m}"
348 )));
349 }
350
351 let z_points = contour.points();
353
354 let mut v_mat = Array2::zeros((m, n));
356 for k in 0..m {
357 let z_k = z_points[k];
358 let mut z_power = Complex::new(1.0, 0.0);
359 for j in 0..n {
360 v_mat[[k, j]] = z_power;
361 z_power = z_power / z_k; }
363 }
364
365 let mut vhb = Array1::zeros(n);
368 for j in 0..n {
369 let mut sum = Complex::new(0.0, 0.0);
370 for k in 0..m {
371 sum += v_mat[[k, j]].conj() * czt_values[k];
372 }
373 vhb[j] = sum;
374 }
375
376 let mut vhv = Array2::zeros((n, n));
378 for i in 0..n {
379 for j in 0..n {
380 let mut sum = Complex::new(0.0, 0.0);
381 for k in 0..m {
382 sum += v_mat[[k, i]].conj() * v_mat[[k, j]];
383 }
384 vhv[[i, j]] = sum;
385 }
386 }
387
388 solve_complex_system(&vhv, &vhb)
390}
391
392fn solve_complex_system(
394 a: &Array2<Complex<f64>>,
395 b: &Array1<Complex<f64>>,
396) -> FFTResult<Array1<Complex<f64>>> {
397 let n = b.len();
398 let mut augmented = Array2::zeros((n, n + 1));
399
400 for i in 0..n {
402 for j in 0..n {
403 augmented[[i, j]] = a[[i, j]];
404 }
405 augmented[[i, n]] = b[i];
406 }
407
408 for col in 0..n {
410 let mut max_val = augmented[[col, col]].norm();
412 let mut max_row = col;
413 for row in (col + 1)..n {
414 let val = augmented[[row, col]].norm();
415 if val > max_val {
416 max_val = val;
417 max_row = row;
418 }
419 }
420
421 if max_val < 1e-14 {
422 return Err(FFTError::ComputationError(
423 "Singular or near-singular system in ICZT".to_string(),
424 ));
425 }
426
427 if max_row != col {
429 for j in 0..=n {
430 let tmp = augmented[[col, j]];
431 augmented[[col, j]] = augmented[[max_row, j]];
432 augmented[[max_row, j]] = tmp;
433 }
434 }
435
436 let pivot = augmented[[col, col]];
438 for row in (col + 1)..n {
439 let factor = augmented[[row, col]] / pivot;
440 for j in col..=n {
441 let val = augmented[[col, j]];
442 augmented[[row, j]] = augmented[[row, j]] - factor * val;
443 }
444 }
445 }
446
447 let mut x = Array1::zeros(n);
449 for i in (0..n).rev() {
450 let mut sum = augmented[[i, n]];
451 for j in (i + 1)..n {
452 sum = sum - augmented[[i, j]] * x[j];
453 }
454 x[i] = sum / augmented[[i, i]];
455 }
456
457 Ok(x)
458}
459
460pub fn czt_convolve(a: &[f64], b: &[f64]) -> FFTResult<Vec<f64>> {
479 if a.is_empty() || b.is_empty() {
480 return Err(FFTError::ValueError(
481 "Input sequences cannot be empty".to_string(),
482 ));
483 }
484
485 let conv_len = a.len() + b.len() - 1;
486 let nfft = next_fast_len(conv_len, false);
487
488 let mut a_padded: Vec<Complex<f64>> = a.iter().map(|&v| Complex::new(v, 0.0)).collect();
490 a_padded.resize(nfft, Complex::new(0.0, 0.0));
491
492 let mut b_padded: Vec<Complex<f64>> = b.iter().map(|&v| Complex::new(v, 0.0)).collect();
493 b_padded.resize(nfft, Complex::new(0.0, 0.0));
494
495 let a_fft = crate::fft::fft(&a_padded, None)?;
496 let b_fft = crate::fft::fft(&b_padded, None)?;
497
498 let product: Vec<Complex<f64>> = a_fft
500 .iter()
501 .zip(b_fft.iter())
502 .map(|(&ai, &bi)| ai * bi)
503 .collect();
504
505 let result_complex = crate::fft::ifft(&product, None)?;
507
508 Ok(result_complex.iter().take(conv_len).map(|c| c.re).collect())
510}
511
512pub fn adaptive_zoom_fft(
533 x: &[f64],
534 f0: f64,
535 f1: f64,
536 min_points: usize,
537 max_points: usize,
538) -> FFTResult<(Vec<f64>, Array1<Complex<f64>>)> {
539 if x.is_empty() {
540 return Err(FFTError::ValueError("Input signal is empty".to_string()));
541 }
542 if f0 < 0.0 || f1 > 1.0 || f0 >= f1 {
543 return Err(FFTError::ValueError(
544 "Frequency range must satisfy 0 <= f0 < f1 <= 1".to_string(),
545 ));
546 }
547 if min_points == 0 || max_points < min_points {
548 return Err(FFTError::ValueError(
549 "Point count must satisfy 0 < min_points <= max_points".to_string(),
550 ));
551 }
552
553 let n = x.len();
554
555 let freq_range = f1 - f0;
557 let rayleigh_resolution = 1.0 / n as f64;
558 let ideal_points = (freq_range / rayleigh_resolution).ceil() as usize;
559 let m = ideal_points.clamp(min_points, max_points);
560
561 let contour = SpiralContour::zoom_range(m, f0, f1, n)?;
563 let engine = EnhancedCZT::new(n, contour)?;
564
565 let spectrum = engine.transform_real(x)?;
566
567 let frequencies: Vec<f64> = (0..m)
569 .map(|k| {
570 if m > 1 {
571 f0 + k as f64 * (f1 - f0) / (m - 1) as f64
572 } else {
573 f0
574 }
575 })
576 .collect();
577
578 Ok((frequencies, spectrum))
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use approx::assert_abs_diff_eq;
585
586 #[test]
587 fn test_unit_circle_contour() {
588 let contour = SpiralContour::unit_circle(8).expect("Unit circle contour should succeed");
589 let pts = contour.points();
590 assert_eq!(pts.len(), 8);
591
592 for p in pts.iter() {
594 assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
595 }
596 }
597
598 #[test]
599 fn test_zoom_range_contour() {
600 let contour =
601 SpiralContour::zoom_range(16, 0.1, 0.3, 64).expect("Zoom range contour should succeed");
602 let pts = contour.points();
603 assert_eq!(pts.len(), 16);
604
605 for p in pts.iter() {
607 assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
608 }
609 }
610
611 #[test]
612 fn test_log_spiral_contour() {
613 let contour =
614 SpiralContour::log_spiral(10, 1.0, 0.95, 0.0, 0.1).expect("Log spiral should succeed");
615 let pts = contour.points();
616 assert_eq!(pts.len(), 10);
617
618 assert_abs_diff_eq!(pts[0].re, 1.0, epsilon = 1e-10);
620 assert_abs_diff_eq!(pts[0].im, 0.0, epsilon = 1e-10);
621
622 for k in 1..10 {
625 let expected_r = 0.95_f64.powi(k as i32);
626 assert_abs_diff_eq!(pts[k].norm(), expected_r, epsilon = 1e-8);
627 }
628 }
629
630 #[test]
631 fn test_enhanced_czt_matches_fft() {
632 let n = 16;
634 let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
635 let engine = EnhancedCZT::new(n, contour).expect("Engine creation should succeed");
636
637 let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
638
639 let czt_result = engine.transform(&x).expect("Transform should succeed");
640 let fft_result_vec = crate::fft::fft(&x, None).expect("FFT should succeed");
641 let fft_result = Array1::from_vec(fft_result_vec);
642
643 for i in 0..n {
644 assert_abs_diff_eq!(czt_result[i].re, fft_result[i].re, epsilon = 1e-8);
645 assert_abs_diff_eq!(czt_result[i].im, fft_result[i].im, epsilon = 1e-8);
646 }
647 }
648
649 #[test]
650 fn test_enhanced_czt_real_input() {
651 let n = 8;
652 let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
653 let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
654
655 let x: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
656 let result = engine
657 .transform_real(&x)
658 .expect("Real transform should succeed");
659
660 let expected_dc: f64 = x.iter().sum();
662 assert_abs_diff_eq!(result[0].re, expected_dc, epsilon = 1e-8);
663 }
664
665 #[test]
666 fn test_batch_czt() {
667 let n = 8;
668 let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
669 let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
670
671 let mut signals = Array2::zeros((3, n));
673 for i in 0..3 {
674 for j in 0..n {
675 signals[[i, j]] = Complex::new((i * n + j) as f64, 0.0);
676 }
677 }
678
679 let results = engine
680 .transform_batch(&signals)
681 .expect("Batch transform should succeed");
682 assert_eq!(results.dim(), (3, n));
683
684 for i in 0..3 {
686 let row_vec: Vec<Complex<f64>> = signals.row(i).iter().copied().collect();
687 let individual = engine
688 .transform(&row_vec)
689 .expect("Individual transform should succeed");
690 for j in 0..n {
691 assert_abs_diff_eq!(results[[i, j]].re, individual[j].re, epsilon = 1e-8);
692 assert_abs_diff_eq!(results[[i, j]].im, individual[j].im, epsilon = 1e-8);
693 }
694 }
695 }
696
697 #[test]
698 fn test_iczt_roundtrip() {
699 let n = 8;
700 let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
701 let engine = EnhancedCZT::new(n, contour.clone()).expect("Engine should succeed");
702
703 let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64 + 1.0, 0.0)).collect();
704
705 let czt_values = engine.transform(&x).expect("Forward CZT should succeed");
706 let czt_vec: Vec<Complex<f64>> = czt_values.iter().copied().collect();
707 let recovered = iczt(&czt_vec, n, &contour).expect("ICZT should succeed");
708
709 for i in 0..n {
710 assert_abs_diff_eq!(recovered[i].re, x[i].re, epsilon = 1e-6);
711 assert_abs_diff_eq!(recovered[i].im, x[i].im, epsilon = 1e-6);
712 }
713 }
714
715 #[test]
716 fn test_czt_convolve() {
717 let a = vec![1.0, 2.0, 3.0];
718 let b = vec![4.0, 5.0];
719
720 let result = czt_convolve(&a, &b).expect("Convolution should succeed");
721 assert_eq!(result.len(), 4); let expected = [4.0, 13.0, 22.0, 15.0];
725 for (i, (&r, &e)) in result.iter().zip(expected.iter()).enumerate() {
726 assert_abs_diff_eq!(r, e, epsilon = 1e-8,);
727 }
728 }
729
730 #[test]
731 fn test_czt_convolve_identity() {
732 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
734 let delta = vec![1.0];
735
736 let result = czt_convolve(&signal, &delta).expect("Identity convolution should succeed");
737 assert_eq!(result.len(), signal.len());
738
739 for (i, (&r, &s)) in result.iter().zip(signal.iter()).enumerate() {
740 assert_abs_diff_eq!(r, s, epsilon = 1e-10);
741 }
742 }
743
744 #[test]
745 fn test_adaptive_zoom_fft() {
746 let n = 256;
748 let freq = 0.15; let x: Vec<f64> = (0..n).map(|i| (2.0 * PI * freq * i as f64).sin()).collect();
750
751 let (frequencies, spectrum) =
752 adaptive_zoom_fft(&x, 0.1, 0.2, 16, 128).expect("Adaptive zoom FFT should succeed");
753
754 assert_eq!(frequencies.len(), spectrum.len());
755 assert!(frequencies.len() >= 16);
756 assert!(frequencies.len() <= 128);
757
758 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
760 let peak_idx = magnitudes
761 .iter()
762 .enumerate()
763 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
764 .map(|(i, _)| i)
765 .unwrap_or(0);
766
767 let peak_freq = frequencies[peak_idx];
769 assert!(
770 (peak_freq - freq).abs() < 0.02,
771 "Peak at {peak_freq:.4} should be near {freq:.4}"
772 );
773 }
774
775 #[test]
776 fn test_parseval_theorem_czt() {
777 let n = 16;
779 let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
780 let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
781
782 let x: Vec<Complex<f64>> = (0..n)
783 .map(|i| Complex::new((2.0 * PI * 3.0 * i as f64 / n as f64).sin(), 0.0))
784 .collect();
785
786 let czt_result = engine.transform(&x).expect("Transform should succeed");
787
788 let input_energy: f64 = x.iter().map(|c| c.norm_sqr()).sum();
789 let output_energy: f64 = czt_result.iter().map(|c| c.norm_sqr()).sum::<f64>() / n as f64;
790
791 assert_abs_diff_eq!(input_energy, output_energy, epsilon = 1e-8);
792 }
793
794 #[test]
795 fn test_czt_prime_length() {
796 let n = 13;
798 let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
799 let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
800
801 let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
802
803 let result = engine
804 .transform(&x)
805 .expect("Prime-length CZT should succeed");
806 assert_eq!(result.len(), n);
807
808 let expected_dc: f64 = (0..n).map(|i| i as f64).sum();
810 assert_abs_diff_eq!(result[0].re, expected_dc, epsilon = 1e-8);
811 }
812
813 #[test]
814 fn test_zoom_fft_resolves_close_frequencies() {
815 let n = 64;
817 let f1_norm = 0.15;
818 let f2_norm = 0.16;
819
820 let x: Vec<f64> = (0..n)
821 .map(|i| (2.0 * PI * f1_norm * i as f64).sin() + (2.0 * PI * f2_norm * i as f64).sin())
822 .collect();
823
824 let contour =
826 SpiralContour::zoom_range(128, 0.12, 0.20, n).expect("Zoom contour should succeed");
827 let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
828 let spectrum = engine.transform_real(&x).expect("Zoom CZT should succeed");
829
830 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
831 let max_mag = magnitudes.iter().copied().fold(0.0_f64, f64::max);
832
833 assert!(max_mag > 1.0, "Zoom should find spectral energy");
835 }
836
837 #[test]
838 fn test_error_handling() {
839 assert!(SpiralContour::unit_circle(0).is_err());
841 assert!(SpiralContour::zoom_range(0, 0.0, 0.5, 64).is_err());
842 assert!(SpiralContour::zoom_range(16, 0.5, 0.3, 64).is_err());
843 assert!(SpiralContour::log_spiral(10, -1.0, 0.95, 0.0, 0.1).is_err());
844
845 assert!(czt_convolve(&[], &[1.0]).is_err());
847 assert!(adaptive_zoom_fft(&[], 0.0, 0.5, 8, 64).is_err());
848 }
849}