Skip to main content

scirs2_graph/signal_processing/
gsp.rs

1//! Graph Signal Processing (GSP) — spectral and wavelet methods on graphs.
2//!
3//! This module implements the core GSP toolkit:
4//! - **Graph Fourier Transform (GFT)** and its inverse using Laplacian eigenvectors.
5//! - **Spectral graph filters** — ideal low-pass, high-pass, band-pass.
6//! - **Graph Wavelets** via diffusion (heat-kernel wavelets).
7//! - **Graph Signal Smoother** via Tikhonov (graph-Laplacian) regularization.
8//!
9//! All algorithms operate on `Array2<f64>` weighted adjacency matrices and
10//! `Array1<f64>` graph signals (one value per node).
11//!
12//! ## Mathematical Background
13//!
14//! Let `L = D − A` be the combinatorial graph Laplacian with eigendecomposition
15//! `L = U Λ Uᵀ`.  The **Graph Fourier Transform** of a signal `x` is
16//! `x̂ = Uᵀ x` and the inverse is `x = U x̂`.  Spectral filters are applied
17//! by multiplying `x̂` component-wise: `ŷ = h(Λ) x̂`.
18//!
19//! ## Example
20//! ```rust,no_run
21//! use scirs2_core::ndarray::{Array1, Array2};
22//! use scirs2_graph::signal_processing::gsp::{GraphFourierTransform, IdealLowPass, GraphFilter};
23//!
24//! // Path graph: 0-1-2-3
25//! let mut adj = Array2::<f64>::zeros((4, 4));
26//! adj[[0,1]] = 1.0; adj[[1,0]] = 1.0;
27//! adj[[1,2]] = 1.0; adj[[2,1]] = 1.0;
28//! adj[[2,3]] = 1.0; adj[[3,2]] = 1.0;
29//!
30//! let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
31//! let signal = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
32//! let freq = gft.transform(&signal).unwrap();
33//! let rec  = gft.inverse(&freq).unwrap();
34//!
35//! // Low-pass filter retaining lowest 2 frequency components
36//! let lp = IdealLowPass::new(2);
37//! let smoothed = lp.apply(&gft, &signal).unwrap();
38//! ```
39
40use scirs2_core::ndarray::{Array1, Array2};
41
42use crate::error::{GraphError, Result};
43use crate::spectral_graph::graph_laplacian;
44
45// ─────────────────────────────────────────────────────────────────────────────
46// Helpers: symmetric tridiagonal eigendecomposition via Jacobi iterations
47// ─────────────────────────────────────────────────────────────────────────────
48
49/// Compute eigenvalues and eigenvectors of a real symmetric matrix via the
50/// classical Jacobi iteration method.
51///
52/// Returns `(eigenvalues, eigenvectors)` where eigenvectors are stored as
53/// columns of the returned matrix.
54fn symmetric_eigen(a: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
55    let n = a.nrows();
56    if n == 0 {
57        return Err(GraphError::InvalidGraph("empty matrix".into()));
58    }
59    if a.ncols() != n {
60        return Err(GraphError::InvalidGraph("matrix must be square".into()));
61    }
62
63    // Work copy
64    let mut m = a.clone();
65    // Accumulate rotations in V (starts as identity)
66    let mut v = Array2::<f64>::eye(n);
67
68    const MAX_SWEEPS: usize = 500;
69    const TOL: f64 = 1e-12;
70
71    for _ in 0..MAX_SWEEPS {
72        // Find the largest off-diagonal element
73        let mut max_val = 0.0_f64;
74        let mut p = 0_usize;
75        let mut q = 1_usize;
76        for i in 0..n {
77            for j in (i + 1)..n {
78                let v_ij = m[[i, j]].abs();
79                if v_ij > max_val {
80                    max_val = v_ij;
81                    p = i;
82                    q = j;
83                }
84            }
85        }
86        if max_val < TOL {
87            break;
88        }
89
90        // Compute Jacobi rotation angle
91        let theta = if (m[[q, q]] - m[[p, p]]).abs() < TOL {
92            std::f64::consts::FRAC_PI_4
93        } else {
94            0.5 * ((2.0 * m[[p, q]]) / (m[[q, q]] - m[[p, p]])).atan()
95        };
96        let cos_t = theta.cos();
97        let sin_t = theta.sin();
98
99        // Apply rotation: M' = R^T M R, V' = V R
100        // Update rows / cols p and q of m
101        let mut new_m = m.clone();
102        for r in 0..n {
103            if r != p && r != q {
104                new_m[[r, p]] = cos_t * m[[r, p]] - sin_t * m[[r, q]];
105                new_m[[p, r]] = new_m[[r, p]];
106                new_m[[r, q]] = sin_t * m[[r, p]] + cos_t * m[[r, q]];
107                new_m[[q, r]] = new_m[[r, q]];
108            }
109        }
110        new_m[[p, p]] =
111            cos_t * cos_t * m[[p, p]] - 2.0 * sin_t * cos_t * m[[p, q]] + sin_t * sin_t * m[[q, q]];
112        new_m[[q, q]] =
113            sin_t * sin_t * m[[p, p]] + 2.0 * sin_t * cos_t * m[[p, q]] + cos_t * cos_t * m[[q, q]];
114        new_m[[p, q]] = 0.0;
115        new_m[[q, p]] = 0.0;
116        m = new_m;
117
118        // Update eigenvector matrix
119        let v_old = v.clone();
120        for r in 0..n {
121            v[[r, p]] = cos_t * v_old[[r, p]] - sin_t * v_old[[r, q]];
122            v[[r, q]] = sin_t * v_old[[r, p]] + cos_t * v_old[[r, q]];
123        }
124    }
125
126    // Eigenvalues are diagonal entries of m
127    let eigenvalues = Array1::from_iter((0..n).map(|i| m[[i, i]]));
128
129    // Sort eigenvalues (and eigenvectors) in ascending order
130    let mut idx: Vec<usize> = (0..n).collect();
131    idx.sort_by(|&a, &b| {
132        eigenvalues[a]
133            .partial_cmp(&eigenvalues[b])
134            .unwrap_or(std::cmp::Ordering::Equal)
135    });
136
137    let sorted_evals = Array1::from_iter(idx.iter().map(|&i| eigenvalues[i]));
138    let mut sorted_evecs = Array2::<f64>::zeros((n, n));
139    for (new_col, &old_col) in idx.iter().enumerate() {
140        for row in 0..n {
141            sorted_evecs[[row, new_col]] = v[[row, old_col]];
142        }
143    }
144
145    Ok((sorted_evals, sorted_evecs))
146}
147
148// ─────────────────────────────────────────────────────────────────────────────
149// GraphFourierTransform
150// ─────────────────────────────────────────────────────────────────────────────
151
152/// Graph Fourier Transform (GFT) based on the graph Laplacian eigenvectors.
153///
154/// The GFT projects a graph signal onto the frequency basis defined by the
155/// eigenvectors of the graph Laplacian `L = D − A`.  Low-frequency components
156/// correspond to smooth signals (slowly varying across edges); high-frequency
157/// components to rapidly oscillating signals.
158///
159/// # Fields
160/// - `eigenvalues` — sorted Laplacian spectrum (graph frequencies) λ₀ ≤ λ₁ ≤ …
161/// - `eigenvectors` — columns are eigenvectors (basis functions); shape `(n, n)`
162#[derive(Debug, Clone)]
163pub struct GraphFourierTransform {
164    /// Graph frequencies (Laplacian eigenvalues), sorted ascending.
165    pub eigenvalues: Array1<f64>,
166    /// Frequency basis: eigenvectors as columns, shape `(n, n)`.
167    pub eigenvectors: Array2<f64>,
168}
169
170impl GraphFourierTransform {
171    /// Build a GFT from a weighted adjacency matrix.
172    ///
173    /// Computes `L = D − A` and its full eigendecomposition.
174    pub fn from_adjacency(adj: &Array2<f64>) -> Result<Self> {
175        let n = adj.nrows();
176        if n == 0 {
177            return Err(GraphError::InvalidGraph("empty adjacency matrix".into()));
178        }
179        let lap = graph_laplacian(adj);
180        let (eigenvalues, eigenvectors) = symmetric_eigen(&lap)?;
181        Ok(Self {
182            eigenvalues,
183            eigenvectors,
184        })
185    }
186
187    /// Build a GFT directly from a precomputed Laplacian matrix.
188    pub fn from_laplacian(laplacian: &Array2<f64>) -> Result<Self> {
189        let (eigenvalues, eigenvectors) = symmetric_eigen(laplacian)?;
190        Ok(Self {
191            eigenvalues,
192            eigenvectors,
193        })
194    }
195
196    /// Number of nodes (= size of frequency basis).
197    pub fn num_nodes(&self) -> usize {
198        self.eigenvalues.len()
199    }
200
201    /// Forward GFT: `x̂ = Uᵀ x`.
202    ///
203    /// # Arguments
204    /// * `signal` — graph signal of length `n` (one value per node).
205    ///
206    /// # Returns
207    /// Spectral coefficients `x̂` of length `n`.
208    pub fn transform(&self, signal: &Array1<f64>) -> Result<Array1<f64>> {
209        let n = self.num_nodes();
210        if signal.len() != n {
211            return Err(GraphError::InvalidParameter {
212                param: "signal.len()".into(),
213                value: signal.len().to_string(),
214                expected: n.to_string(),
215                context: "GFT forward transform".into(),
216            });
217        }
218        // x̂_k = sum_i U[i,k] * x[i]  (Uᵀ applied to x)
219        let mut x_hat = Array1::<f64>::zeros(n);
220        for k in 0..n {
221            let mut acc = 0.0_f64;
222            for i in 0..n {
223                acc += self.eigenvectors[[i, k]] * signal[i];
224            }
225            x_hat[k] = acc;
226        }
227        Ok(x_hat)
228    }
229
230    /// Inverse GFT: `x = U x̂`.
231    ///
232    /// # Arguments
233    /// * `freq_signal` — spectral coefficients `x̂` of length `n`.
234    ///
235    /// # Returns
236    /// Reconstructed graph signal `x` of length `n`.
237    pub fn inverse(&self, freq_signal: &Array1<f64>) -> Result<Array1<f64>> {
238        let n = self.num_nodes();
239        if freq_signal.len() != n {
240            return Err(GraphError::InvalidParameter {
241                param: "freq_signal.len()".into(),
242                value: freq_signal.len().to_string(),
243                expected: n.to_string(),
244                context: "GFT inverse transform".into(),
245            });
246        }
247        // x_i = sum_k U[i,k] * x̂_k
248        let mut x = Array1::<f64>::zeros(n);
249        for i in 0..n {
250            let mut acc = 0.0_f64;
251            for k in 0..n {
252                acc += self.eigenvectors[[i, k]] * freq_signal[k];
253            }
254            x[i] = acc;
255        }
256        Ok(x)
257    }
258}
259
260// ─────────────────────────────────────────────────────────────────────────────
261// GraphFilter trait
262// ─────────────────────────────────────────────────────────────────────────────
263
264/// Trait for spectral graph filters.
265///
266/// A filter takes a graph signal and returns a filtered version by modifying
267/// the spectral coefficients according to a frequency-response function `h(λ)`.
268pub trait GraphFilter {
269    /// Apply the filter to `signal` using the precomputed GFT basis.
270    fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>>;
271
272    /// Return the frequency response `h(λ)` for each eigenvalue in `gft`.
273    fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64>;
274}
275
276// ─────────────────────────────────────────────────────────────────────────────
277// IdealLowPass
278// ─────────────────────────────────────────────────────────────────────────────
279
280/// Ideal low-pass spectral graph filter.
281///
282/// Retains the `k` lowest-frequency graph Fourier components and zeroes out
283/// all higher-frequency components.  This is the graph analogue of the ideal
284/// rectangular low-pass filter in classical DSP.
285#[derive(Debug, Clone)]
286pub struct IdealLowPass {
287    /// Number of low-frequency components to retain.
288    pub k: usize,
289}
290
291impl IdealLowPass {
292    /// Create a new ideal low-pass filter retaining `k` components.
293    pub fn new(k: usize) -> Self {
294        Self { k }
295    }
296}
297
298impl GraphFilter for IdealLowPass {
299    fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
300        let n = gft.num_nodes();
301        Array1::from_iter((0..n).map(|i| if i < self.k { 1.0 } else { 0.0 }))
302    }
303
304    fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
305        let x_hat = gft.transform(signal)?;
306        let h = self.frequency_response(gft);
307        let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
308        gft.inverse(&filtered_hat)
309    }
310}
311
312// ─────────────────────────────────────────────────────────────────────────────
313// IdealHighPass
314// ─────────────────────────────────────────────────────────────────────────────
315
316/// Ideal high-pass spectral graph filter.
317///
318/// Zeroes out the `k` lowest-frequency graph Fourier components and retains
319/// all higher-frequency components.  Useful for highlighting rapidly varying
320/// parts of a graph signal (e.g. edge features, anomalies).
321#[derive(Debug, Clone)]
322pub struct IdealHighPass {
323    /// Number of low-frequency components to suppress.
324    pub k: usize,
325}
326
327impl IdealHighPass {
328    /// Create a new ideal high-pass filter suppressing the `k` lowest components.
329    pub fn new(k: usize) -> Self {
330        Self { k }
331    }
332}
333
334impl GraphFilter for IdealHighPass {
335    fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
336        let n = gft.num_nodes();
337        Array1::from_iter((0..n).map(|i| if i < self.k { 0.0 } else { 1.0 }))
338    }
339
340    fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
341        let x_hat = gft.transform(signal)?;
342        let h = self.frequency_response(gft);
343        let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
344        gft.inverse(&filtered_hat)
345    }
346}
347
348// ─────────────────────────────────────────────────────────────────────────────
349// GraphBandpass
350// ─────────────────────────────────────────────────────────────────────────────
351
352/// Ideal band-pass spectral graph filter.
353///
354/// Retains only frequency components whose indices fall in `[low_k, high_k)`,
355/// i.e. the band between the `low_k`-th and `high_k`-th eigenvalue.
356#[derive(Debug, Clone)]
357pub struct GraphBandpass {
358    /// Inclusive lower index of the retained frequency band.
359    pub low_k: usize,
360    /// Exclusive upper index of the retained frequency band.
361    pub high_k: usize,
362}
363
364impl GraphBandpass {
365    /// Create a new band-pass filter for the frequency band `[low_k, high_k)`.
366    pub fn new(low_k: usize, high_k: usize) -> Self {
367        Self { low_k, high_k }
368    }
369}
370
371impl GraphFilter for GraphBandpass {
372    fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
373        let n = gft.num_nodes();
374        Array1::from_iter((0..n).map(|i| {
375            if i >= self.low_k && i < self.high_k {
376                1.0
377            } else {
378                0.0
379            }
380        }))
381    }
382
383    fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
384        let x_hat = gft.transform(signal)?;
385        let h = self.frequency_response(gft);
386        let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
387        gft.inverse(&filtered_hat)
388    }
389}
390
391// ─────────────────────────────────────────────────────────────────────────────
392// GraphWavelet (diffusion / heat-kernel wavelets)
393// ─────────────────────────────────────────────────────────────────────────────
394
395/// Graph wavelet based on the heat / diffusion kernel.
396///
397/// The diffusion kernel at scale `t` is `K_t = U exp(−t Λ) Uᵀ`.
398/// The wavelet at node `s` is the `s`-th column (or row) of `K_t`.
399/// This provides a spatially-localized, multi-scale representation of
400/// graph signals.
401///
402/// # Reference
403/// Hammond et al. (2011). "Wavelets on graphs via spectral graph theory."
404/// *Applied and Computational Harmonic Analysis*, 30(2), 129–150.
405#[derive(Debug, Clone)]
406pub struct GraphWavelet {
407    /// Diffusion scale parameter `t > 0`.
408    pub scale: f64,
409    /// Pre-computed kernel matrix `K_t = U exp(−t Λ) Uᵀ`, shape `(n, n)`.
410    kernel: Array2<f64>,
411}
412
413impl GraphWavelet {
414    /// Build the diffusion wavelet kernel at scale `t` from a GFT.
415    ///
416    /// # Arguments
417    /// * `gft` — precomputed graph Fourier transform.
418    /// * `scale` — diffusion scale `t > 0`.
419    pub fn new(gft: &GraphFourierTransform, scale: f64) -> Result<Self> {
420        if scale <= 0.0 {
421            return Err(GraphError::InvalidParameter {
422                param: "scale".into(),
423                value: scale.to_string(),
424                expected: "> 0".into(),
425                context: "GraphWavelet construction".into(),
426            });
427        }
428        let n = gft.num_nodes();
429        // h(λ) = exp(−t λ)
430        let h: Vec<f64> = gft
431            .eigenvalues
432            .iter()
433            .map(|&lam| (-scale * lam).exp())
434            .collect();
435
436        // K_t[i,j] = sum_k U[i,k] h[k] U[j,k]
437        let mut kernel = Array2::<f64>::zeros((n, n));
438        for i in 0..n {
439            for j in 0..n {
440                let mut acc = 0.0_f64;
441                for k in 0..n {
442                    acc += gft.eigenvectors[[i, k]] * h[k] * gft.eigenvectors[[j, k]];
443                }
444                kernel[[i, j]] = acc;
445            }
446        }
447        Ok(Self { scale, kernel })
448    }
449
450    /// Apply the wavelet kernel to a signal: `y = K_t x`.
451    pub fn apply(&self, signal: &Array1<f64>) -> Result<Array1<f64>> {
452        let n = self.kernel.nrows();
453        if signal.len() != n {
454            return Err(GraphError::InvalidParameter {
455                param: "signal.len()".into(),
456                value: signal.len().to_string(),
457                expected: n.to_string(),
458                context: "GraphWavelet apply".into(),
459            });
460        }
461        let mut out = Array1::<f64>::zeros(n);
462        for i in 0..n {
463            let mut acc = 0.0_f64;
464            for j in 0..n {
465                acc += self.kernel[[i, j]] * signal[j];
466            }
467            out[i] = acc;
468        }
469        Ok(out)
470    }
471
472    /// Return the wavelet atom centered at node `s` (column `s` of `K_t`).
473    pub fn wavelet_atom(&self, s: usize) -> Result<Array1<f64>> {
474        let n = self.kernel.nrows();
475        if s >= n {
476            return Err(GraphError::InvalidParameter {
477                param: "s".into(),
478                value: s.to_string(),
479                expected: format!("< {n}"),
480                context: "GraphWavelet atom".into(),
481            });
482        }
483        Ok(self.kernel.column(s).to_owned())
484    }
485
486    /// Return the full kernel matrix `K_t` (shape `n × n`).
487    pub fn kernel(&self) -> &Array2<f64> {
488        &self.kernel
489    }
490}
491
492// ─────────────────────────────────────────────────────────────────────────────
493// GraphSignalSmoother (Tikhonov / graph Laplacian regularization)
494// ─────────────────────────────────────────────────────────────────────────────
495
496/// Graph signal smoother via Tikhonov (graph Laplacian) regularization.
497///
498/// Solves the regularized least-squares problem:
499///
500///   minimize  ‖x − y‖² + α xᵀ L x
501///
502/// where `y` is the observed (potentially noisy) signal, `L` is the graph
503/// Laplacian, and `α > 0` is the regularization strength.
504///
505/// The closed-form solution is:
506///
507///   x* = (I + α L)⁻¹ y  =  U (I + α Λ)⁻¹ Uᵀ y
508///
509/// which can be computed efficiently using the GFT as a spectral filter with
510/// frequency response `h(λ) = 1 / (1 + α λ)`.
511#[derive(Debug, Clone)]
512pub struct GraphSignalSmoother {
513    /// Regularization strength α > 0.
514    pub alpha: f64,
515}
516
517impl GraphSignalSmoother {
518    /// Create a smoother with regularization strength `alpha`.
519    pub fn new(alpha: f64) -> Result<Self> {
520        if alpha <= 0.0 {
521            return Err(GraphError::InvalidParameter {
522                param: "alpha".into(),
523                value: alpha.to_string(),
524                expected: "> 0".into(),
525                context: "GraphSignalSmoother construction".into(),
526            });
527        }
528        Ok(Self { alpha })
529    }
530
531    /// Smooth the observed signal `y` using the GFT basis.
532    ///
533    /// Returns `x* = U (I + α Λ)⁻¹ Uᵀ y`.
534    pub fn smooth(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
535        let n = gft.num_nodes();
536        if signal.len() != n {
537            return Err(GraphError::InvalidParameter {
538                param: "signal.len()".into(),
539                value: signal.len().to_string(),
540                expected: n.to_string(),
541                context: "GraphSignalSmoother smooth".into(),
542            });
543        }
544        let y_hat = gft.transform(signal)?;
545        // Apply frequency response h(λ) = 1 / (1 + α λ)
546        let x_hat = Array1::from_iter(
547            y_hat
548                .iter()
549                .zip(gft.eigenvalues.iter())
550                .map(|(&c, &lam)| c / (1.0 + self.alpha * lam)),
551        );
552        gft.inverse(&x_hat)
553    }
554
555    /// Return the frequency response function h(λ) = 1/(1 + α λ) evaluated at
556    /// each graph frequency.
557    pub fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
558        Array1::from_iter(
559            gft.eigenvalues
560                .iter()
561                .map(|&lam| 1.0 / (1.0 + self.alpha * lam)),
562        )
563    }
564
565    /// Compute the total variation (graph smoothness) of a signal:
566    /// `TV(x) = xᵀ L x = Σ_{(i,j)∈E} w_ij (x_i − x_j)²`.
567    pub fn total_variation(adj: &Array2<f64>, signal: &Array1<f64>) -> Result<f64> {
568        let n = adj.nrows();
569        if signal.len() != n {
570            return Err(GraphError::InvalidParameter {
571                param: "signal.len()".into(),
572                value: signal.len().to_string(),
573                expected: n.to_string(),
574                context: "total_variation".into(),
575            });
576        }
577        let mut tv = 0.0_f64;
578        for i in 0..n {
579            for j in (i + 1)..n {
580                let w = adj[[i, j]];
581                if w != 0.0 {
582                    let diff = signal[i] - signal[j];
583                    tv += w * diff * diff;
584                }
585            }
586        }
587        Ok(tv)
588    }
589}
590
591// ─────────────────────────────────────────────────────────────────────────────
592// Tests
593// ─────────────────────────────────────────────────────────────────────────────
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use scirs2_core::ndarray::Array1;
599
600    fn path_graph_adj(n: usize) -> Array2<f64> {
601        let mut adj = Array2::<f64>::zeros((n, n));
602        for i in 0..(n - 1) {
603            adj[[i, i + 1]] = 1.0;
604            adj[[i + 1, i]] = 1.0;
605        }
606        adj
607    }
608
609    #[test]
610    fn test_gft_reconstruction() {
611        let adj = path_graph_adj(5);
612        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
613        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
614        let freq = gft.transform(&signal).unwrap();
615        let rec = gft.inverse(&freq).unwrap();
616        for (a, b) in signal.iter().zip(rec.iter()) {
617            assert!((a - b).abs() < 1e-9, "Reconstruction error: {a} vs {b}");
618        }
619    }
620
621    #[test]
622    fn test_low_pass_smoothing() {
623        let adj = path_graph_adj(6);
624        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
625        let signal = Array1::from_vec(vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0]);
626        let lp = IdealLowPass::new(2);
627        let smoothed = lp.apply(&gft, &signal).unwrap();
628        // The smoothed signal should have lower total variation
629        let tv_orig = GraphSignalSmoother::total_variation(&adj, &signal).unwrap();
630        let tv_smooth = GraphSignalSmoother::total_variation(&adj, &smoothed).unwrap();
631        assert!(
632            tv_smooth < tv_orig,
633            "LP filter should reduce TV: {tv_smooth} vs {tv_orig}"
634        );
635    }
636
637    #[test]
638    fn test_high_pass_removes_dc() {
639        let adj = path_graph_adj(5);
640        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
641        // Constant signal = DC component only
642        let dc_signal = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
643        let hp = IdealHighPass::new(1);
644        let out = hp.apply(&gft, &dc_signal).unwrap();
645        for v in out.iter() {
646            assert!(v.abs() < 1e-9, "HP filter should remove DC: got {v}");
647        }
648    }
649
650    #[test]
651    fn test_bandpass() {
652        let adj = path_graph_adj(8);
653        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
654        let signal = Array1::from_vec(vec![1.0, 0.5, 0.0, -0.5, -1.0, -0.5, 0.0, 0.5]);
655        let bp = GraphBandpass::new(2, 5);
656        let out = bp.apply(&gft, &signal).unwrap();
657        assert_eq!(out.len(), 8);
658    }
659
660    #[test]
661    fn test_wavelet_kernel_symmetry() {
662        let adj = path_graph_adj(5);
663        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
664        let wv = GraphWavelet::new(&gft, 1.0).unwrap();
665        let k = wv.kernel();
666        for i in 0..5 {
667            for j in 0..5 {
668                assert!((k[[i, j]] - k[[j, i]]).abs() < 1e-10);
669            }
670        }
671    }
672
673    #[test]
674    fn test_smoother_reduces_variation() {
675        let adj = path_graph_adj(6);
676        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
677        let noisy = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
678        let smoother = GraphSignalSmoother::new(5.0).unwrap();
679        let smoothed = smoother.smooth(&gft, &noisy).unwrap();
680        let tv_noisy = GraphSignalSmoother::total_variation(&adj, &noisy).unwrap();
681        let tv_smooth = GraphSignalSmoother::total_variation(&adj, &smoothed).unwrap();
682        assert!(tv_smooth < tv_noisy);
683    }
684}