1use super::windows::dpss;
14use crate::error::{SignalError, SignalResult};
15use crate::simd_advanced::{simd_apply_window, SimdConfig};
16
17use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
18use scirs2_core::numeric::Complex64;
19use scirs2_core::numeric::{Float, NumCast};
20use scirs2_core::parallel_ops::*;
21use scirs2_core::random::{Rng, RngExt};
22use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
23use scirs2_core::validation::check_positive;
24use statrs::distribution::{ChiSquared, ContinuousCDF};
25use std::fmt::Debug;
26use std::sync::Arc;
27
28#[allow(unused_imports)]
29#[derive(Debug, Clone)]
31pub struct EnhancedMultitaperResult {
32 pub frequencies: Vec<f64>,
34 pub psd: Vec<f64>,
36 pub confidence_intervals: Option<(Vec<f64>, Vec<f64>)>,
38 pub dof: Option<f64>,
40 pub tapers: Option<Array2<f64>>,
42 pub eigenvalues: Option<Array1<f64>>,
44}
45
46#[derive(Debug, Clone)]
48pub struct MultitaperConfig {
49 pub fs: f64,
51 pub nw: f64,
53 pub k: usize,
55 pub nfft: Option<usize>,
57 pub onesided: bool,
59 pub adaptive: bool,
61 pub confidence: Option<f64>,
63 pub return_tapers: bool,
65 pub parallel: bool,
67 pub parallel_threshold: usize,
69 pub memory_optimized: bool,
71}
72
73impl Default for MultitaperConfig {
74 fn default() -> Self {
75 Self {
76 fs: 1.0,
77 nw: 4.0,
78 k: 7, nfft: None,
80 onesided: true,
81 adaptive: true,
82 confidence: None,
83 return_tapers: false,
84 parallel: true,
85 parallel_threshold: 1024,
86 memory_optimized: false,
87 }
88 }
89}
90
91#[allow(dead_code)]
136pub fn enhanced_pmtm<T>(
137 x: &[T],
138 config: &MultitaperConfig,
139) -> SignalResult<EnhancedMultitaperResult>
140where
141 T: Float + NumCast + Debug + Send + Sync,
142{
143 if x.is_empty() {
145 return Err(SignalError::ValueError("Input signal is empty".to_string()));
146 }
147
148 check_positive(config.nw, "nw")?;
149 check_positive(config.k, "k")?;
150 check_positive(config.fs, "fs")?;
151
152 let x_f64: Vec<f64> = x
154 .iter()
155 .map(|&val| {
156 NumCast::from(val).ok_or_else(|| {
157 SignalError::ValueError(format!("Could not convert {:?} to f64", val))
158 })
159 })
160 .collect::<SignalResult<Vec<f64>>>()?;
161
162 if let Some(confidence) = config.confidence {
164 if confidence <= 0.0 || confidence >= 1.0 {
165 return Err(SignalError::ValueError(format!(
166 "Confidence level must be between 0 and 1, got {}",
167 confidence
168 )));
169 }
170 }
171
172 if config.k > (2.0 * config.nw) as usize {
174 return Err(SignalError::ValueError(format!(
175 "Number of tapers k={} should not exceed 2*nw={}",
176 config.k,
177 2.0 * config.nw
178 )));
179 }
180
181 if config.nw < 1.0 {
183 return Err(SignalError::ValueError(format!(
184 "Time-bandwidth product nw={} must be at least 1.0",
185 config.nw
186 )));
187 }
188
189 if config.k == 0 {
190 return Err(SignalError::ValueError(
191 "Number of tapers k must be at least 1".to_string(),
192 ));
193 }
194
195 let min_signal_length = (4.0 * config.nw) as usize;
197 if x_f64.len() < min_signal_length {
198 return Err(SignalError::ValueError(format!(
199 "Signal length {} too short for nw={}. Minimum length is {}",
200 x_f64.len(),
201 config.nw,
202 min_signal_length
203 )));
204 }
205
206 if x_f64.len() < (8.0 * config.nw) as usize {
208 eprintln!("Warning: Signal length {} is relatively short for nw={}. Consider reducing nw or using a longer signal.",
209 x_f64.len(), config.nw);
210 }
211
212 for (i, &val) in x_f64.iter().enumerate() {
214 if !val.is_finite() {
215 return Err(SignalError::ValueError(format!(
216 "Non-finite value at index {}: {}",
217 i, val
218 )));
219 }
220 }
221
222 let n = x_f64.len();
223 let nfft = config.nfft.unwrap_or(next_power_of_two(n));
224
225 let memory_threshold = if config.k > 10 {
227 500_000 } else {
229 1_000_000 };
231 let use_chunked_processing = n > memory_threshold || config.memory_optimized;
232
233 if use_chunked_processing {
234 return compute_pmtm_chunked(&x_f64, config, nfft);
235 }
236
237 let (tapers, eigenvalues_opt) = dpss(n, config.nw, config.k, true)?;
239
240 let eigenvalues = eigenvalues_opt.ok_or_else(|| {
241 SignalError::ComputationError("Eigenvalues required but not returned from dpss".to_string())
242 })?;
243
244 for i in 1..eigenvalues.len() {
247 if eigenvalues[i] > eigenvalues[i - 1] {
248 return Err(SignalError::ComputationError(
249 "DPSS eigenvalues are not in descending order".to_string(),
250 ));
251 }
252 }
253
254 let min_concentration = 0.9;
256 for (i, &eigenval) in eigenvalues.iter().enumerate() {
257 if eigenval < min_concentration && i < config.k {
258 eprintln!("Warning: Taper {} has low concentration ratio {:.3}. Consider reducing k or increasing nw.",
259 i, eigenval);
260 }
261 }
262
263 for i in 0..config.k {
265 for j in (i + 1)..config.k {
266 let dot_product: f64 = tapers.row(i).dot(&tapers.row(j));
267 if dot_product.abs() > 1e-10 {
268 eprintln!(
269 "Warning: Tapers {} and {} have non-orthogonal dot product {:.2e}",
270 i, j, dot_product
271 );
272 }
273 }
274 }
275
276 let spectra = if config.parallel && n >= config.parallel_threshold {
278 compute_tapered_ffts_parallel(&x_f64, &tapers, nfft)?
279 } else {
280 compute_tapered_ffts_simd(&x_f64, &tapers, nfft)?
281 };
282
283 for i in 0..spectra.nrows() {
285 for j in 0..spectra.ncols() {
286 let val = spectra[[i, j]];
287 if !val.is_finite() || val < 0.0 {
288 return Err(SignalError::ComputationError(format!(
289 "Invalid spectral value at taper {}, frequency bin {}: {}",
290 i, j, val
291 )));
292 }
293 }
294 }
295
296 let (frequencies, psd) = if config.adaptive {
298 combine_spectra_adaptive(&spectra, &eigenvalues, config.fs, nfft, config.onesided)?
299 } else {
300 combine_spectra_standard(&spectra, &eigenvalues, config.fs, nfft, config.onesided)?
301 };
302
303 for (i, &val) in psd.iter().enumerate() {
305 if !val.is_finite() || val < 0.0 {
306 return Err(SignalError::ComputationError(format!(
307 "Invalid PSD value at frequency bin {}: {}",
308 i, val
309 )));
310 }
311 }
312
313 let confidence_intervals = if let Some(confidence_level) = config.confidence {
315 Some(compute_confidence_intervals(
316 &spectra,
317 &eigenvalues,
318 confidence_level,
319 )?)
320 } else {
321 None
322 };
323
324 let dof = Some(compute_effective_dof(&eigenvalues));
326
327 Ok(EnhancedMultitaperResult {
328 frequencies,
329 psd,
330 confidence_intervals,
331 dof,
332 tapers: if config.return_tapers {
333 Some(tapers)
334 } else {
335 None
336 },
337 eigenvalues: if config.return_tapers {
338 Some(eigenvalues)
339 } else {
340 None
341 },
342 })
343}
344
345#[allow(dead_code)]
347fn compute_tapered_ffts_simd(
348 signal: &[f64],
349 tapers: &Array2<f64>,
350 nfft: usize,
351) -> SignalResult<Array2<f64>> {
352 let k = tapers.nrows();
353 let n = signal.len();
354 let mut spectra = Array2::zeros((k, nfft));
355
356 let caps = PlatformCapabilities::detect();
358 let use_advanced_simd = caps.simd_available;
359
360 let memory_efficient = k > 20 || n > 50_000;
362
363 if memory_efficient {
364 let batch_size = if k > 50 { 8 } else { k };
366
367 for batch_start in (0..k).step_by(batch_size) {
368 let batch_end = (batch_start + batch_size).min(k);
369
370 for i in batch_start..batch_end {
371 let result = compute_single_tapered_fft_simd(
372 signal,
373 tapers.row(i),
374 nfft,
375 use_advanced_simd,
376 )?;
377 for (j, &val) in result.iter().enumerate() {
378 spectra[[i, j]] = val;
379 }
380 }
381 }
382 } else {
383 for i in 0..k {
385 let result =
386 compute_single_tapered_fft_simd(signal, tapers.row(i), nfft, use_advanced_simd)?;
387 for (j, &val) in result.iter().enumerate() {
388 spectra[[i, j]] = val;
389 }
390 }
391 }
392
393 validate_spectral_matrix(&spectra)?;
395
396 Ok(spectra)
397}
398
399#[allow(dead_code)]
401fn compute_single_tapered_fft_simd(
402 signal: &[f64],
403 taper: ArrayView1<f64>,
404 nfft: usize,
405 use_advanced_simd: bool,
406) -> SignalResult<Vec<f64>> {
407 let n = signal.len();
408
409 let mut tapered = vec![0.0; n];
411 let mut simd_success = false;
412
413 if use_advanced_simd && n >= 64 {
414 if let Ok(()) = try_advanced_simd_tapering(signal, &taper, &mut tapered) {
416 simd_success = true;
417 }
418 }
419
420 if !simd_success && n >= 16 {
421 match try_basic_simd_tapering(signal, &taper, &mut tapered) {
423 Ok(()) => simd_success = true,
424 Err(_) => {
425 if try_chunked_simd_tapering(signal, &taper, &mut tapered).is_ok() {
427 simd_success = true;
428 }
429 }
430 }
431 }
432
433 if !simd_success {
434 scalar_tapering_optimized(signal, &taper, &mut tapered);
436 }
437
438 validate_tapered_signal(&tapered)?;
440
441 let spectrum = enhanced_simd_fft(&tapered, nfft)?;
443
444 compute_validated_power_spectrum(&spectrum)
446}
447
448#[allow(dead_code)]
450fn validate_spectral_matrix(spectra: &Array2<f64>) -> SignalResult<()> {
451 let (k, nfft) = spectra.dim();
452
453 for i in 0..k {
454 for j in 0..nfft {
455 let val = spectra[[i, j]];
456
457 if !val.is_finite() {
458 return Err(SignalError::ComputationError(format!(
459 "Non-finite spectral value at taper {}, frequency bin {}: {}",
460 i, j, val
461 )));
462 }
463
464 if val < 0.0 {
465 return Err(SignalError::ComputationError(format!(
466 "Negative spectral value at taper {}, frequency bin {}: {}",
467 i, j, val
468 )));
469 }
470
471 if val > 1e200 {
473 return Err(SignalError::ComputationError(format!(
474 "Extremely large spectral value at taper {}, frequency bin {}: {:.2e}",
475 i, j, val
476 )));
477 }
478 }
479 }
480
481 for i in 0..k {
483 let row_sum: f64 = (0..nfft).map(|j| spectra[[i, j]]).sum();
484
485 if row_sum < 1e-100 {
486 return Err(SignalError::ComputationError(format!(
487 "Taper {} has extremely low total energy: {:.2e}",
488 i, row_sum
489 )));
490 }
491
492 if row_sum > 1e100 {
493 eprintln!(
494 "Warning: Taper {} has very high total energy: {:.2e}",
495 i, row_sum
496 );
497 }
498 }
499
500 Ok(())
501}
502
503#[allow(dead_code)]
505fn compute_tapered_ffts_parallel(
506 signal: &[f64],
507 tapers: &Array2<f64>,
508 nfft: usize,
509) -> SignalResult<Array2<f64>> {
510 let k = tapers.nrows();
511 let n = signal.len();
512 let signal_arc = Arc::new(signal.to_vec());
513
514 let results: Result<Vec<Vec<f64>>, SignalError> = (0..k)
516 .into_par_iter()
517 .map(|i| {
518 let signal_ref = signal_arc.clone();
519 let taper = tapers.row(i).to_owned();
520
521 let mut tapered = vec![0.0; n];
523 for j in 0..n {
524 tapered[j] = signal_ref[j] * taper[j];
525 }
526
527 let spectrum = enhanced_simd_fft(&tapered, nfft)?;
529
530 Ok(spectrum.iter().map(|c| c.norm_sqr()).collect())
532 })
533 .collect();
534
535 let results = results?;
536
537 let mut spectra = Array2::zeros((k, nfft));
539 for (i, row) in results.iter().enumerate() {
540 for (j, &val) in row.iter().enumerate() {
541 spectra[[i, j]] = val;
542 }
543 }
544
545 Ok(spectra)
546}
547
548#[allow(dead_code)]
550fn try_advanced_simd_tapering(
551 signal: &[f64],
552 taper: &ArrayView1<f64>,
553 tapered: &mut [f64],
554) -> Result<(), Box<dyn std::error::Error>> {
555 let config = SimdConfig::default();
556 let taper_vec: Vec<f64> = taper.iter().copied().collect();
557
558 simd_apply_window(signal, &taper_vec, tapered, &config).map_err(|e| format!("{}", e).into())
559}
560
561#[allow(dead_code)]
563fn try_basic_simd_tapering(
564 signal: &[f64],
565 taper: &ArrayView1<f64>,
566 tapered: &mut [f64],
567) -> SignalResult<()> {
568 let signal_view = ArrayView1::from(signal);
569 let _tapered_view = ArrayView1::from_shape(signal.len(), tapered)
570 .map_err(|e| SignalError::ComputationError(format!("Shape error: {}", e)))?;
571
572 let result = f64::simd_mul(&signal_view, taper);
573 for (i, &val) in result.iter().enumerate() {
574 tapered[i] = val;
575 }
576 Ok(())
577}
578
579#[allow(dead_code)]
581fn try_chunked_simd_tapering(
582 signal: &[f64],
583 taper: &ArrayView1<f64>,
584 tapered: &mut [f64],
585) -> SignalResult<()> {
586 let chunk_size = 256; for (_chunk_idx, chunk_data) in signal
589 .chunks(chunk_size)
590 .zip(
591 taper
592 .as_slice()
593 .expect("Operation failed")
594 .chunks(chunk_size),
595 )
596 .zip(tapered.chunks_mut(chunk_size))
597 .enumerate()
598 {
599 let ((sig_chunk, tap_chunk), out_chunk) = chunk_data;
600 let sig_view = ArrayView1::from(sig_chunk);
601 let tap_view = ArrayView1::from(tap_chunk);
602
603 let result = f64::simd_mul(&sig_view, &tap_view);
604 for (i, &val) in result.iter().enumerate() {
605 if i < out_chunk.len() {
606 out_chunk[i] = val;
607 }
608 }
609 }
610
611 Ok(())
612}
613
614#[allow(dead_code)]
616fn scalar_tapering_optimized(signal: &[f64], taper: &ArrayView1<f64>, tapered: &mut [f64]) {
617 let taper_slice = taper.as_slice().expect("Operation failed");
619 let chunks = signal.len() / 4;
620
621 for i in 0..chunks {
623 let base_idx = i * 4;
624 tapered[base_idx] = signal[base_idx] * taper_slice[base_idx];
625 tapered[base_idx + 1] = signal[base_idx + 1] * taper_slice[base_idx + 1];
626 tapered[base_idx + 2] = signal[base_idx + 2] * taper_slice[base_idx + 2];
627 tapered[base_idx + 3] = signal[base_idx + 3] * taper_slice[base_idx + 3];
628 }
629
630 for i in (chunks * 4)..signal.len() {
632 tapered[i] = signal[i] * taper_slice[i];
633 }
634}
635
636#[allow(dead_code)]
638fn validate_tapered_signal(tapered: &[f64]) -> SignalResult<()> {
639 for (i, &val) in tapered.iter().enumerate() {
641 if !val.is_finite() {
642 return Err(SignalError::ComputationError(format!(
643 "Non-finite value in _tapered signal at index {}: {}",
644 i, val
645 )));
646 }
647 }
648 Ok(())
649}
650
651#[allow(dead_code)]
653fn enhanced_simd_fft(x: &[f64], nfft: usize) -> SignalResult<Vec<Complex64>> {
654 if nfft == 0 {
656 return Err(SignalError::ValueError(
657 "FFT length cannot be zero".to_string(),
658 ));
659 }
660
661 if !nfft.is_power_of_two() {
662 eprintln!(
663 "Warning: FFT length {} is not a power of two, performance may be suboptimal",
664 nfft
665 );
666 }
667
668 if nfft > 1_000_000 {
669 eprintln!(
670 "Warning: Very large FFT length {}, consider chunked processing",
671 nfft
672 );
673 }
674
675 let mut padded = vec![Complex64::new(0.0, 0.0); nfft];
677 let copy_len = x.len().min(nfft);
678
679 if copy_len >= 64 {
681 let config = SimdConfig::default();
682 let unity_window = vec![1.0; copy_len];
683 let mut temp_real = vec![0.0; copy_len];
684
685 if simd_apply_window(&x[..copy_len], &unity_window, &mut temp_real, &config).is_ok() {
687 for (i, &val) in temp_real.iter().enumerate() {
688 padded[i] = Complex64::new(val, 0.0);
689 }
690 } else {
691 for i in 0..copy_len {
693 padded[i] = Complex64::new(x[i], 0.0);
694 }
695 }
696 } else {
697 for i in 0..copy_len {
698 padded[i] = Complex64::new(x[i], 0.0);
699 }
700 }
701
702 let mut buffer = padded.clone();
705
706 for (i, &val) in buffer.iter().enumerate() {
708 if !val.is_finite() {
709 return Err(SignalError::ComputationError(format!(
710 "Non-finite value in FFT input at index {}: {}",
711 i, val
712 )));
713 }
714 }
715
716 if nfft > 8192 {
718 let start = std::time::Instant::now();
719 let fft_result = scirs2_fft::fft(&buffer, Some(buffer.len()))
720 .map_err(|e| SignalError::ComputationError(format!("FFT failed: {}", e)))?;
721 for (i, c) in fft_result.iter().enumerate() {
722 buffer[i] = *c;
723 }
724 let duration = start.elapsed();
725
726 if duration.as_millis() > 1000 {
728 eprintln!(
729 "Warning: Large FFT took {:.2}s for length {}",
730 duration.as_secs_f64(),
731 nfft
732 );
733 }
734 } else {
735 let fft_result = scirs2_fft::fft(&buffer, Some(buffer.len()))
736 .map_err(|e| SignalError::ComputationError(format!("FFT failed: {}", e)))?;
737 for (i, c) in fft_result.iter().enumerate() {
738 buffer[i] = *c;
739 }
740 }
741
742 for (i, &val) in buffer.iter().enumerate() {
744 if !val.is_finite() {
745 return Err(SignalError::ComputationError(format!(
746 "Non-finite value in FFT output at index {}: {}",
747 i, val
748 )));
749 }
750 }
751
752 Ok(buffer)
753}
754
755#[allow(dead_code)]
757fn compute_validated_power_spectrum(spectrum: &[Complex64]) -> SignalResult<Vec<f64>> {
758 let mut power_spectrum = Vec::with_capacity(spectrum.len());
759 let mut max_power = 0.0;
760 let mut suspicious_values = 0;
761
762 for (i, &val) in spectrum.iter().enumerate() {
763 let power = val.norm_sqr();
764
765 if !power.is_finite() {
767 return Err(SignalError::ComputationError(format!(
768 "Non-finite power _spectrum value at frequency bin {}: {}",
769 i, power
770 )));
771 }
772
773 if power < 0.0 {
774 return Err(SignalError::ComputationError(format!(
775 "Negative power _spectrum value at frequency bin {}: {}",
776 i, power
777 )));
778 }
779
780 if power > 1e50 {
782 suspicious_values += 1;
783 if suspicious_values > spectrum.len() / 10 {
784 return Err(SignalError::ComputationError(
785 "Too many extremely large power _spectrum values detected".to_string(),
786 ));
787 }
788 }
789
790 max_power = max_power.max(power);
791 power_spectrum.push(power);
792 }
793
794 if max_power > 1e100 {
796 eprintln!(
797 "Warning: Very large maximum power _spectrum value: {:.2e}",
798 max_power
799 );
800 }
801
802 if suspicious_values > 0 {
803 eprintln!(
804 "Warning: {} suspicious power _spectrum values detected",
805 suspicious_values
806 );
807 }
808
809 Ok(power_spectrum)
810}
811
812#[allow(dead_code)]
814fn combine_spectra_standard(
815 spectra: &Array2<f64>,
816 eigenvalues: &Array1<f64>,
817 fs: f64,
818 nfft: usize,
819 onesided: bool,
820) -> SignalResult<(Vec<f64>, Vec<f64>)> {
821 let k = spectra.nrows();
822
823 let frequencies = if onesided {
825 let n_freqs = nfft / 2 + 1;
826 (0..n_freqs).map(|i| i as f64 * fs / nfft as f64).collect()
827 } else {
828 (0..nfft)
829 .map(|i| {
830 if i <= nfft / 2 {
831 i as f64 * fs / nfft as f64
832 } else {
833 (i as f64 - nfft as f64) * fs / nfft as f64
834 }
835 })
836 .collect()
837 };
838
839 let n_freqs = if onesided { nfft / 2 + 1 } else { nfft };
841 let mut psd = vec![0.0; n_freqs];
842
843 let weight_sum: f64 = eigenvalues.sum();
844 let scaling = if onesided {
845 2.0 / (fs * weight_sum)
846 } else {
847 1.0 / (fs * weight_sum)
848 };
849
850 for j in 0..n_freqs {
851 let mut weighted_sum = 0.0;
852 for i in 0..k {
853 weighted_sum += eigenvalues[i] * spectra[[i, j]];
854 }
855 psd[j] = weighted_sum * scaling;
856 }
857
858 Ok((frequencies, psd))
859}
860
861#[allow(dead_code)]
869fn combine_spectra_adaptive(
870 spectra: &Array2<f64>,
871 eigenvalues: &Array1<f64>,
872 fs: f64,
873 nfft: usize,
874 onesided: bool,
875) -> SignalResult<(Vec<f64>, Vec<f64>)> {
876 let k = spectra.nrows();
877 let n_freqs = if onesided { nfft / 2 + 1 } else { nfft };
878
879 if k == 0 || n_freqs == 0 {
881 return Err(SignalError::ValueError(
882 "Invalid dimensions for adaptive spectral combination".to_string(),
883 ));
884 }
885
886 let mut weights = Array2::zeros((k, n_freqs));
888 let mut psd = vec![0.0; n_freqs];
889
890 let max_iter = 20; let base_tolerance = 1e-12; let min_weight = 1e-16; let eigenvalue_ratio = eigenvalues[eigenvalues.len() - 1] / eigenvalues[0];
897 let regularization = if eigenvalue_ratio < 1e-10 {
898 1e-10 } else {
900 1e-14 };
902
903 let damping_start = 8; let damping_factor = 0.85; let lambda_sum: f64 = eigenvalues.sum();
908 for i in 0..k {
909 for j in 0..n_freqs {
910 weights[[i, j]] = eigenvalues[i] / lambda_sum;
911 }
912 }
913
914 let mut converged = false;
916 let mut convergence_history = Vec::new();
917 let mut oscillation_detector = 0;
918
919 for iter in 0..max_iter {
920 let old_psd = psd.clone();
921
922 for j in 0..n_freqs {
924 let mut weighted_sum = 0.0;
925 let mut weight_sum = 0.0;
926
927 for i in 0..k {
928 let w = weights[[i, j]];
929 weighted_sum += w * spectra[[i, j]];
930 weight_sum += w;
931 }
932
933 if weight_sum > min_weight {
935 psd[j] = weighted_sum / weight_sum;
936 } else {
937 let mut fallback_sum = 0.0;
939 for i in 0..k {
940 fallback_sum += eigenvalues[i] * spectra[[i, j]];
941 }
942 psd[j] = fallback_sum / lambda_sum;
943 }
944
945 psd[j] = psd[j].max(regularization);
947 }
948
949 for j in 0..n_freqs {
951 let mut new_weight_sum = 0.0;
952
953 for i in 0..k {
954 let lambda = eigenvalues[i];
955 let spectrum_val = spectra[[i, j]].max(regularization);
956 let psd_val = psd[j].max(regularization);
957
958 let ratio = psd_val / spectrum_val;
960 let bias_factor = if ratio > 1e-6 {
961 lambda / (lambda + ratio.powi(2))
962 } else {
963 lambda };
965
966 weights[[i, j]] = bias_factor.max(min_weight);
967 new_weight_sum += weights[[i, j]];
968 }
969
970 if new_weight_sum > min_weight {
972 for i in 0..k {
973 weights[[i, j]] /= new_weight_sum;
974 }
975 } else {
976 for i in 0..k {
978 weights[[i, j]] = 1.0 / k as f64;
979 }
980 }
981 }
982
983 let max_change = old_psd
985 .iter()
986 .zip(psd.iter())
987 .map(|(old, new)| {
988 let denominator = old.abs().max(new.abs()).max(regularization);
989 ((old - new) / denominator).abs()
990 })
991 .fold(0.0, f64::max);
992
993 let mean_change = old_psd
994 .iter()
995 .zip(psd.iter())
996 .map(|(old, new)| {
997 let denominator = old.abs().max(new.abs()).max(regularization);
998 ((old - new) / denominator).abs()
999 })
1000 .sum::<f64>()
1001 / n_freqs as f64;
1002
1003 let rms_change = (old_psd
1005 .iter()
1006 .zip(psd.iter())
1007 .map(|(old, new)| {
1008 let denominator = old.abs().max(new.abs()).max(regularization);
1009 ((old - new) / denominator).powi(2)
1010 })
1011 .sum::<f64>()
1012 / n_freqs as f64)
1013 .sqrt();
1014
1015 convergence_history.push(mean_change);
1017 if convergence_history.len() > 4 {
1018 convergence_history.remove(0);
1019
1020 if convergence_history.len() >= 4 {
1022 let recent_trend = convergence_history[3] - convergence_history[1];
1023 let prev_trend = convergence_history[2] - convergence_history[0];
1024 if recent_trend * prev_trend < 0.0 && recent_trend.abs() > base_tolerance {
1025 oscillation_detector += 1;
1026 }
1027 }
1028 }
1029
1030 let adaptive_tolerance = base_tolerance * (1.0 + iter as f64 * 0.1);
1032
1033 if max_change < adaptive_tolerance
1035 && mean_change < adaptive_tolerance * 0.1
1036 && rms_change < adaptive_tolerance * 0.5
1037 {
1038 converged = true;
1039 break;
1040 }
1041
1042 if oscillation_detector >= 3 && mean_change < adaptive_tolerance * 10.0 {
1044 converged = true;
1045 break;
1046 }
1047
1048 if iter > damping_start {
1050 let convergence_rate = if iter > 0 && convergence_history.len() > 1 {
1052 convergence_history[convergence_history.len() - 1]
1053 / convergence_history[convergence_history.len() - 2].max(1e-15)
1054 } else {
1055 1.0
1056 };
1057
1058 let adaptive_damping = if oscillation_detector > 0 {
1059 0.6 } else if mean_change > adaptive_tolerance * 50.0 {
1061 0.7 } else if convergence_rate > 0.95 {
1063 0.8 } else {
1065 damping_factor };
1067
1068 for j in 0..n_freqs {
1070 let local_change = (psd[j] - old_psd[j]).abs()
1071 / (psd[j].abs().max(old_psd[j].abs()).max(regularization));
1072 let local_damping = if local_change > adaptive_tolerance * 100.0 {
1073 adaptive_damping * 0.8 } else {
1075 adaptive_damping
1076 };
1077
1078 psd[j] = local_damping * psd[j] + (1.0 - local_damping) * old_psd[j];
1079 }
1080 }
1081 }
1082
1083 if !converged {
1085 let final_change = if let Some(&last_change) = convergence_history.last() {
1086 last_change
1087 } else {
1088 1.0
1089 };
1090
1091 if final_change < base_tolerance * 50.0 {
1093 eprintln!(
1094 "Info: Adaptive multitaper algorithm achieved acceptable convergence (final change: {:.2e}, oscillations: {})",
1095 final_change, oscillation_detector
1096 );
1097 } else if oscillation_detector > 0 {
1098 eprintln!(
1099 "Warning: Adaptive algorithm experienced {} oscillations but stabilized (final change: {:.2e})",
1100 oscillation_detector, final_change
1101 );
1102 } else {
1103 eprintln!(
1104 "Warning: Adaptive multitaper convergence incomplete after {} iterations (final change: {:.2e})",
1105 max_iter, final_change
1106 );
1107 eprintln!("Consider: increasing signal length, reducing k, or adjusting nw parameter");
1108 }
1109 }
1110
1111 let frequencies = if onesided {
1113 (0..n_freqs).map(|i| i as f64 * fs / nfft as f64).collect()
1114 } else {
1115 (0..nfft)
1116 .map(|i| {
1117 if i <= nfft / 2 {
1118 i as f64 * fs / nfft as f64
1119 } else {
1120 (i as f64 - nfft as f64) * fs / nfft as f64
1121 }
1122 })
1123 .collect()
1124 };
1125
1126 let scaling = if onesided { 2.0 / fs } else { 1.0 / fs };
1128 psd.iter_mut().for_each(|p| *p *= scaling);
1129
1130 Ok((frequencies, psd))
1131}
1132
1133#[allow(dead_code)]
1138fn compute_confidence_intervals(
1139 spectra: &Array2<f64>,
1140 eigenvalues: &Array1<f64>,
1141 confidence_level: f64,
1142) -> SignalResult<(Vec<f64>, Vec<f64>)> {
1143 let _k = spectra.nrows() as f64;
1144 let effective_k = compute_effective_dof(eigenvalues) / 2.0;
1146 let dof = 2.0 * effective_k; let chi2 = ChiSquared::new(dof).map_err(|e| {
1150 SignalError::ComputationError(format!("Failed to create chi-squared distribution: {}", e))
1151 })?;
1152
1153 let alpha = 1.0 - confidence_level;
1155 let lower_quantile = chi2.inverse_cdf(alpha / 2.0);
1156 let upper_quantile = chi2.inverse_cdf(1.0 - alpha / 2.0);
1157
1158 let lower_factor = dof / upper_quantile;
1159 let upper_factor = dof / lower_quantile;
1160
1161 let n_freqs = spectra.ncols();
1163 let mut lower_ci = vec![0.0; n_freqs];
1164 let mut upper_ci = vec![0.0; n_freqs];
1165
1166 let weight_sum: f64 = eigenvalues.sum();
1167
1168 for j in 0..n_freqs {
1169 let mut weighted_sum = 0.0;
1170 let mut variance_estimate = 0.0;
1171
1172 for i in 0..spectra.nrows() {
1174 weighted_sum += eigenvalues[i] * spectra[[i, j]];
1175 }
1176 let psd_estimate = weighted_sum / weight_sum;
1177
1178 for i in 0..spectra.nrows() {
1180 let deviation = spectra[[i, j]] - psd_estimate;
1181 variance_estimate += eigenvalues[i] * deviation * deviation;
1182 }
1183 variance_estimate /= weight_sum;
1184
1185 let scale_factor = (1.0 + variance_estimate / (psd_estimate * psd_estimate + 1e-15)).sqrt();
1187
1188 lower_ci[j] = psd_estimate * lower_factor / scale_factor;
1189 upper_ci[j] = psd_estimate * upper_factor * scale_factor;
1190
1191 lower_ci[j] = lower_ci[j].max(1e-15);
1193 upper_ci[j] = upper_ci[j].max(lower_ci[j] * 1.01); }
1195
1196 Ok((lower_ci, upper_ci))
1197}
1198
1199#[allow(dead_code)]
1201fn compute_effective_dof(eigenvalues: &Array1<f64>) -> f64 {
1202 let sum_lambda: f64 = eigenvalues.sum();
1203 let sum_lambda_sq: f64 = eigenvalues.iter().map(|&x| x * x).sum();
1204
1205 if sum_lambda_sq < 1e-15 || sum_lambda < 1e-15 {
1207 return 2.0; }
1209
1210 let dof = 2.0 * sum_lambda.powi(2) / sum_lambda_sq;
1211
1212 if dof < 1.0 {
1214 eprintln!(
1215 "Warning: Computed DOF ({:.2}) is less than 1, using minimum value",
1216 dof
1217 );
1218 2.0
1219 } else if dof > 2.0 * eigenvalues.len() as f64 {
1220 eprintln!(
1221 "Warning: Computed DOF ({:.2}) exceeds theoretical maximum",
1222 dof
1223 );
1224 2.0 * eigenvalues.len() as f64
1225 } else {
1226 dof
1227 }
1228}
1229
1230#[allow(dead_code)]
1238fn compute_pmtm_chunked(
1239 signal: &[f64],
1240 config: &MultitaperConfig,
1241 nfft: usize,
1242) -> SignalResult<EnhancedMultitaperResult> {
1243 let n = signal.len();
1244
1245 let signal_complexity = estimate_signal_complexity(signal);
1247 let memory_factor = if config.memory_optimized { 0.5 } else { 1.0 };
1248
1249 let base_chunk_size = match (config.k, signal_complexity) {
1250 (k, _) if k > 30 => 30_000, (k, complexity) if k > 15 && complexity > 2.0 => 40_000, (k, _) if k > 10 => 60_000, _ => 100_000, };
1255
1256 let chunk_size = ((base_chunk_size as f64 * memory_factor) as usize)
1257 .min(n / 8) .max((config.k * 25).min(n / 2)); let overlap_ratio = if signal_complexity > 3.0 {
1262 0.3 } else if config.k > 20 {
1264 0.25 } else {
1266 0.2 };
1268
1269 let overlap = (chunk_size as f64 * overlap_ratio) as usize;
1270 let step = chunk_size.saturating_sub(overlap).max(chunk_size / 2);
1271
1272 let n_chunks = (n + step - 1) / step; let n_freqs = if config.onesided { nfft / 2 + 1 } else { nfft };
1277 let mut psd_accumulator = vec![0.0; n_freqs];
1278 let mut weight_accumulator = vec![0.0; n_freqs];
1279 let mut frequencies = Vec::new();
1280
1281 for chunk_idx in 0..n_chunks {
1283 let start = chunk_idx * step;
1284 let end = (start + chunk_size).min(n);
1285
1286 let chunk_len = end - start;
1287 if chunk_len < config.k * 15 {
1288 continue;
1291 }
1292
1293 let chunk = &signal[start..end];
1295 let chunk_energy: f64 = chunk.iter().map(|&x| x * x).sum();
1296 if chunk_energy < 1e-20 {
1297 continue;
1299 }
1300
1301 let chunk = &signal[start..end];
1302 let chunk_len = chunk.len();
1303
1304 let (tapers, eigenvalues_opt) = dpss(chunk_len, config.nw, config.k, true)?;
1306 let eigenvalues = eigenvalues_opt.ok_or_else(|| {
1307 SignalError::ComputationError(
1308 "Eigenvalues required but not returned from dpss".to_string(),
1309 )
1310 })?;
1311
1312 let chunk_nfft = next_power_of_two(chunk_len);
1314
1315 let spectra = if config.parallel && chunk_len >= config.parallel_threshold {
1317 compute_tapered_ffts_parallel(chunk, &tapers, chunk_nfft)?
1318 } else {
1319 compute_tapered_ffts_simd(chunk, &tapers, chunk_nfft)?
1320 };
1321
1322 let (chunk_freqs, chunk_psd) = if config.adaptive {
1324 combine_spectra_adaptive(
1325 &spectra,
1326 &eigenvalues,
1327 config.fs,
1328 chunk_nfft,
1329 config.onesided,
1330 )?
1331 } else {
1332 combine_spectra_standard(
1333 &spectra,
1334 &eigenvalues,
1335 config.fs,
1336 chunk_nfft,
1337 config.onesided,
1338 )?
1339 };
1340
1341 if chunk_idx == 0 {
1343 frequencies = chunk_freqs.clone();
1344 }
1345
1346 let interpolated_psd = if chunk_freqs.len() != frequencies.len() {
1348 interpolate_psd(&chunk_freqs, &chunk_psd, &frequencies)?
1349 } else {
1350 chunk_psd
1351 };
1352
1353 let chunk_len_actual = end - start;
1355 let chunk_weight = (chunk_len_actual as f64 / n as f64)
1356 * (chunk_len_actual as f64 / chunk_size as f64).sqrt(); for (i, &psd_val) in interpolated_psd.iter().enumerate() {
1359 if i < psd_accumulator.len() && psd_val.is_finite() && psd_val > 0.0 {
1360 psd_accumulator[i] += psd_val * chunk_weight;
1361 weight_accumulator[i] += chunk_weight;
1362 }
1363 }
1364 }
1365
1366 for i in 0..psd_accumulator.len() {
1368 if weight_accumulator[i] > 0.0 {
1369 psd_accumulator[i] /= weight_accumulator[i];
1370 }
1371 }
1372
1373 Ok(EnhancedMultitaperResult {
1377 frequencies,
1378 psd: psd_accumulator,
1379 confidence_intervals: None, dof: Some(2.0 * config.k as f64 * n_chunks as f64), tapers: None, eigenvalues: None, })
1384}
1385
1386#[allow(dead_code)]
1388fn interpolate_psd(
1389 source_freqs: &[f64],
1390 source_psd: &[f64],
1391 target_freqs: &[f64],
1392) -> SignalResult<Vec<f64>> {
1393 if source_freqs.is_empty() || source_psd.is_empty() || target_freqs.is_empty() {
1394 return Err(SignalError::ValueError(
1395 "Empty frequency or PSD arrays".to_string(),
1396 ));
1397 }
1398
1399 let mut result = vec![0.0; target_freqs.len()];
1400
1401 for (i, &target_freq) in target_freqs.iter().enumerate() {
1402 let mut lower_idx = 0;
1404 let mut upper_idx = source_freqs.len() - 1;
1405
1406 for (j, &freq) in source_freqs.iter().enumerate() {
1407 if freq <= target_freq {
1408 lower_idx = j;
1409 } else {
1410 upper_idx = j;
1411 break;
1412 }
1413 }
1414
1415 if lower_idx == upper_idx {
1416 result[i] = source_psd[lower_idx];
1418 } else {
1419 let f1 = source_freqs[lower_idx];
1421 let f2 = source_freqs[upper_idx];
1422 let p1 = source_psd[lower_idx];
1423 let p2 = source_psd[upper_idx];
1424
1425 if (f2 - f1).abs() > 1e-15 {
1426 let weight = (target_freq - f1) / (f2 - f1);
1427 result[i] = p1 + weight * (p2 - p1);
1428 } else {
1429 result[i] = (p1 + p2) / 2.0;
1430 }
1431 }
1432 }
1433
1434 Ok(result)
1435}
1436
1437#[allow(dead_code)]
1444fn estimate_signal_complexity(signal: &[f64]) -> f64 {
1445 if signal.len() < 64 {
1446 return 1.0; }
1448
1449 let mean = signal.iter().sum::<f64>() / signal.len() as f64;
1451 let variance = signal.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / signal.len() as f64;
1452 let std_dev = variance.sqrt();
1453
1454 if std_dev < 1e-12 {
1455 return 0.5; }
1457
1458 let max_val = signal.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1460 let min_val = signal.iter().cloned().fold(f64::INFINITY, f64::min);
1461 let dynamic_range = (max_val - min_val) / std_dev;
1462
1463 let mut high_freq_energy = 0.0;
1465 for window in signal.windows(2) {
1466 high_freq_energy += (window[1] - window[0]).powi(2);
1467 }
1468 high_freq_energy /= signal.len() as f64;
1469 let high_freq_ratio = high_freq_energy / variance.max(1e-12);
1470
1471 let complexity = 1.0 +
1473 (dynamic_range / 10.0).min(2.0) + (high_freq_ratio * 5.0).min(2.0); complexity.min(5.0) }
1478
1479#[allow(dead_code)]
1481fn next_power_of_two(n: usize) -> usize {
1482 if n == 0 {
1483 return 1;
1484 }
1485
1486 let mut power = 1;
1487 while power < n {
1488 power <<= 1;
1489 }
1490 power
1491}
1492
1493#[allow(dead_code)]
1507pub fn enhanced_multitaper_spectrogram<T>(
1508 x: &[T],
1509 config: &SpectrogramConfig,
1510) -> SignalResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
1511where
1512 T: Float + NumCast + Debug + Send + Sync,
1513{
1514 if x.is_empty() {
1516 return Err(SignalError::ValueError("Input signal is empty".to_string()));
1517 }
1518
1519 check_positive(config.window_size, "window_size")?;
1520 check_positive(config.step, "step")?;
1521
1522 let x_f64: Vec<f64> = x
1524 .iter()
1525 .map(|&val| {
1526 NumCast::from(val).ok_or_else(|| {
1527 SignalError::ValueError(format!("Could not convert {:?} to f64", val))
1528 })
1529 })
1530 .collect::<SignalResult<Vec<f64>>>()?;
1531
1532 for (i, &val) in x_f64.iter().enumerate() {
1534 if !val.is_finite() {
1535 return Err(SignalError::ValueError(format!(
1536 "Non-finite value at index {}: {}",
1537 i, val
1538 )));
1539 }
1540 }
1541
1542 let n = x_f64.len();
1543 let window_size = config.window_size;
1544 let step = config.step;
1545
1546 if window_size > n {
1548 return Err(SignalError::ValueError(
1549 "Window size larger than signal length".to_string(),
1550 ));
1551 }
1552
1553 let n_windows = (n - window_size) / step + 1;
1554 if n_windows == 0 {
1555 return Err(SignalError::ValueError(
1556 "No complete windows in signal".to_string(),
1557 ));
1558 }
1559
1560 let mut mt_config = config.multitaper.clone();
1562 mt_config.nfft = Some(config.window_size);
1563
1564 let times: Vec<f64> = (0..n_windows)
1566 .map(|i| (i * step + window_size / 2) as f64 / config.fs)
1567 .collect();
1568
1569 let results: Vec<EnhancedMultitaperResult> = if config.multitaper.parallel
1571 && n_windows >= config.multitaper.parallel_threshold / window_size
1572 {
1573 let x_arc = Arc::new(x_f64);
1574
1575 (0..n_windows)
1576 .into_par_iter()
1577 .map(|i| {
1578 let start = i * step;
1579 let end = start + window_size;
1580 let window = &x_arc[start..end];
1581
1582 enhanced_pmtm(window, &mt_config).expect("Operation failed")
1583 })
1584 .collect()
1585 } else {
1586 (0..n_windows)
1588 .map(|i| {
1589 let start = i * step;
1590 let end = start + window_size;
1591 let window = &x_f64[start..end];
1592
1593 enhanced_pmtm(window, &mt_config)
1594 })
1595 .collect::<SignalResult<Vec<_>>>()?
1596 };
1597
1598 let frequencies = results[0].frequencies.clone();
1600 let n_freqs = frequencies.len();
1601
1602 let mut spectrogram = Array2::zeros((n_freqs, n_windows));
1604
1605 for (j, result) in results.iter().enumerate() {
1606 for (i, &psd_val) in result.psd.iter().enumerate() {
1607 spectrogram[[i, j]] = psd_val;
1608 }
1609 }
1610
1611 let epsilon = 1e-10;
1613 spectrogram.mapv_inplace(|x| (x + epsilon).log10() * 10.0); Ok((times, frequencies, spectrogram))
1616}
1617
1618#[derive(Debug, Clone)]
1620pub struct SpectrogramConfig {
1621 pub fs: f64,
1623 pub window_size: usize,
1625 pub step: usize,
1627 pub multitaper: MultitaperConfig,
1629}
1630
1631mod tests {
1632 use super::*;
1633 use std::f64::consts::PI;
1634
1635 #[test]
1636 fn test_enhanced_pmtm_basic() {
1637 let n = 256;
1640 let signal: Vec<f64> = (0..n)
1641 .map(|i| (2.0 * PI * 10.0 * i as f64 / 100.0).sin())
1642 .collect();
1643
1644 let mut config = MultitaperConfig::default();
1646 config.k = 4; config.adaptive = false; let result = enhanced_pmtm(&signal, &config).expect("Operation failed");
1650
1651 assert_eq!(result.frequencies.len(), result.psd.len());
1652 assert!(result.dof.is_some());
1653 }
1654
1655 #[test]
1656 fn test_enhanced_simd_fft() {
1657 let signal = vec![1.0, 0.0, -1.0, 0.0];
1658 let result = enhanced_simd_fft(&signal, 4).expect("Operation failed");
1659 assert_eq!(result.len(), 4);
1660 for val in result {
1662 assert!(val.re.is_finite() && val.im.is_finite());
1663 }
1664 }
1665
1666 #[test]
1667 fn test_dpss_small() {
1668 use super::super::windows::dpss;
1669 let n = 64;
1671 let nw = 2.5;
1672 let k = 4;
1673 let (tapers, ratios) = dpss(n, nw, k, true).expect("DPSS computation failed");
1674 assert_eq!(tapers.nrows(), k);
1675 assert_eq!(tapers.ncols(), n);
1676 assert!(ratios.is_some());
1677 }
1678
1679 #[test]
1680 fn test_dpss_n256() {
1681 use super::super::windows::dpss;
1682 eprintln!("Testing DPSS with n=256...");
1683 let n = 256;
1684 let nw = 4.0;
1685 let k = 4;
1686 let start = std::time::Instant::now();
1687 let (tapers, ratios) = dpss(n, nw, k, true).expect("DPSS computation failed");
1688 eprintln!("DPSS took {:?}", start.elapsed());
1689 assert_eq!(tapers.nrows(), k);
1690 assert_eq!(tapers.ncols(), n);
1691 assert!(ratios.is_some());
1692 }
1693}