Skip to main content

ruvector_math/spectral/
wavelets.rs

1//! Graph Wavelets
2//!
3//! Multi-scale analysis on graphs using spectral graph wavelets.
4//! Based on Hammond et al. "Wavelets on Graphs via Spectral Graph Theory"
5
6use super::{ChebyshevExpansion, ScaledLaplacian};
7
8/// Wavelet scale configuration
9#[derive(Debug, Clone)]
10pub struct WaveletScale {
11    /// Scale parameter (larger = coarser)
12    pub scale: f64,
13    /// Chebyshev expansion for this scale
14    pub filter: ChebyshevExpansion,
15}
16
17impl WaveletScale {
18    /// Create wavelet at given scale using Mexican hat kernel
19    /// g(λ) = λ * exp(-λ * scale)
20    pub fn mexican_hat(scale: f64, degree: usize) -> Self {
21        let filter = ChebyshevExpansion::from_function(
22            |x| {
23                let lambda = (x + 1.0); // Map [-1,1] to [0,2]
24                lambda * (-lambda * scale).exp()
25            },
26            degree,
27        );
28
29        Self { scale, filter }
30    }
31
32    /// Create wavelet using heat kernel derivative
33    /// g(λ) = λ * exp(-λ * scale) (same as Mexican hat)
34    pub fn heat_derivative(scale: f64, degree: usize) -> Self {
35        Self::mexican_hat(scale, degree)
36    }
37
38    /// Create scaling function (low-pass for residual)
39    /// h(λ) = exp(-λ * scale)
40    pub fn scaling_function(scale: f64, degree: usize) -> Self {
41        let filter = ChebyshevExpansion::from_function(
42            |x| {
43                let lambda = (x + 1.0);
44                (-lambda * scale).exp()
45            },
46            degree,
47        );
48
49        Self { scale, filter }
50    }
51}
52
53/// Graph wavelet at specific vertex
54#[derive(Debug, Clone)]
55pub struct GraphWavelet {
56    /// Wavelet scale
57    pub scale: WaveletScale,
58    /// Center vertex
59    pub center: usize,
60    /// Wavelet coefficients for all vertices
61    pub coefficients: Vec<f64>,
62}
63
64impl GraphWavelet {
65    /// Compute wavelet centered at vertex
66    pub fn at_vertex(laplacian: &ScaledLaplacian, scale: &WaveletScale, center: usize) -> Self {
67        let n = laplacian.n;
68
69        // Delta function at center
70        let mut delta = vec![0.0; n];
71        if center < n {
72            delta[center] = 1.0;
73        }
74
75        // Apply wavelet filter: ψ_s,v = g(L) δ_v
76        let coefficients = apply_filter(laplacian, &scale.filter, &delta);
77
78        Self {
79            scale: scale.clone(),
80            center,
81            coefficients,
82        }
83    }
84
85    /// Inner product with signal
86    pub fn inner_product(&self, signal: &[f64]) -> f64 {
87        self.coefficients
88            .iter()
89            .zip(signal.iter())
90            .map(|(&w, &s)| w * s)
91            .sum()
92    }
93
94    /// L2 norm
95    pub fn norm(&self) -> f64 {
96        self.coefficients.iter().map(|x| x * x).sum::<f64>().sqrt()
97    }
98}
99
100/// Spectral Wavelet Transform
101#[derive(Debug, Clone)]
102pub struct SpectralWaveletTransform {
103    /// Laplacian
104    laplacian: ScaledLaplacian,
105    /// Wavelet scales (finest to coarsest)
106    scales: Vec<WaveletScale>,
107    /// Scaling function (for residual)
108    scaling: WaveletScale,
109    /// Chebyshev degree
110    degree: usize,
111}
112
113impl SpectralWaveletTransform {
114    /// Create wavelet transform with logarithmically spaced scales
115    pub fn new(laplacian: ScaledLaplacian, num_scales: usize, degree: usize) -> Self {
116        // Scales from fine (small t) to coarse (large t)
117        let min_scale = 0.1;
118        let max_scale = 2.0 / laplacian.lambda_max;
119
120        let scales: Vec<WaveletScale> = (0..num_scales)
121            .map(|i| {
122                let t = if num_scales > 1 {
123                    min_scale * (max_scale / min_scale).powf(i as f64 / (num_scales - 1) as f64)
124                } else {
125                    min_scale
126                };
127                WaveletScale::mexican_hat(t, degree)
128            })
129            .collect();
130
131        let scaling = WaveletScale::scaling_function(max_scale, degree);
132
133        Self {
134            laplacian,
135            scales,
136            scaling,
137            degree,
138        }
139    }
140
141    /// Forward transform: compute wavelet coefficients
142    /// Returns (scaling_coeffs, [wavelet_coeffs_scale_0, wavelet_coeffs_scale_1, ...])
143    pub fn forward(&self, signal: &[f64]) -> (Vec<f64>, Vec<Vec<f64>>) {
144        // Scaling coefficients
145        let scaling_coeffs = apply_filter(&self.laplacian, &self.scaling.filter, signal);
146
147        // Wavelet coefficients at each scale
148        let wavelet_coeffs: Vec<Vec<f64>> = self
149            .scales
150            .iter()
151            .map(|s| apply_filter(&self.laplacian, &s.filter, signal))
152            .collect();
153
154        (scaling_coeffs, wavelet_coeffs)
155    }
156
157    /// Inverse transform: reconstruct signal from coefficients
158    /// Note: Perfect reconstruction requires frame bounds analysis
159    pub fn inverse(&self, scaling_coeffs: &[f64], wavelet_coeffs: &[Vec<f64>]) -> Vec<f64> {
160        let n = self.laplacian.n;
161        let mut signal = vec![0.0; n];
162
163        // Add scaling contribution
164        let scaled_scaling = apply_filter(&self.laplacian, &self.scaling.filter, scaling_coeffs);
165        for i in 0..n {
166            signal[i] += scaled_scaling[i];
167        }
168
169        // Add wavelet contributions
170        for (scale, coeffs) in self.scales.iter().zip(wavelet_coeffs.iter()) {
171            let scaled_wavelet = apply_filter(&self.laplacian, &scale.filter, coeffs);
172            for i in 0..n {
173                signal[i] += scaled_wavelet[i];
174            }
175        }
176
177        signal
178    }
179
180    /// Compute wavelet energy at each scale
181    pub fn scale_energies(&self, signal: &[f64]) -> Vec<f64> {
182        let (_, wavelet_coeffs) = self.forward(signal);
183
184        wavelet_coeffs
185            .iter()
186            .map(|coeffs| coeffs.iter().map(|x| x * x).sum::<f64>())
187            .collect()
188    }
189
190    /// Get all wavelets centered at a vertex
191    pub fn wavelets_at(&self, vertex: usize) -> Vec<GraphWavelet> {
192        self.scales
193            .iter()
194            .map(|s| GraphWavelet::at_vertex(&self.laplacian, s, vertex))
195            .collect()
196    }
197
198    /// Number of scales
199    pub fn num_scales(&self) -> usize {
200        self.scales.len()
201    }
202
203    /// Get scale parameters
204    pub fn scale_values(&self) -> Vec<f64> {
205        self.scales.iter().map(|s| s.scale).collect()
206    }
207}
208
209/// Apply Chebyshev filter to signal using recurrence
210fn apply_filter(
211    laplacian: &ScaledLaplacian,
212    filter: &ChebyshevExpansion,
213    signal: &[f64],
214) -> Vec<f64> {
215    let n = laplacian.n;
216    let coeffs = &filter.coefficients;
217
218    if coeffs.is_empty() || signal.len() != n {
219        return vec![0.0; n];
220    }
221
222    let k = coeffs.len() - 1;
223
224    let mut t_prev: Vec<f64> = signal.to_vec();
225    let mut t_curr: Vec<f64> = laplacian.apply(signal);
226
227    let mut output = vec![0.0; n];
228
229    // c_0 * T_0 * x
230    for i in 0..n {
231        output[i] += coeffs[0] * t_prev[i];
232    }
233
234    // c_1 * T_1 * x
235    if coeffs.len() > 1 {
236        for i in 0..n {
237            output[i] += coeffs[1] * t_curr[i];
238        }
239    }
240
241    // Recurrence
242    for ki in 2..=k {
243        let lt_curr = laplacian.apply(&t_curr);
244        let mut t_next = vec![0.0; n];
245        for i in 0..n {
246            t_next[i] = 2.0 * lt_curr[i] - t_prev[i];
247        }
248
249        for i in 0..n {
250            output[i] += coeffs[ki] * t_next[i];
251        }
252
253        t_prev = t_curr;
254        t_curr = t_next;
255    }
256
257    output
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    fn path_graph_laplacian(n: usize) -> ScaledLaplacian {
265        let edges: Vec<(usize, usize, f64)> = (0..n - 1).map(|i| (i, i + 1, 1.0)).collect();
266        ScaledLaplacian::from_sparse_adjacency(&edges, n)
267    }
268
269    #[test]
270    fn test_wavelet_scale() {
271        let scale = WaveletScale::mexican_hat(0.5, 10);
272        assert_eq!(scale.scale, 0.5);
273        assert!(!scale.filter.coefficients.is_empty());
274    }
275
276    #[test]
277    fn test_graph_wavelet() {
278        let laplacian = path_graph_laplacian(10);
279        let scale = WaveletScale::mexican_hat(0.5, 10);
280
281        let wavelet = GraphWavelet::at_vertex(&laplacian, &scale, 5);
282
283        assert_eq!(wavelet.center, 5);
284        assert_eq!(wavelet.coefficients.len(), 10);
285        // Wavelet should be localized around center
286        assert!(wavelet.coefficients[5].abs() > 0.0);
287    }
288
289    #[test]
290    fn test_wavelet_transform() {
291        let laplacian = path_graph_laplacian(20);
292        let transform = SpectralWaveletTransform::new(laplacian, 4, 10);
293
294        assert_eq!(transform.num_scales(), 4);
295
296        // Test forward transform
297        let signal: Vec<f64> = (0..20).map(|i| (i as f64 * 0.3).sin()).collect();
298        let (scaling, wavelets) = transform.forward(&signal);
299
300        assert_eq!(scaling.len(), 20);
301        assert_eq!(wavelets.len(), 4);
302        for w in &wavelets {
303            assert_eq!(w.len(), 20);
304        }
305    }
306
307    #[test]
308    fn test_scale_energies() {
309        let laplacian = path_graph_laplacian(20);
310        let transform = SpectralWaveletTransform::new(laplacian, 4, 10);
311
312        let signal: Vec<f64> = (0..20).map(|i| (i as f64 * 0.3).sin()).collect();
313        let energies = transform.scale_energies(&signal);
314
315        assert_eq!(energies.len(), 4);
316        // All energies should be non-negative
317        for e in energies {
318            assert!(e >= 0.0);
319        }
320    }
321
322    #[test]
323    fn test_wavelets_at_vertex() {
324        let laplacian = path_graph_laplacian(10);
325        let transform = SpectralWaveletTransform::new(laplacian, 3, 8);
326
327        let wavelets = transform.wavelets_at(5);
328
329        assert_eq!(wavelets.len(), 3);
330        for w in &wavelets {
331            assert_eq!(w.center, 5);
332        }
333    }
334}