Skip to main content

scirs2_stats/
spectral_density.rs

1//! Spectral Density Estimation
2//!
3//! This module provides non-parametric spectral density estimation methods
4//! for time series analysis:
5//!
6//! - **Periodogram**: raw spectral estimate (unsmoothed)
7//! - **Welch's method**: averaged modified periodograms with overlapping segments
8//! - **Bartlett's method**: averaged periodograms with non-overlapping segments
9//! - **Cross-spectral density**: joint spectral analysis of two series
10//! - **Coherence function**: squared coherence (normalized cross-spectrum magnitude)
11//! - **Spectral Granger causality**: frequency-domain causality measure
12//!
13//! All methods use a pure-Rust DFT implementation (no external FFT crate required
14//! for correctness; OxiFFT can be plugged in for performance).
15//!
16//! # References
17//!
18//! - Welch, P.D. (1967). The Use of Fast Fourier Transform for the Estimation
19//!   of Power Spectra. IEEE Transactions on Audio and Electroacoustics.
20//! - Bartlett, M.S. (1948). Smoothing Periodograms from Time-Series with
21//!   Continuous Spectra. Nature.
22//! - Geweke, J. (1982). Measurement of Linear Dependence and Feedback Between
23//!   Multiple Time Series. JASA.
24
25use crate::error::{StatsError, StatsResult};
26use scirs2_core::ndarray::{Array1, ArrayView1};
27use std::f64::consts::PI;
28
29// ---------------------------------------------------------------------------
30// Result types
31// ---------------------------------------------------------------------------
32
33/// Result of a spectral density estimation
34#[derive(Debug, Clone)]
35pub struct SpectralDensityResult {
36    /// Frequencies (in cycles per sample, [0, 0.5])
37    pub frequencies: Array1<f64>,
38    /// Power spectral density estimates
39    pub psd: Array1<f64>,
40    /// Number of segments used (for Welch/Bartlett)
41    pub n_segments: usize,
42    /// Effective bandwidth
43    pub bandwidth: f64,
44}
45
46/// Result of a cross-spectral density estimation
47#[derive(Debug, Clone)]
48pub struct CrossSpectralResult {
49    /// Frequencies
50    pub frequencies: Array1<f64>,
51    /// Cross-spectral density (real part)
52    pub csd_real: Array1<f64>,
53    /// Cross-spectral density (imaginary part)
54    pub csd_imag: Array1<f64>,
55    /// Magnitude of the cross-spectrum
56    pub csd_magnitude: Array1<f64>,
57    /// Phase of the cross-spectrum (radians)
58    pub csd_phase: Array1<f64>,
59    /// Power spectral density of x
60    pub psd_x: Array1<f64>,
61    /// Power spectral density of y
62    pub psd_y: Array1<f64>,
63}
64
65/// Result of a coherence analysis
66#[derive(Debug, Clone)]
67pub struct CoherenceResult {
68    /// Frequencies
69    pub frequencies: Array1<f64>,
70    /// Squared coherence (in [0, 1])
71    pub coherence_sq: Array1<f64>,
72    /// Phase spectrum (radians)
73    pub phase: Array1<f64>,
74    /// Gain spectrum (|Sxy| / Sxx)
75    pub gain: Array1<f64>,
76}
77
78/// Result of spectral Granger causality analysis
79#[derive(Debug, Clone)]
80pub struct SpectralGrangerResult {
81    /// Frequencies
82    pub frequencies: Array1<f64>,
83    /// Spectral Granger causality from x to y at each frequency
84    pub causality_x_to_y: Array1<f64>,
85    /// Spectral Granger causality from y to x at each frequency
86    pub causality_y_to_x: Array1<f64>,
87    /// Total spectral interdependence
88    pub total_interdependence: Array1<f64>,
89}
90
91// ---------------------------------------------------------------------------
92// Window functions
93// ---------------------------------------------------------------------------
94
95/// Window function types for spectral estimation
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum Window {
98    /// Rectangular (no windowing)
99    Rectangular,
100    /// Hann (raised cosine)
101    Hann,
102    /// Hamming
103    Hamming,
104    /// Blackman
105    Blackman,
106    /// Bartlett (triangular)
107    Bartlett,
108    /// Tukey (tapered cosine) with alpha parameter stored separately
109    Tukey,
110}
111
112/// Generate window coefficients
113fn window_coefficients(window: Window, n: usize, alpha: f64) -> Array1<f64> {
114    let nf = n as f64;
115    Array1::from_vec(
116        (0..n)
117            .map(|i| {
118                let t = i as f64;
119                match window {
120                    Window::Rectangular => 1.0,
121                    Window::Hann => 0.5 * (1.0 - (2.0 * PI * t / (nf - 1.0)).cos()),
122                    Window::Hamming => 0.54 - 0.46 * (2.0 * PI * t / (nf - 1.0)).cos(),
123                    Window::Blackman => {
124                        0.42 - 0.5 * (2.0 * PI * t / (nf - 1.0)).cos()
125                            + 0.08 * (4.0 * PI * t / (nf - 1.0)).cos()
126                    }
127                    Window::Bartlett => {
128                        if n <= 1 {
129                            1.0
130                        } else {
131                            1.0 - (2.0 * t / (nf - 1.0) - 1.0).abs()
132                        }
133                    }
134                    Window::Tukey => {
135                        let a = alpha.max(0.0).min(1.0);
136                        if a == 0.0 {
137                            1.0
138                        } else if a >= 1.0 {
139                            0.5 * (1.0 - (2.0 * PI * t / (nf - 1.0)).cos())
140                        } else {
141                            let boundary = a * (nf - 1.0) / 2.0;
142                            if t < boundary {
143                                0.5 * (1.0 - (PI * t / boundary).cos())
144                            } else if t > (nf - 1.0) - boundary {
145                                0.5 * (1.0 - (PI * ((nf - 1.0) - t) / boundary).cos())
146                            } else {
147                                1.0
148                            }
149                        }
150                    }
151                }
152            })
153            .collect(),
154    )
155}
156
157/// Window power (sum of squared coefficients / n), used for PSD normalization
158fn window_power(w: &Array1<f64>) -> f64 {
159    let n = w.len() as f64;
160    if n == 0.0 {
161        return 1.0;
162    }
163    w.iter().map(|&v| v * v).sum::<f64>() / n
164}
165
166// ---------------------------------------------------------------------------
167// DFT helpers
168// ---------------------------------------------------------------------------
169
170/// Compute DFT of a real-valued signal, returning complex values for
171/// non-negative frequencies only (N/2 + 1 values).
172/// Returns (real_parts, imag_parts).
173fn rfft(x: &[f64]) -> (Vec<f64>, Vec<f64>) {
174    let n = x.len();
175    let n_out = n / 2 + 1;
176    let mut real = vec![0.0; n_out];
177    let mut imag = vec![0.0; n_out];
178    let nf = n as f64;
179    for k in 0..n_out {
180        let mut re = 0.0;
181        let mut im = 0.0;
182        for t in 0..n {
183            let angle = 2.0 * PI * (k as f64) * (t as f64) / nf;
184            re += x[t] * angle.cos();
185            im -= x[t] * angle.sin();
186        }
187        real[k] = re;
188        imag[k] = im;
189    }
190    (real, imag)
191}
192
193/// Compute the power spectral density from DFT coefficients.
194/// Returns one-sided PSD (scaled by 2/N except at DC and Nyquist).
195fn dft_to_psd(real: &[f64], imag: &[f64], n: usize, fs: f64, win_power: f64) -> Vec<f64> {
196    let n_out = real.len();
197    let scale = 1.0 / (fs * (n as f64) * win_power);
198    let mut psd = vec![0.0; n_out];
199    for k in 0..n_out {
200        let power = real[k] * real[k] + imag[k] * imag[k];
201        psd[k] = power * scale;
202        // Double for one-sided (except DC and Nyquist)
203        if k > 0 && k < n_out - 1 {
204            psd[k] *= 2.0;
205        }
206    }
207    psd
208}
209
210// ---------------------------------------------------------------------------
211// Periodogram
212// ---------------------------------------------------------------------------
213
214/// Compute the periodogram (raw spectral estimate) of a time series.
215///
216/// # Arguments
217/// * `x` - Time series data
218/// * `window` - Window function to apply (default: `Hann`)
219/// * `detrend` - If true, remove the mean before computing
220///
221/// # Example
222/// ```
223/// use scirs2_stats::spectral_density::{periodogram, Window};
224/// use scirs2_core::ndarray::Array1;
225///
226/// // Sine wave at frequency 0.1 (cycles/sample)
227/// let n = 256;
228/// let x = Array1::from_vec((0..n).map(|i| {
229///     (2.0 * std::f64::consts::PI * 0.1 * i as f64).sin()
230/// }).collect());
231/// let result = periodogram(&x.view(), Window::Hann, true).expect("periodogram failed");
232/// assert_eq!(result.frequencies.len(), result.psd.len());
233/// // Peak should be near frequency 0.1
234/// let peak_idx = result.psd.iter()
235///     .enumerate()
236///     .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
237///     .map(|(i, _)| i)
238///     .unwrap_or(0);
239/// assert!((result.frequencies[peak_idx] - 0.1).abs() < 0.02);
240/// ```
241pub fn periodogram(
242    x: &ArrayView1<f64>,
243    window: Window,
244    detrend: bool,
245) -> StatsResult<SpectralDensityResult> {
246    let n = x.len();
247    if n < 4 {
248        return Err(StatsError::InsufficientData(
249            "periodogram requires at least 4 data points".into(),
250        ));
251    }
252    let fs = 1.0; // normalized sampling frequency
253                  // Detrend (remove mean)
254    let mean = if detrend {
255        x.iter().sum::<f64>() / (n as f64)
256    } else {
257        0.0
258    };
259    // Apply window
260    let w = window_coefficients(window, n, 0.5);
261    let wp = window_power(&w);
262    let windowed: Vec<f64> = (0..n).map(|i| (x[i] - mean) * w[i]).collect();
263
264    let (real, imag) = rfft(&windowed);
265    let psd_vec = dft_to_psd(&real, &imag, n, fs, wp);
266    let n_out = psd_vec.len();
267    let freqs = Array1::from_vec((0..n_out).map(|k| (k as f64) * fs / (n as f64)).collect());
268
269    Ok(SpectralDensityResult {
270        frequencies: freqs,
271        psd: Array1::from_vec(psd_vec),
272        n_segments: 1,
273        bandwidth: fs / (n as f64),
274    })
275}
276
277// ---------------------------------------------------------------------------
278// Welch's method
279// ---------------------------------------------------------------------------
280
281/// Compute the power spectral density using Welch's method.
282///
283/// Divides the signal into overlapping segments, windows each, computes
284/// modified periodograms, and averages them.
285///
286/// # Arguments
287/// * `x` - Time series data
288/// * `segment_length` - Length of each segment (None for `n/8` rounded to power of 2)
289/// * `overlap` - Fraction of overlap between segments (default: 0.5)
290/// * `window` - Window function
291///
292/// # Example
293/// ```
294/// use scirs2_stats::spectral_density::{welch, Window};
295/// use scirs2_core::ndarray::Array1;
296///
297/// let n = 1024;
298/// let x = Array1::from_vec((0..n).map(|i| {
299///     (2.0 * std::f64::consts::PI * 0.25 * i as f64).sin() + ((i as f64) * 0.7).sin() * 0.1
300/// }).collect());
301/// let result = welch(&x.view(), Some(256), Some(0.5), Window::Hann)
302///     .expect("Welch failed");
303/// assert!(result.n_segments > 1);
304/// ```
305pub fn welch(
306    x: &ArrayView1<f64>,
307    segment_length: Option<usize>,
308    overlap: Option<f64>,
309    window: Window,
310) -> StatsResult<SpectralDensityResult> {
311    let n = x.len();
312    if n < 8 {
313        return Err(StatsError::InsufficientData(
314            "Welch's method requires at least 8 data points".into(),
315        ));
316    }
317    let seg_len = segment_length.unwrap_or_else(|| {
318        // Default: n/8, but at least 8
319        let target = n / 8;
320        target.max(8).min(n)
321    });
322    if seg_len < 4 || seg_len > n {
323        return Err(StatsError::InvalidArgument(format!(
324            "segment_length must be in [4, {}], got {}",
325            n, seg_len
326        )));
327    }
328    let overlap_frac = overlap.unwrap_or(0.5).max(0.0).min(0.99);
329    let step = ((seg_len as f64) * (1.0 - overlap_frac)).round() as usize;
330    let step = step.max(1);
331
332    let fs = 1.0;
333    let w = window_coefficients(window, seg_len, 0.5);
334    let wp = window_power(&w);
335
336    let n_freq = seg_len / 2 + 1;
337    let mut avg_psd = vec![0.0_f64; n_freq];
338    let mut n_segments = 0_usize;
339
340    let mut start = 0;
341    while start + seg_len <= n {
342        // Extract segment and detrend
343        let mean: f64 = (start..start + seg_len).map(|i| x[i]).sum::<f64>() / (seg_len as f64);
344        let windowed: Vec<f64> = (0..seg_len).map(|i| (x[start + i] - mean) * w[i]).collect();
345        let (real, imag) = rfft(&windowed);
346        let psd = dft_to_psd(&real, &imag, seg_len, fs, wp);
347        for k in 0..n_freq {
348            avg_psd[k] += psd[k];
349        }
350        n_segments += 1;
351        start += step;
352    }
353
354    if n_segments == 0 {
355        return Err(StatsError::ComputationError(
356            "Welch: no segments could be formed".into(),
357        ));
358    }
359
360    for k in 0..n_freq {
361        avg_psd[k] /= n_segments as f64;
362    }
363
364    let freqs = Array1::from_vec(
365        (0..n_freq)
366            .map(|k| (k as f64) * fs / (seg_len as f64))
367            .collect(),
368    );
369
370    Ok(SpectralDensityResult {
371        frequencies: freqs,
372        psd: Array1::from_vec(avg_psd),
373        n_segments,
374        bandwidth: fs / (seg_len as f64),
375    })
376}
377
378// ---------------------------------------------------------------------------
379// Bartlett's method
380// ---------------------------------------------------------------------------
381
382/// Compute the power spectral density using Bartlett's method.
383///
384/// Similar to Welch's method but with no overlap and a rectangular window.
385///
386/// # Arguments
387/// * `x` - Time series data
388/// * `n_segments` - Number of non-overlapping segments
389///
390/// # Example
391/// ```
392/// use scirs2_stats::spectral_density::bartlett;
393/// use scirs2_core::ndarray::Array1;
394///
395/// let n = 256;
396/// let x = Array1::from_vec((0..n).map(|i| {
397///     (2.0 * std::f64::consts::PI * 0.1 * i as f64).sin()
398/// }).collect());
399/// let result = bartlett(&x.view(), 4).expect("Bartlett failed");
400/// assert_eq!(result.n_segments, 4);
401/// ```
402pub fn bartlett(x: &ArrayView1<f64>, n_segments: usize) -> StatsResult<SpectralDensityResult> {
403    let n = x.len();
404    if n_segments == 0 || n_segments > n {
405        return Err(StatsError::InvalidArgument(format!(
406            "n_segments must be in [1, {}]",
407            n
408        )));
409    }
410    let seg_len = n / n_segments;
411    if seg_len < 4 {
412        return Err(StatsError::InsufficientData(
413            "Bartlett: segments too short (< 4 points each)".into(),
414        ));
415    }
416    // Bartlett = Welch with rectangular window and no overlap
417    welch(x, Some(seg_len), Some(0.0), Window::Rectangular)
418}
419
420// ---------------------------------------------------------------------------
421// Cross-spectral density
422// ---------------------------------------------------------------------------
423
424/// Compute the cross-spectral density of two time series.
425///
426/// Uses Welch's method to estimate the cross-spectrum Sxy(f) = E[X*(f) Y(f)].
427///
428/// # Arguments
429/// * `x` - First time series
430/// * `y` - Second time series
431/// * `segment_length` - Segment length (None for auto)
432/// * `overlap` - Overlap fraction (default 0.5)
433/// * `window` - Window function
434///
435/// # Example
436/// ```
437/// use scirs2_stats::spectral_density::{cross_spectral_density, Window};
438/// use scirs2_core::ndarray::Array1;
439///
440/// let n = 256;
441/// let x = Array1::from_vec((0..n).map(|i| {
442///     (2.0 * std::f64::consts::PI * 0.1 * i as f64).sin()
443/// }).collect());
444/// let y = Array1::from_vec((0..n).map(|i| {
445///     (2.0 * std::f64::consts::PI * 0.1 * i as f64 + 0.5).sin()
446/// }).collect());
447/// let result = cross_spectral_density(&x.view(), &y.view(), Some(64), Some(0.5), Window::Hann)
448///     .expect("CSD failed");
449/// assert_eq!(result.frequencies.len(), result.csd_magnitude.len());
450/// ```
451pub fn cross_spectral_density(
452    x: &ArrayView1<f64>,
453    y: &ArrayView1<f64>,
454    segment_length: Option<usize>,
455    overlap: Option<f64>,
456    window: Window,
457) -> StatsResult<CrossSpectralResult> {
458    let n = x.len();
459    if n != y.len() {
460        return Err(StatsError::DimensionMismatch(format!(
461            "x and y must have the same length (got {} and {})",
462            n,
463            y.len()
464        )));
465    }
466    if n < 8 {
467        return Err(StatsError::InsufficientData(
468            "cross-spectral density requires at least 8 data points".into(),
469        ));
470    }
471    let seg_len = segment_length.unwrap_or_else(|| (n / 8).max(8).min(n));
472    if seg_len < 4 || seg_len > n {
473        return Err(StatsError::InvalidArgument(format!(
474            "segment_length must be in [4, {}]",
475            n
476        )));
477    }
478    let overlap_frac = overlap.unwrap_or(0.5).max(0.0).min(0.99);
479    let step = ((seg_len as f64) * (1.0 - overlap_frac)).round() as usize;
480    let step = step.max(1);
481
482    let fs = 1.0;
483    let w = window_coefficients(window, seg_len, 0.5);
484    let wp = window_power(&w);
485
486    let n_freq = seg_len / 2 + 1;
487    let mut avg_csd_re = vec![0.0_f64; n_freq];
488    let mut avg_csd_im = vec![0.0_f64; n_freq];
489    let mut avg_psd_x = vec![0.0_f64; n_freq];
490    let mut avg_psd_y = vec![0.0_f64; n_freq];
491    let mut n_seg = 0_usize;
492
493    let mut start = 0;
494    while start + seg_len <= n {
495        let x_mean: f64 = (start..start + seg_len).map(|i| x[i]).sum::<f64>() / (seg_len as f64);
496        let y_mean: f64 = (start..start + seg_len).map(|i| y[i]).sum::<f64>() / (seg_len as f64);
497
498        let wx: Vec<f64> = (0..seg_len)
499            .map(|i| (x[start + i] - x_mean) * w[i])
500            .collect();
501        let wy: Vec<f64> = (0..seg_len)
502            .map(|i| (y[start + i] - y_mean) * w[i])
503            .collect();
504
505        let (xr, xi) = rfft(&wx);
506        let (yr, yi) = rfft(&wy);
507
508        let scale = 1.0 / (fs * (seg_len as f64) * wp);
509        for k in 0..n_freq {
510            // Cross: conj(X) * Y = (xr - j*xi_neg)(yr + j*yi) but xi stored as -sin
511            // X* = (xr, -xi), Y = (yr, yi)
512            // X* * Y = (xr*yr + xi*yi) + j*(xr*yi - xi*yr)
513            // But our rfft stores imag as -sin component, so conj(X) has imag = +xi
514            let csd_re = (xr[k] * yr[k] + xi[k] * yi[k]) * scale;
515            let csd_im = (xr[k] * yi[k] - xi[k] * yr[k]) * scale;
516            let psd_x = (xr[k] * xr[k] + xi[k] * xi[k]) * scale;
517            let psd_y = (yr[k] * yr[k] + yi[k] * yi[k]) * scale;
518            let double = if k > 0 && k < n_freq - 1 { 2.0 } else { 1.0 };
519            avg_csd_re[k] += csd_re * double;
520            avg_csd_im[k] += csd_im * double;
521            avg_psd_x[k] += psd_x * double;
522            avg_psd_y[k] += psd_y * double;
523        }
524        n_seg += 1;
525        start += step;
526    }
527
528    if n_seg == 0 {
529        return Err(StatsError::ComputationError(
530            "no segments formed for cross-spectral density".into(),
531        ));
532    }
533
534    let ns = n_seg as f64;
535    let mut magnitude = vec![0.0_f64; n_freq];
536    let mut phase = vec![0.0_f64; n_freq];
537    for k in 0..n_freq {
538        avg_csd_re[k] /= ns;
539        avg_csd_im[k] /= ns;
540        avg_psd_x[k] /= ns;
541        avg_psd_y[k] /= ns;
542        magnitude[k] = (avg_csd_re[k] * avg_csd_re[k] + avg_csd_im[k] * avg_csd_im[k]).sqrt();
543        phase[k] = avg_csd_im[k].atan2(avg_csd_re[k]);
544    }
545
546    let freqs = Array1::from_vec(
547        (0..n_freq)
548            .map(|k| (k as f64) * fs / (seg_len as f64))
549            .collect(),
550    );
551
552    Ok(CrossSpectralResult {
553        frequencies: freqs,
554        csd_real: Array1::from_vec(avg_csd_re),
555        csd_imag: Array1::from_vec(avg_csd_im),
556        csd_magnitude: Array1::from_vec(magnitude),
557        csd_phase: Array1::from_vec(phase),
558        psd_x: Array1::from_vec(avg_psd_x),
559        psd_y: Array1::from_vec(avg_psd_y),
560    })
561}
562
563// ---------------------------------------------------------------------------
564// Coherence function
565// ---------------------------------------------------------------------------
566
567/// Compute the squared coherence and phase spectrum between two series.
568///
569/// The squared coherence is |Sxy(f)|^2 / (Sxx(f) * Syy(f)), ranging in [0, 1].
570/// A value near 1 indicates strong linear relationship at that frequency.
571///
572/// # Arguments
573/// * `x` - First time series
574/// * `y` - Second time series
575/// * `segment_length` - Segment length for Welch (None for auto)
576/// * `overlap` - Overlap fraction (default 0.5)
577/// * `window` - Window function
578///
579/// # Example
580/// ```
581/// use scirs2_stats::spectral_density::{coherence, Window};
582/// use scirs2_core::ndarray::Array1;
583///
584/// let n = 256;
585/// let x = Array1::from_vec((0..n).map(|i| {
586///     (2.0 * std::f64::consts::PI * 0.1 * i as f64).sin()
587/// }).collect());
588/// // y is a phase-shifted version of x => high coherence
589/// let y = Array1::from_vec((0..n).map(|i| {
590///     (2.0 * std::f64::consts::PI * 0.1 * i as f64 + 1.0).sin()
591/// }).collect());
592/// let result = coherence(&x.view(), &y.view(), Some(64), Some(0.5), Window::Hann)
593///     .expect("coherence failed");
594/// // At the signal frequency, coherence should be high
595/// let peak_idx = result.coherence_sq.iter()
596///     .enumerate()
597///     .skip(1)
598///     .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
599///     .map(|(i, _)| i)
600///     .unwrap_or(0);
601/// assert!(result.coherence_sq[peak_idx] > 0.5);
602/// ```
603pub fn coherence(
604    x: &ArrayView1<f64>,
605    y: &ArrayView1<f64>,
606    segment_length: Option<usize>,
607    overlap: Option<f64>,
608    window: Window,
609) -> StatsResult<CoherenceResult> {
610    let csd = cross_spectral_density(x, y, segment_length, overlap, window)?;
611    let n_freq = csd.frequencies.len();
612    let mut coh_sq = Array1::<f64>::zeros(n_freq);
613    let mut phase = Array1::<f64>::zeros(n_freq);
614    let mut gain = Array1::<f64>::zeros(n_freq);
615
616    for k in 0..n_freq {
617        let sxy_sq = csd.csd_real[k] * csd.csd_real[k] + csd.csd_imag[k] * csd.csd_imag[k];
618        let denom = csd.psd_x[k] * csd.psd_y[k];
619        coh_sq[k] = if denom > 1e-30 {
620            (sxy_sq / denom).min(1.0)
621        } else {
622            0.0
623        };
624        phase[k] = csd.csd_phase[k];
625        gain[k] = if csd.psd_x[k] > 1e-30 {
626            csd.csd_magnitude[k] / csd.psd_x[k]
627        } else {
628            0.0
629        };
630    }
631
632    Ok(CoherenceResult {
633        frequencies: csd.frequencies,
634        coherence_sq: coh_sq,
635        phase,
636        gain,
637    })
638}
639
640// ---------------------------------------------------------------------------
641// Spectral Granger causality helper
642// ---------------------------------------------------------------------------
643
644/// Compute a spectral Granger causality measure between two series.
645///
646/// This is a frequency-domain decomposition of Granger causality based on
647/// comparing the spectral density of the restricted model (univariate AR) with
648/// the full model (bivariate VAR).
649///
650/// The measure at each frequency f is:
651///   GC_{x->y}(f) = ln(S_y(f) / S_y|x(f))
652///
653/// where S_y is the spectrum of y from a univariate AR, and S_y|x is the
654/// spectrum of y from the bivariate VAR residuals.
655///
656/// # Arguments
657/// * `x` - First time series (potential cause)
658/// * `y` - Second time series (potential effect)
659/// * `max_lags` - Maximum number of AR/VAR lags
660/// * `segment_length` - Segment length for spectral estimation
661///
662/// # Example
663/// ```
664/// use scirs2_stats::spectral_density::spectral_granger_causality;
665/// use scirs2_core::ndarray::Array1;
666///
667/// let n = 200;
668/// // x leads y by a few samples
669/// let x = Array1::from_vec((0..n).map(|i| ((i as f64) * 0.3).sin()).collect());
670/// let mut y_vec = vec![0.0_f64; n];
671/// for i in 3..n {
672///     y_vec[i] = 0.7 * x[i-3] + ((i as f64) * 0.5).sin() * 0.3;
673/// }
674/// let y = Array1::from_vec(y_vec);
675/// let result = spectral_granger_causality(&x.view(), &y.view(), 5, Some(64))
676///     .expect("spectral GC failed");
677/// assert_eq!(result.frequencies.len(), result.causality_x_to_y.len());
678/// ```
679pub fn spectral_granger_causality(
680    x: &ArrayView1<f64>,
681    y: &ArrayView1<f64>,
682    max_lags: usize,
683    segment_length: Option<usize>,
684) -> StatsResult<SpectralGrangerResult> {
685    let n = x.len();
686    if n != y.len() {
687        return Err(StatsError::DimensionMismatch(
688            "x and y must have the same length".into(),
689        ));
690    }
691    if n < max_lags + 10 {
692        return Err(StatsError::InsufficientData(
693            "insufficient data for spectral Granger causality".into(),
694        ));
695    }
696
697    // Fit univariate AR(p) for y
698    let resid_y_only = fit_ar_residuals(y, max_lags)?;
699    // Fit bivariate VAR(p) for (x->y direction): y_t = sum a_i*y_{t-i} + b_i*x_{t-i} + e_t
700    let resid_y_full = fit_var_residuals(x, y, max_lags)?;
701    // Similarly for x direction
702    let resid_x_only = fit_ar_residuals(x, max_lags)?;
703    let resid_x_full = fit_var_residuals(y, x, max_lags)?;
704
705    // Compute spectral densities of residuals
706    let seg_len = segment_length.unwrap_or_else(|| (n / 8).max(8).min(n));
707    let spec_y_only = welch(&resid_y_only.view(), Some(seg_len), Some(0.5), Window::Hann)?;
708    let spec_y_full = welch(&resid_y_full.view(), Some(seg_len), Some(0.5), Window::Hann)?;
709    let spec_x_only = welch(&resid_x_only.view(), Some(seg_len), Some(0.5), Window::Hann)?;
710    let spec_x_full = welch(&resid_x_full.view(), Some(seg_len), Some(0.5), Window::Hann)?;
711
712    let n_freq = spec_y_only.psd.len().min(spec_y_full.psd.len());
713    let n_freq = n_freq.min(spec_x_only.psd.len()).min(spec_x_full.psd.len());
714
715    let mut gc_x_to_y = Array1::<f64>::zeros(n_freq);
716    let mut gc_y_to_x = Array1::<f64>::zeros(n_freq);
717    let mut total = Array1::<f64>::zeros(n_freq);
718
719    for k in 0..n_freq {
720        let ratio_xy = spec_y_only.psd[k] / spec_y_full.psd[k].max(1e-30);
721        gc_x_to_y[k] = ratio_xy.max(1.0).ln();
722        let ratio_yx = spec_x_only.psd[k] / spec_x_full.psd[k].max(1e-30);
723        gc_y_to_x[k] = ratio_yx.max(1.0).ln();
724        total[k] = gc_x_to_y[k] + gc_y_to_x[k];
725    }
726
727    let freqs = Array1::from_vec((0..n_freq).map(|k| (k as f64) / (seg_len as f64)).collect());
728
729    Ok(SpectralGrangerResult {
730        frequencies: freqs,
731        causality_x_to_y: gc_x_to_y,
732        causality_y_to_x: gc_y_to_x,
733        total_interdependence: total,
734    })
735}
736
737/// Fit a univariate AR(p) model and return residuals.
738fn fit_ar_residuals(y: &ArrayView1<f64>, p: usize) -> StatsResult<Array1<f64>> {
739    let n = y.len();
740    if n <= p + 1 {
741        return Err(StatsError::InsufficientData(
742            "too few observations for AR model".into(),
743        ));
744    }
745    let n_eff = n - p;
746    // Design: [y_{t-1}, y_{t-2}, ..., y_{t-p}, 1]
747    let n_reg = p + 1;
748    let mut design = scirs2_core::ndarray::Array2::<f64>::zeros((n_eff, n_reg));
749    let dep = Array1::from_vec((p..n).map(|i| y[i]).collect());
750    for i in 0..n_eff {
751        for lag in 1..=p {
752            design[[i, lag - 1]] = y[p + i - lag];
753        }
754        design[[i, p]] = 1.0; // constant
755    }
756    let ols = crate::stationarity::ols_regression(&dep.view(), &design)?;
757    Ok(ols.residuals)
758}
759
760/// Fit a bivariate VAR equation: z_t = sum a_i*z_{t-i} + b_i*cause_{t-i} + c + e_t
761/// Returns the residuals for the effect variable.
762fn fit_var_residuals(
763    cause: &ArrayView1<f64>,
764    effect: &ArrayView1<f64>,
765    p: usize,
766) -> StatsResult<Array1<f64>> {
767    let n = cause.len();
768    if n <= p + 1 {
769        return Err(StatsError::InsufficientData(
770            "too few observations for VAR model".into(),
771        ));
772    }
773    let n_eff = n - p;
774    let n_reg = 2 * p + 1; // p lags of effect + p lags of cause + constant
775    let mut design = scirs2_core::ndarray::Array2::<f64>::zeros((n_eff, n_reg));
776    let dep = Array1::from_vec((p..n).map(|i| effect[i]).collect());
777    for i in 0..n_eff {
778        let mut col = 0;
779        for lag in 1..=p {
780            design[[i, col]] = effect[p + i - lag];
781            col += 1;
782        }
783        for lag in 1..=p {
784            design[[i, col]] = cause[p + i - lag];
785            col += 1;
786        }
787        design[[i, col]] = 1.0;
788    }
789    let ols = crate::stationarity::ols_regression(&dep.view(), &design)?;
790    Ok(ols.residuals)
791}
792
793// ---------------------------------------------------------------------------
794// Tests
795// ---------------------------------------------------------------------------
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800    use scirs2_core::ndarray::Array1;
801
802    fn make_sine(n: usize, freq: f64) -> Array1<f64> {
803        Array1::from_vec(
804            (0..n)
805                .map(|i| (2.0 * PI * freq * (i as f64)).sin())
806                .collect(),
807        )
808    }
809
810    fn make_noise(n: usize) -> Array1<f64> {
811        Array1::from_vec(
812            (0..n)
813                .map(|i| ((i as f64) * 2.7 + 0.3).sin() * 0.5)
814                .collect(),
815        )
816    }
817
818    #[test]
819    fn test_periodogram_pure_sine() {
820        let x = make_sine(256, 0.1);
821        let result = periodogram(&x.view(), Window::Hann, true);
822        assert!(result.is_ok());
823        let r = result.expect("periodogram should succeed");
824        assert_eq!(r.frequencies.len(), 129); // 256/2 + 1
825        assert_eq!(r.psd.len(), 129);
826        // Find peak
827        let peak_idx = r
828            .psd
829            .iter()
830            .enumerate()
831            .skip(1)
832            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
833            .map(|(i, _)| i)
834            .unwrap_or(0);
835        assert!((r.frequencies[peak_idx] - 0.1).abs() < 0.02);
836    }
837
838    #[test]
839    fn test_periodogram_rectangular() {
840        let x = make_sine(128, 0.2);
841        let result = periodogram(&x.view(), Window::Rectangular, false);
842        assert!(result.is_ok());
843    }
844
845    #[test]
846    fn test_periodogram_blackman() {
847        let x = make_sine(128, 0.15);
848        let result = periodogram(&x.view(), Window::Blackman, true);
849        assert!(result.is_ok());
850    }
851
852    #[test]
853    fn test_periodogram_insufficient() {
854        let x = Array1::from_vec(vec![1.0, 2.0]);
855        let result = periodogram(&x.view(), Window::Hann, true);
856        assert!(result.is_err());
857    }
858
859    #[test]
860    fn test_welch_basic() {
861        let x = make_sine(512, 0.1);
862        let result = welch(&x.view(), Some(128), Some(0.5), Window::Hann);
863        assert!(result.is_ok());
864        let r = result.expect("Welch should succeed");
865        assert!(r.n_segments > 1);
866        assert_eq!(r.psd.len(), 65); // 128/2 + 1
867    }
868
869    #[test]
870    fn test_welch_auto_segment() {
871        let x = make_sine(1024, 0.25);
872        let result = welch(&x.view(), None, None, Window::Hamming);
873        assert!(result.is_ok());
874        let r = result.expect("Welch auto should succeed");
875        assert!(r.n_segments >= 1);
876    }
877
878    #[test]
879    fn test_welch_peak_detection() {
880        let x = make_sine(1024, 0.1);
881        let r = welch(&x.view(), Some(256), Some(0.5), Window::Hann).expect("Welch should succeed");
882        let peak_idx = r
883            .psd
884            .iter()
885            .enumerate()
886            .skip(1)
887            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
888            .map(|(i, _)| i)
889            .unwrap_or(0);
890        assert!((r.frequencies[peak_idx] - 0.1).abs() < 0.01);
891    }
892
893    #[test]
894    fn test_bartlett_basic() {
895        let x = make_sine(256, 0.1);
896        let result = bartlett(&x.view(), 4);
897        assert!(result.is_ok());
898        let r = result.expect("Bartlett should succeed");
899        assert!(r.n_segments >= 1);
900    }
901
902    #[test]
903    fn test_bartlett_invalid_segments() {
904        let x = make_sine(16, 0.1);
905        let result = bartlett(&x.view(), 0);
906        assert!(result.is_err());
907    }
908
909    #[test]
910    fn test_cross_spectral_density_basic() {
911        let x = make_sine(256, 0.1);
912        let y = make_sine(256, 0.1); // same frequency
913        let result =
914            cross_spectral_density(&x.view(), &y.view(), Some(64), Some(0.5), Window::Hann);
915        assert!(result.is_ok());
916        let r = result.expect("CSD should succeed");
917        assert_eq!(r.csd_magnitude.len(), r.frequencies.len());
918    }
919
920    #[test]
921    fn test_cross_spectral_density_different_freqs() {
922        let x = make_sine(256, 0.1);
923        let y = make_sine(256, 0.3);
924        let result =
925            cross_spectral_density(&x.view(), &y.view(), Some(64), Some(0.5), Window::Hann);
926        assert!(result.is_ok());
927    }
928
929    #[test]
930    fn test_cross_spectral_density_length_mismatch() {
931        let x = make_sine(100, 0.1);
932        let y = make_sine(200, 0.1);
933        let result = cross_spectral_density(&x.view(), &y.view(), None, None, Window::Hann);
934        assert!(result.is_err());
935    }
936
937    #[test]
938    fn test_coherence_same_signal() {
939        let x = make_sine(256, 0.1);
940        let result = coherence(&x.view(), &x.view(), Some(64), Some(0.5), Window::Hann);
941        assert!(result.is_ok());
942        let r = result.expect("coherence should succeed");
943        // Coherence of a signal with itself should be very high
944        let max_coh = r
945            .coherence_sq
946            .iter()
947            .skip(1)
948            .cloned()
949            .fold(0.0_f64, f64::max);
950        assert!(max_coh > 0.9);
951    }
952
953    #[test]
954    fn test_coherence_values_bounded() {
955        let x = make_sine(256, 0.1);
956        let y = make_noise(256);
957        let r = coherence(&x.view(), &y.view(), Some(64), Some(0.5), Window::Hann)
958            .expect("coherence should succeed");
959        for &c in r.coherence_sq.iter() {
960            assert!(c >= 0.0, "coherence must be >= 0, got {}", c);
961            assert!(c <= 1.0 + 1e-10, "coherence must be <= 1, got {}", c);
962        }
963    }
964
965    #[test]
966    fn test_spectral_granger_causality() {
967        let n = 200;
968        let x = make_sine(n, 0.1);
969        let mut y_vec = vec![0.0_f64; n];
970        for i in 3..n {
971            y_vec[i] = 0.7 * x[i - 3] + ((i as f64) * 0.5).sin() * 0.3;
972        }
973        let y = Array1::from_vec(y_vec);
974        let result = spectral_granger_causality(&x.view(), &y.view(), 5, Some(32));
975        assert!(result.is_ok());
976        let r = result.expect("spectral GC should succeed");
977        assert_eq!(r.causality_x_to_y.len(), r.frequencies.len());
978        // All GC values should be non-negative
979        for &gc in r.causality_x_to_y.iter() {
980            assert!(gc >= 0.0, "GC should be non-negative, got {}", gc);
981        }
982    }
983
984    #[test]
985    fn test_window_coefficients_hann() {
986        let w = window_coefficients(Window::Hann, 8, 0.5);
987        assert_eq!(w.len(), 8);
988        // Hann window starts and ends at 0
989        assert!((w[0]).abs() < 1e-10);
990        assert!((w[7]).abs() < 1e-10);
991    }
992
993    #[test]
994    fn test_window_coefficients_rectangular() {
995        let w = window_coefficients(Window::Rectangular, 10, 0.5);
996        for &v in w.iter() {
997            assert!((v - 1.0).abs() < 1e-10);
998        }
999    }
1000
1001    #[test]
1002    fn test_window_coefficients_bartlett() {
1003        let w = window_coefficients(Window::Bartlett, 5, 0.5);
1004        // Bartlett is triangular, peaks at center
1005        assert!(w[2] > w[0]);
1006        assert!(w[2] > w[4]);
1007    }
1008
1009    #[test]
1010    fn test_psd_non_negative() {
1011        let x = make_noise(128);
1012        let r = periodogram(&x.view(), Window::Hann, true).expect("periodogram should succeed");
1013        for &p in r.psd.iter() {
1014            assert!(p >= 0.0, "PSD must be non-negative, got {}", p);
1015        }
1016    }
1017
1018    #[test]
1019    fn test_spectral_granger_insufficient() {
1020        let x = Array1::from_vec(vec![1.0; 5]);
1021        let y = Array1::from_vec(vec![2.0; 5]);
1022        let result = spectral_granger_causality(&x.view(), &y.view(), 10, None);
1023        assert!(result.is_err());
1024    }
1025
1026    #[test]
1027    fn test_welch_overlap_zero() {
1028        let x = make_sine(256, 0.1);
1029        let result = welch(&x.view(), Some(64), Some(0.0), Window::Hann);
1030        assert!(result.is_ok());
1031        let r = result.expect("Welch with 0 overlap should succeed");
1032        assert_eq!(r.n_segments, 4); // 256/64 = 4
1033    }
1034
1035    #[test]
1036    fn test_tukey_window() {
1037        let w = window_coefficients(Window::Tukey, 100, 0.5);
1038        assert_eq!(w.len(), 100);
1039        // Middle should be 1.0
1040        assert!((w[50] - 1.0).abs() < 0.01);
1041    }
1042}