Skip to main content

scirs2_graph/signal_processing/
sampling.rs

1//! Graph signal sampling: optimal set selection, bandlimited reconstruction,
2//! bandwidth estimation (Gershgorin), and the graph uncertainty principle.
3//!
4//! ## Overview
5//!
6//! Classical Nyquist–Shannon sampling requires that a signal's bandwidth be at
7//! most half the sampling rate.  On graphs, an analogous theory exists where
8//! "frequency" is defined through the Laplacian spectrum:
9//!
10//! - A signal is **k-bandlimited** if its GFT coefficients beyond index `k` are zero.
11//! - A **sampling set** `S ⊆ V` is *k-valid* if the restriction `U_k|_S` (top-k
12//!   eigenvectors restricted to rows in `S`) has full column rank.
13//! - **Gershgorin circles** provide fast spectral radius bounds without computing
14//!   the full eigendecomposition.
15//! - The **graph uncertainty principle** quantifies the trade-off between spatial
16//!   spread and spectral spread of a signal.
17//!
18//! ## References
19//! - Pesenson (2008). "Sampling in Paley-Wiener spaces on combinatorial graphs."
20//! - Shomorony & Avestimehr (2014). "Sampling large graphs."
21//! - Agaskar & Lu (2013). "A spectral graph uncertainty principle."
22//!
23//! ## Example
24//! ```rust,no_run
25//! use scirs2_core::ndarray::Array2;
26//! use scirs2_graph::signal_processing::sampling::{GraphSampling, BandlimitedReconstruction};
27//! use scirs2_graph::signal_processing::gsp::GraphFourierTransform;
28//!
29//! let mut adj = Array2::<f64>::zeros((6, 6));
30//! for i in 0..5_usize { adj[[i, i+1]] = 1.0; adj[[i+1, i]] = 1.0; }
31//! let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
32//!
33//! // Find a 2-valid sampling set
34//! let sampler = GraphSampling::new(2);
35//! let set = sampler.greedy_sampling_set(&gft).unwrap();
36//! println!("Sampling set: {set:?}");
37//! ```
38
39use scirs2_core::ndarray::{Array1, Array2};
40
41use crate::error::{GraphError, Result};
42use crate::signal_processing::gsp::GraphFourierTransform;
43
44// ─────────────────────────────────────────────────────────────────────────────
45// GraphSampling — greedy optimal sampling set
46// ─────────────────────────────────────────────────────────────────────────────
47
48/// Optimal graph signal sampling set selection.
49///
50/// Selects a set of `k` nodes from which a `k`-bandlimited graph signal can
51/// be uniquely reconstructed.  The greedy algorithm maximises the cut-off
52/// frequency achievable from the selected nodes by greedily picking nodes that
53/// maximise the determinant (volume) of the submatrix formed by the `k`
54/// smoothest eigenvectors restricted to the sampling set.
55///
56/// The greedy maximisation of the log-determinant is equivalent to maximising
57/// the minimum singular value of the restricted eigenvector matrix, which
58/// directly determines the stability of the reconstruction.
59#[derive(Debug, Clone)]
60pub struct GraphSampling {
61    /// Number of frequency components (bandwidth) of the signal.
62    pub bandwidth: usize,
63}
64
65impl GraphSampling {
66    /// Create a `GraphSampling` instance for signals with `bandwidth` components.
67    pub fn new(bandwidth: usize) -> Self {
68        Self { bandwidth }
69    }
70
71    /// Greedy sampling set selection.
72    ///
73    /// Returns a list of `bandwidth` node indices that form a valid sampling set
74    /// (i.e. the restriction of the first `bandwidth` eigenvectors to these rows
75    /// has full column rank).
76    ///
77    /// The algorithm iteratively adds the node that maximises the volume
78    /// (log-det of `Aᵀ A` where `A` is the current restricted eigenvector
79    /// submatrix).  This is a classic D-optimal experimental design step.
80    pub fn greedy_sampling_set(&self, gft: &GraphFourierTransform) -> Result<Vec<usize>> {
81        let n = gft.num_nodes();
82        let k = self.bandwidth;
83        if k == 0 {
84            return Ok(Vec::new());
85        }
86        if k > n {
87            return Err(GraphError::InvalidParameter {
88                param: "bandwidth".into(),
89                value: k.to_string(),
90                expected: format!("<= n ({})", n),
91                context: "GraphSampling".into(),
92            });
93        }
94
95        // The first k columns of the eigenvector matrix U (shape n×k)
96        let u_k = gft
97            .eigenvectors
98            .slice(scirs2_core::ndarray::s![.., ..k])
99            .to_owned();
100
101        let mut selected: Vec<usize> = Vec::with_capacity(k);
102        let mut remaining: Vec<usize> = (0..n).collect();
103
104        for _ in 0..k {
105            // Find the node that maximises the volume increment.
106            // Volume criterion: choose node `r` maximising the squared norm of
107            // the projection of u_k[r, :] onto the complement of the span of
108            // already-selected rows.
109            let best = Self::pick_best_node(&u_k, &selected, &remaining, n, k)?;
110            selected.push(best);
111            remaining.retain(|&x| x != best);
112        }
113
114        Ok(selected)
115    }
116
117    /// Pick the node index from `remaining` that maximises the leverage score
118    /// (projection onto the complement of the span of currently selected rows).
119    fn pick_best_node(
120        u_k: &Array2<f64>,
121        selected: &[usize],
122        remaining: &[usize],
123        _n: usize,
124        k: usize,
125    ) -> Result<usize> {
126        // Build the current selection matrix (rows of u_k for selected nodes)
127        // Compute the orthogonal projection onto the span of selected rows.
128        // Leverage score of row r = ‖(I − Π) u_k[r,:]‖²
129
130        let mut best_score = -1.0_f64;
131        let mut best_node = remaining[0];
132
133        for &r in remaining {
134            let row = u_k.row(r);
135            let score = if selected.is_empty() {
136                // No projection yet; score = ‖row‖²
137                row.iter().map(|&x| x * x).sum::<f64>()
138            } else {
139                // Project row onto the complement of the span of selected rows.
140                let s = selected.len();
141                // Assemble sub-matrix S (s × k)
142                let mut sub = Array2::<f64>::zeros((s, k));
143                for (new_r, &old_r) in selected.iter().enumerate() {
144                    for c in 0..k {
145                        sub[[new_r, c]] = u_k[[old_r, c]];
146                    }
147                }
148                // QR-free: use Gram-Schmidt residual
149                let row_vec: Vec<f64> = row.to_vec();
150                let residual = gram_schmidt_residual(&row_vec, &sub);
151                residual.iter().map(|&x| x * x).sum::<f64>()
152            };
153            if score > best_score {
154                best_score = score;
155                best_node = r;
156            }
157        }
158
159        Ok(best_node)
160    }
161
162    /// Check whether a given node set is a valid `bandwidth`-sampling set,
163    /// i.e. the restricted eigenvector matrix has full column rank.
164    ///
165    /// Validity is checked by verifying that the minimum singular value of
166    /// the restricted matrix exceeds `tol`.
167    pub fn is_valid_sampling_set(
168        &self,
169        gft: &GraphFourierTransform,
170        set: &[usize],
171        tol: f64,
172    ) -> Result<bool> {
173        let k = self.bandwidth;
174        let n = gft.num_nodes();
175        if set.len() < k {
176            return Ok(false);
177        }
178        for &s in set {
179            if s >= n {
180                return Err(GraphError::InvalidParameter {
181                    param: "set node index".into(),
182                    value: s.to_string(),
183                    expected: format!("< {n}"),
184                    context: "is_valid_sampling_set".into(),
185                });
186            }
187        }
188        // Build restricted matrix (|set| × k)
189        let m = set.len();
190        let mut r_mat = Array2::<f64>::zeros((m, k));
191        for (new_r, &old_r) in set.iter().enumerate() {
192            for c in 0..k {
193                r_mat[[new_r, c]] = gft.eigenvectors[[old_r, c]];
194            }
195        }
196        // Minimum singular value via power iteration on (A^T A)
197        let min_sv = min_singular_value(&r_mat);
198        Ok(min_sv > tol)
199    }
200}
201
202/// Gram-Schmidt residual of `v` w.r.t. rows of `basis` (each row is a basis vector).
203fn gram_schmidt_residual(v: &[f64], basis: &Array2<f64>) -> Vec<f64> {
204    let k = v.len();
205    let m = basis.nrows();
206    let mut res = v.to_vec();
207    for i in 0..m {
208        let row = basis.row(i);
209        let dot: f64 = res.iter().zip(row.iter()).map(|(&a, &b)| a * b).sum();
210        let norm_sq: f64 = row.iter().map(|&x| x * x).sum();
211        if norm_sq > 1e-14 {
212            for (r, &b) in res.iter_mut().zip(row.iter()) {
213                *r -= (dot / norm_sq) * b;
214            }
215        }
216    }
217    let _ = k; // used implicitly
218    res
219}
220
221/// Approximate minimum singular value of matrix `a` via inverse power iteration
222/// applied to `aᵀ a`.  Returns a lower bound.
223fn min_singular_value(a: &Array2<f64>) -> f64 {
224    let m = a.nrows();
225    let k = a.ncols();
226    if m < k {
227        return 0.0;
228    }
229    // Compute ata = a^T a  (k × k)
230    let mut ata = Array2::<f64>::zeros((k, k));
231    for i in 0..k {
232        for j in 0..k {
233            let mut acc = 0.0_f64;
234            for r in 0..m {
235                acc += a[[r, i]] * a[[r, j]];
236            }
237            ata[[i, j]] = acc;
238        }
239    }
240
241    // Power iteration on ata to find max eigenvalue, then invert
242    // (approximation: we use the ratio of norms as a cheap estimate)
243    // A more robust approach: Frobenius norm as upper bound, then
244    // check conditioning.
245    let frob: f64 = ata.iter().map(|&x| x * x).sum::<f64>().sqrt();
246    if frob < 1e-14 {
247        return 0.0;
248    }
249
250    // One step of inverse power iteration with random start
251    let mut v: Vec<f64> = (0..k).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
252    for _ in 0..30 {
253        // Solve ata * w = v  (via Gaussian elimination on small k×k system)
254        let w = solve_linear(&ata, &v);
255        let norm: f64 = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
256        if norm < 1e-14 {
257            return 0.0;
258        }
259        v = w.iter().map(|&x| x / norm).collect();
260    }
261
262    // Rayleigh quotient: v^T ata v / v^T v
263    let ata_v = matvec(&ata, &v);
264    let num: f64 = v.iter().zip(ata_v.iter()).map(|(&a, &b)| a * b).sum();
265    let den: f64 = v.iter().map(|&x| x * x).sum();
266    if den < 1e-14 {
267        return 0.0;
268    }
269    // This Rayleigh quotient converges to the MIN eigenvalue of ata,
270    // hence the min singular value of a is sqrt(min_eig(ata)).
271    (num / den).max(0.0).sqrt()
272}
273
274fn matvec(a: &Array2<f64>, v: &[f64]) -> Vec<f64> {
275    let n = a.nrows();
276    let k = a.ncols();
277    let mut out = vec![0.0_f64; n];
278    for i in 0..n {
279        for j in 0..k {
280            out[i] += a[[i, j]] * v[j];
281        }
282    }
283    out
284}
285
286/// Solve `a x = b` via Gaussian elimination with partial pivoting (small k×k).
287fn solve_linear(a: &Array2<f64>, b: &[f64]) -> Vec<f64> {
288    let n = a.nrows();
289    // Augmented matrix [A | b]
290    let mut aug: Vec<Vec<f64>> = (0..n)
291        .map(|i| {
292            let mut row: Vec<f64> = (0..n).map(|j| a[[i, j]]).collect();
293            row.push(b[i]);
294            row
295        })
296        .collect();
297
298    for col in 0..n {
299        // Find pivot
300        let mut max_row = col;
301        let mut max_val = aug[col][col].abs();
302        for row in (col + 1)..n {
303            if aug[row][col].abs() > max_val {
304                max_val = aug[row][col].abs();
305                max_row = row;
306            }
307        }
308        aug.swap(col, max_row);
309        let pivot = aug[col][col];
310        if pivot.abs() < 1e-14 {
311            // Singular; return zeros
312            return vec![0.0; n];
313        }
314        for row in (col + 1)..n {
315            let factor = aug[row][col] / pivot;
316            for k in col..=n {
317                let val = aug[col][k];
318                aug[row][k] -= factor * val;
319            }
320        }
321    }
322
323    // Back-substitution
324    let mut x = vec![0.0_f64; n];
325    for i in (0..n).rev() {
326        let mut sum = aug[i][n];
327        for j in (i + 1)..n {
328            sum -= aug[i][j] * x[j];
329        }
330        x[i] = sum / aug[i][i];
331    }
332    x
333}
334
335// ─────────────────────────────────────────────────────────────────────────────
336// BandlimitedReconstruction — reconstruct from samples
337// ─────────────────────────────────────────────────────────────────────────────
338
339/// Reconstruct a `k`-bandlimited graph signal from its samples on a node set.
340///
341/// Given samples `y = x[S]` at nodes `S`, and assuming the signal `x` is
342/// `k`-bandlimited (only first `k` GFT components are non-zero), recover `x`
343/// by solving the least-squares system:
344///
345///   U_k[S, :] α = y   =>   α = (U_k[S,:]ᵀ U_k[S,:])⁻¹ U_k[S,:]ᵀ y
346///   x = U_k α
347///
348/// where `U_k` is the matrix of the `k` smoothest eigenvectors.
349#[derive(Debug, Clone)]
350pub struct BandlimitedReconstruction {
351    /// Number of frequency components.
352    pub bandwidth: usize,
353}
354
355impl BandlimitedReconstruction {
356    /// Create a reconstruction instance for `k`-bandlimited signals.
357    pub fn new(bandwidth: usize) -> Self {
358        Self { bandwidth }
359    }
360
361    /// Reconstruct the full graph signal from samples.
362    ///
363    /// # Arguments
364    /// * `gft` — precomputed GFT for the graph.
365    /// * `sampling_set` — indices of sampled nodes.
366    /// * `samples` — signal values at the sampled nodes (length = `|sampling_set|`).
367    ///
368    /// # Returns
369    /// Reconstructed signal of length `n` (all nodes).
370    pub fn reconstruct(
371        &self,
372        gft: &GraphFourierTransform,
373        sampling_set: &[usize],
374        samples: &Array1<f64>,
375    ) -> Result<Array1<f64>> {
376        let n = gft.num_nodes();
377        let k = self.bandwidth;
378        let s = sampling_set.len();
379
380        if samples.len() != s {
381            return Err(GraphError::InvalidParameter {
382                param: "samples.len()".into(),
383                value: samples.len().to_string(),
384                expected: format!("{s} (= |sampling_set|)"),
385                context: "BandlimitedReconstruction".into(),
386            });
387        }
388        if s < k {
389            return Err(GraphError::InvalidParameter {
390                param: "sampling_set.len()".into(),
391                value: s.to_string(),
392                expected: format!(">= bandwidth ({})", k),
393                context: "BandlimitedReconstruction".into(),
394            });
395        }
396
397        // Build U_k[S, :] — shape (s, k)
398        let mut u_s = Array2::<f64>::zeros((s, k));
399        for (new_r, &old_r) in sampling_set.iter().enumerate() {
400            if old_r >= n {
401                return Err(GraphError::InvalidParameter {
402                    param: "sampling_set node".into(),
403                    value: old_r.to_string(),
404                    expected: format!("< {n}"),
405                    context: "BandlimitedReconstruction".into(),
406                });
407            }
408            for c in 0..k {
409                u_s[[new_r, c]] = gft.eigenvectors[[old_r, c]];
410            }
411        }
412
413        // Normal equations: (U_s^T U_s) α = U_s^T y
414        // Build Gram matrix G = U_s^T U_s  (k × k)
415        let mut gram = Array2::<f64>::zeros((k, k));
416        for i in 0..k {
417            for j in 0..k {
418                let mut acc = 0.0_f64;
419                for r in 0..s {
420                    acc += u_s[[r, i]] * u_s[[r, j]];
421                }
422                gram[[i, j]] = acc;
423            }
424        }
425
426        // Build right-hand side: rhs = U_s^T y  (length k)
427        let rhs: Vec<f64> = (0..k)
428            .map(|c| (0..s).map(|r| u_s[[r, c]] * samples[r]).sum::<f64>())
429            .collect();
430
431        // Solve Gram α = rhs
432        let alpha = solve_linear(&gram, &rhs);
433
434        // Reconstruct: x = U_k α  (length n)
435        let mut x = Array1::<f64>::zeros(n);
436        for i in 0..n {
437            let mut acc = 0.0_f64;
438            for c in 0..k {
439                acc += gft.eigenvectors[[i, c]] * alpha[c];
440            }
441            x[i] = acc;
442        }
443        Ok(x)
444    }
445}
446
447// ─────────────────────────────────────────────────────────────────────────────
448// GershgorinBound — graph signal bandwidth estimation
449// ─────────────────────────────────────────────────────────────────────────────
450
451/// Gershgorin-circle-based bounds on the graph Laplacian spectrum.
452///
453/// The Gershgorin circle theorem states that every eigenvalue `λ` of a matrix
454/// `M` lies in at least one Gershgorin disc centred at `M[i,i]` with radius
455/// `R_i = Σ_{j≠i} |M[i,j]|`.
456///
457/// For the graph Laplacian `L = D − A`:
458///   - `L[i,i] = degree(i)`
459///   - `R_i = degree(i)` (sum of off-diagonal magnitudes)
460///
461/// Therefore all eigenvalues lie in `[0, 2 * max_degree]` and we get
462/// a fast upper bound on the graph bandwidth without computing eigenvalues.
463#[derive(Debug, Clone)]
464pub struct GershgorinBound {
465    /// Upper bound on the spectral radius (max Laplacian eigenvalue).
466    pub lambda_max_upper: f64,
467    /// Lower bound: 0 (Laplacian is PSD).
468    pub lambda_min_lower: f64,
469    /// Per-node Gershgorin radii.
470    pub radii: Vec<f64>,
471    /// Per-node centres (= degree of node i).
472    pub centres: Vec<f64>,
473}
474
475impl GershgorinBound {
476    /// Compute Gershgorin bounds from a weighted adjacency matrix.
477    pub fn from_adjacency(adj: &Array2<f64>) -> Result<Self> {
478        let n = adj.nrows();
479        if n == 0 {
480            return Err(GraphError::InvalidGraph("empty adjacency".into()));
481        }
482        let mut centres = vec![0.0_f64; n];
483        let mut radii = vec![0.0_f64; n];
484        for i in 0..n {
485            let deg = adj.row(i).iter().map(|&x| x.abs()).sum::<f64>();
486            centres[i] = deg; // L[i,i] = degree
487            radii[i] = deg; // R_i = sum of |L[i,j]| for j != i = same deg for L
488        }
489        let lambda_max_upper = centres
490            .iter()
491            .zip(radii.iter())
492            .map(|(&c, &r)| c + r)
493            .fold(0.0_f64, f64::max);
494        Ok(Self {
495            lambda_max_upper,
496            lambda_min_lower: 0.0,
497            radii,
498            centres,
499        })
500    }
501
502    /// Estimate the effective bandwidth of a graph signal based on the
503    /// fraction of spectral energy concentrated in low frequencies.
504    ///
505    /// Computes the GFT and returns the smallest `k` such that the
506    /// cumulative spectral energy in the first `k` components exceeds `threshold`.
507    pub fn signal_bandwidth(
508        gft: &GraphFourierTransform,
509        signal: &Array1<f64>,
510        threshold: f64,
511    ) -> Result<usize> {
512        let x_hat = gft.transform(signal)?;
513        let total_energy: f64 = x_hat.iter().map(|&c| c * c).sum();
514        if total_energy < 1e-14 {
515            return Ok(0);
516        }
517        let mut cumulative = 0.0_f64;
518        for (k, &c) in x_hat.iter().enumerate() {
519            cumulative += c * c;
520            if cumulative / total_energy >= threshold {
521                return Ok(k + 1);
522            }
523        }
524        Ok(x_hat.len())
525    }
526}
527
528// ─────────────────────────────────────────────────────────────────────────────
529// GraphUncertaintyPrinciple
530// ─────────────────────────────────────────────────────────────────────────────
531
532/// Graph uncertainty principle: spatial spread vs. spectral spread tradeoff.
533///
534/// Analogous to Heisenberg's uncertainty principle, signals on graphs cannot
535/// be simultaneously concentrated in both the vertex domain and the spectral
536/// domain.
537///
538/// **Spatial spread** of signal `x` about reference node `v₀`:
539///   Δ_V²(x) = Σ_i d(i, v₀)² x_i² / ‖x‖²
540///
541/// **Spectral spread** about reference frequency `λ₀`:
542///   Δ_S²(x) = Σ_k (λ_k − λ₀)² x̂_k² / ‖x‖²
543///
544/// The product `Δ_V Δ_S ≥ C` for some graph-dependent constant.
545#[derive(Debug, Clone)]
546pub struct GraphUncertaintyPrinciple {
547    /// Squared pairwise distances from each node to node `v₀`.
548    pub spatial_distances_sq: Array1<f64>,
549    /// Reference frequency `λ₀` (usually the DC frequency = 0).
550    pub ref_frequency: f64,
551}
552
553impl GraphUncertaintyPrinciple {
554    /// Build from BFS shortest-path distances from reference node `v0`.
555    ///
556    /// # Arguments
557    /// * `adj` — weighted adjacency matrix.
558    /// * `v0` — reference node index (typically the "centre" of the signal).
559    /// * `ref_frequency` — spectral reference point (default: 0 for DC).
560    pub fn new(adj: &Array2<f64>, v0: usize, ref_frequency: f64) -> Result<Self> {
561        let n = adj.nrows();
562        if v0 >= n {
563            return Err(GraphError::InvalidParameter {
564                param: "v0".into(),
565                value: v0.to_string(),
566                expected: format!("< {n}"),
567                context: "GraphUncertaintyPrinciple".into(),
568            });
569        }
570        // Compute shortest-path distances from v0 via Dijkstra (unweighted BFS)
571        let mut dist = vec![f64::INFINITY; n];
572        dist[v0] = 0.0;
573        let mut queue = std::collections::VecDeque::new();
574        queue.push_back(v0);
575        while let Some(u) = queue.pop_front() {
576            for v in 0..n {
577                if adj[[u, v]] != 0.0 && dist[v].is_infinite() {
578                    dist[v] = dist[u] + 1.0;
579                    queue.push_back(v);
580                }
581            }
582        }
583        let spatial_distances_sq = Array1::from_iter(dist.iter().map(|&d| d * d));
584        Ok(Self {
585            spatial_distances_sq,
586            ref_frequency,
587        })
588    }
589
590    /// Compute the spatial spread Δ_V(x) for signal `x`.
591    pub fn spatial_spread(&self, signal: &Array1<f64>) -> Result<f64> {
592        let n = signal.len();
593        if n != self.spatial_distances_sq.len() {
594            return Err(GraphError::InvalidParameter {
595                param: "signal.len()".into(),
596                value: n.to_string(),
597                expected: self.spatial_distances_sq.len().to_string(),
598                context: "spatial_spread".into(),
599            });
600        }
601        let norm_sq: f64 = signal.iter().map(|&x| x * x).sum();
602        if norm_sq < 1e-14 {
603            return Ok(0.0);
604        }
605        let spread_sq: f64 = signal
606            .iter()
607            .zip(self.spatial_distances_sq.iter())
608            .map(|(&x, &d2)| d2 * x * x)
609            .sum::<f64>()
610            / norm_sq;
611        Ok(spread_sq.sqrt())
612    }
613
614    /// Compute the spectral spread Δ_S(x) for signal `x` using the GFT.
615    pub fn spectral_spread(
616        &self,
617        gft: &GraphFourierTransform,
618        signal: &Array1<f64>,
619    ) -> Result<f64> {
620        let x_hat = gft.transform(signal)?;
621        let norm_sq: f64 = x_hat.iter().map(|&c| c * c).sum();
622        if norm_sq < 1e-14 {
623            return Ok(0.0);
624        }
625        let spread_sq: f64 = x_hat
626            .iter()
627            .zip(gft.eigenvalues.iter())
628            .map(|(&c, &lam)| {
629                let d = lam - self.ref_frequency;
630                d * d * c * c
631            })
632            .sum::<f64>()
633            / norm_sq;
634        Ok(spread_sq.sqrt())
635    }
636
637    /// Compute both spatial and spectral spreads and their product.
638    ///
639    /// Returns `(spatial_spread, spectral_spread, product)`.
640    pub fn uncertainty(
641        &self,
642        gft: &GraphFourierTransform,
643        signal: &Array1<f64>,
644    ) -> Result<(f64, f64, f64)> {
645        let dv = self.spatial_spread(signal)?;
646        let ds = self.spectral_spread(gft, signal)?;
647        Ok((dv, ds, dv * ds))
648    }
649}
650
651// ─────────────────────────────────────────────────────────────────────────────
652// Tests
653// ─────────────────────────────────────────────────────────────────────────────
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use scirs2_core::ndarray::Array1;
659
660    fn path_adj(n: usize) -> Array2<f64> {
661        let mut adj = Array2::<f64>::zeros((n, n));
662        for i in 0..(n - 1) {
663            adj[[i, i + 1]] = 1.0;
664            adj[[i + 1, i]] = 1.0;
665        }
666        adj
667    }
668
669    #[test]
670    fn test_greedy_sampling_set_size() {
671        let adj = path_adj(8);
672        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
673        let sampler = GraphSampling::new(3);
674        let set = sampler.greedy_sampling_set(&gft).unwrap();
675        assert_eq!(set.len(), 3);
676        // All indices should be valid
677        for &s in &set {
678            assert!(s < 8);
679        }
680        // All indices should be unique
681        let mut uniq = set.clone();
682        uniq.sort();
683        uniq.dedup();
684        assert_eq!(uniq.len(), set.len());
685    }
686
687    #[test]
688    fn test_bandlimited_reconstruction() {
689        let n = 6;
690        let adj = path_adj(n);
691        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
692        let k = 2;
693
694        // Build a k-bandlimited signal (only first k GFT components)
695        let mut x_hat = Array1::<f64>::zeros(n);
696        x_hat[0] = 2.0;
697        x_hat[1] = 1.0;
698        let original = gft.inverse(&x_hat).unwrap();
699
700        // Sample all nodes (trivial reconstruction)
701        let set: Vec<usize> = (0..n).collect();
702        let samples = Array1::from_iter(set.iter().map(|&i| original[i]));
703        let rec = BandlimitedReconstruction::new(k)
704            .reconstruct(&gft, &set, &samples)
705            .unwrap();
706
707        for (a, b) in original.iter().zip(rec.iter()) {
708            assert!((a - b).abs() < 1e-8, "Reconstruction mismatch: {a} vs {b}");
709        }
710    }
711
712    #[test]
713    fn test_gershgorin_bounds() {
714        let adj = path_adj(5);
715        let bounds = GershgorinBound::from_adjacency(&adj).unwrap();
716        assert!(bounds.lambda_min_lower == 0.0);
717        assert!(bounds.lambda_max_upper > 0.0);
718        // For a path graph, all eigenvalues are in [0, 4], so upper bound <= 4
719        assert!(bounds.lambda_max_upper <= 4.0 + 1e-9);
720    }
721
722    #[test]
723    fn test_signal_bandwidth() {
724        let adj = path_adj(8);
725        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
726        // DC signal: all energy in component 0
727        let dc = Array1::from_vec(vec![1.0; 8]);
728        let bw = GershgorinBound::signal_bandwidth(&gft, &dc, 0.99).unwrap();
729        assert!(bw <= 2, "DC signal should have bandwidth 1 (got {bw})");
730    }
731
732    #[test]
733    fn test_uncertainty_principle() {
734        let adj = path_adj(7);
735        let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
736        let up = GraphUncertaintyPrinciple::new(&adj, 3, 0.0).unwrap();
737        // Localised signal at node 3
738        let mut local = Array1::<f64>::zeros(7);
739        local[3] = 1.0;
740        let (dv, ds, prod) = up.uncertainty(&gft, &local).unwrap();
741        assert!(dv >= 0.0);
742        assert!(ds >= 0.0);
743        assert!(prod >= 0.0);
744        // A localised vertex signal should have small spatial spread
745        assert!(
746            dv < 1.0,
747            "Vertex-localised signal should have small dv: {dv}"
748        );
749    }
750}