1use super::{ChebyshevExpansion, ScaledLaplacian};
7
8#[derive(Debug, Clone)]
10pub struct WaveletScale {
11 pub scale: f64,
13 pub filter: ChebyshevExpansion,
15}
16
17impl WaveletScale {
18 pub fn mexican_hat(scale: f64, degree: usize) -> Self {
21 let filter = ChebyshevExpansion::from_function(
22 |x| {
23 let lambda = (x + 1.0); lambda * (-lambda * scale).exp()
25 },
26 degree,
27 );
28
29 Self { scale, filter }
30 }
31
32 pub fn heat_derivative(scale: f64, degree: usize) -> Self {
35 Self::mexican_hat(scale, degree)
36 }
37
38 pub fn scaling_function(scale: f64, degree: usize) -> Self {
41 let filter = ChebyshevExpansion::from_function(
42 |x| {
43 let lambda = (x + 1.0);
44 (-lambda * scale).exp()
45 },
46 degree,
47 );
48
49 Self { scale, filter }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct GraphWavelet {
56 pub scale: WaveletScale,
58 pub center: usize,
60 pub coefficients: Vec<f64>,
62}
63
64impl GraphWavelet {
65 pub fn at_vertex(
67 laplacian: &ScaledLaplacian,
68 scale: &WaveletScale,
69 center: usize,
70 ) -> Self {
71 let n = laplacian.n;
72
73 let mut delta = vec![0.0; n];
75 if center < n {
76 delta[center] = 1.0;
77 }
78
79 let coefficients = apply_filter(laplacian, &scale.filter, &delta);
81
82 Self {
83 scale: scale.clone(),
84 center,
85 coefficients,
86 }
87 }
88
89 pub fn inner_product(&self, signal: &[f64]) -> f64 {
91 self.coefficients
92 .iter()
93 .zip(signal.iter())
94 .map(|(&w, &s)| w * s)
95 .sum()
96 }
97
98 pub fn norm(&self) -> f64 {
100 self.coefficients.iter().map(|x| x * x).sum::<f64>().sqrt()
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct SpectralWaveletTransform {
107 laplacian: ScaledLaplacian,
109 scales: Vec<WaveletScale>,
111 scaling: WaveletScale,
113 degree: usize,
115}
116
117impl SpectralWaveletTransform {
118 pub fn new(laplacian: ScaledLaplacian, num_scales: usize, degree: usize) -> Self {
120 let min_scale = 0.1;
122 let max_scale = 2.0 / laplacian.lambda_max;
123
124 let scales: Vec<WaveletScale> = (0..num_scales)
125 .map(|i| {
126 let t = if num_scales > 1 {
127 min_scale * (max_scale / min_scale).powf(i as f64 / (num_scales - 1) as f64)
128 } else {
129 min_scale
130 };
131 WaveletScale::mexican_hat(t, degree)
132 })
133 .collect();
134
135 let scaling = WaveletScale::scaling_function(max_scale, degree);
136
137 Self {
138 laplacian,
139 scales,
140 scaling,
141 degree,
142 }
143 }
144
145 pub fn forward(&self, signal: &[f64]) -> (Vec<f64>, Vec<Vec<f64>>) {
148 let scaling_coeffs = apply_filter(&self.laplacian, &self.scaling.filter, signal);
150
151 let wavelet_coeffs: Vec<Vec<f64>> = self
153 .scales
154 .iter()
155 .map(|s| apply_filter(&self.laplacian, &s.filter, signal))
156 .collect();
157
158 (scaling_coeffs, wavelet_coeffs)
159 }
160
161 pub fn inverse(&self, scaling_coeffs: &[f64], wavelet_coeffs: &[Vec<f64>]) -> Vec<f64> {
164 let n = self.laplacian.n;
165 let mut signal = vec![0.0; n];
166
167 let scaled_scaling = apply_filter(&self.laplacian, &self.scaling.filter, scaling_coeffs);
169 for i in 0..n {
170 signal[i] += scaled_scaling[i];
171 }
172
173 for (scale, coeffs) in self.scales.iter().zip(wavelet_coeffs.iter()) {
175 let scaled_wavelet = apply_filter(&self.laplacian, &scale.filter, coeffs);
176 for i in 0..n {
177 signal[i] += scaled_wavelet[i];
178 }
179 }
180
181 signal
182 }
183
184 pub fn scale_energies(&self, signal: &[f64]) -> Vec<f64> {
186 let (_, wavelet_coeffs) = self.forward(signal);
187
188 wavelet_coeffs
189 .iter()
190 .map(|coeffs| coeffs.iter().map(|x| x * x).sum::<f64>())
191 .collect()
192 }
193
194 pub fn wavelets_at(&self, vertex: usize) -> Vec<GraphWavelet> {
196 self.scales
197 .iter()
198 .map(|s| GraphWavelet::at_vertex(&self.laplacian, s, vertex))
199 .collect()
200 }
201
202 pub fn num_scales(&self) -> usize {
204 self.scales.len()
205 }
206
207 pub fn scale_values(&self) -> Vec<f64> {
209 self.scales.iter().map(|s| s.scale).collect()
210 }
211}
212
213fn apply_filter(laplacian: &ScaledLaplacian, filter: &ChebyshevExpansion, signal: &[f64]) -> Vec<f64> {
215 let n = laplacian.n;
216 let coeffs = &filter.coefficients;
217
218 if coeffs.is_empty() || signal.len() != n {
219 return vec![0.0; n];
220 }
221
222 let k = coeffs.len() - 1;
223
224 let mut t_prev: Vec<f64> = signal.to_vec();
225 let mut t_curr: Vec<f64> = laplacian.apply(signal);
226
227 let mut output = vec![0.0; n];
228
229 for i in 0..n {
231 output[i] += coeffs[0] * t_prev[i];
232 }
233
234 if coeffs.len() > 1 {
236 for i in 0..n {
237 output[i] += coeffs[1] * t_curr[i];
238 }
239 }
240
241 for ki in 2..=k {
243 let lt_curr = laplacian.apply(&t_curr);
244 let mut t_next = vec![0.0; n];
245 for i in 0..n {
246 t_next[i] = 2.0 * lt_curr[i] - t_prev[i];
247 }
248
249 for i in 0..n {
250 output[i] += coeffs[ki] * t_next[i];
251 }
252
253 t_prev = t_curr;
254 t_curr = t_next;
255 }
256
257 output
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 fn path_graph_laplacian(n: usize) -> ScaledLaplacian {
265 let edges: Vec<(usize, usize, f64)> = (0..n - 1).map(|i| (i, i + 1, 1.0)).collect();
266 ScaledLaplacian::from_sparse_adjacency(&edges, n)
267 }
268
269 #[test]
270 fn test_wavelet_scale() {
271 let scale = WaveletScale::mexican_hat(0.5, 10);
272 assert_eq!(scale.scale, 0.5);
273 assert!(!scale.filter.coefficients.is_empty());
274 }
275
276 #[test]
277 fn test_graph_wavelet() {
278 let laplacian = path_graph_laplacian(10);
279 let scale = WaveletScale::mexican_hat(0.5, 10);
280
281 let wavelet = GraphWavelet::at_vertex(&laplacian, &scale, 5);
282
283 assert_eq!(wavelet.center, 5);
284 assert_eq!(wavelet.coefficients.len(), 10);
285 assert!(wavelet.coefficients[5].abs() > 0.0);
287 }
288
289 #[test]
290 fn test_wavelet_transform() {
291 let laplacian = path_graph_laplacian(20);
292 let transform = SpectralWaveletTransform::new(laplacian, 4, 10);
293
294 assert_eq!(transform.num_scales(), 4);
295
296 let signal: Vec<f64> = (0..20).map(|i| (i as f64 * 0.3).sin()).collect();
298 let (scaling, wavelets) = transform.forward(&signal);
299
300 assert_eq!(scaling.len(), 20);
301 assert_eq!(wavelets.len(), 4);
302 for w in &wavelets {
303 assert_eq!(w.len(), 20);
304 }
305 }
306
307 #[test]
308 fn test_scale_energies() {
309 let laplacian = path_graph_laplacian(20);
310 let transform = SpectralWaveletTransform::new(laplacian, 4, 10);
311
312 let signal: Vec<f64> = (0..20).map(|i| (i as f64 * 0.3).sin()).collect();
313 let energies = transform.scale_energies(&signal);
314
315 assert_eq!(energies.len(), 4);
316 for e in energies {
318 assert!(e >= 0.0);
319 }
320 }
321
322 #[test]
323 fn test_wavelets_at_vertex() {
324 let laplacian = path_graph_laplacian(10);
325 let transform = SpectralWaveletTransform::new(laplacian, 3, 8);
326
327 let wavelets = transform.wavelets_at(5);
328
329 assert_eq!(wavelets.len(), 3);
330 for w in &wavelets {
331 assert_eq!(w.center, 5);
332 }
333 }
334}