1use super::{ChebyshevExpansion, ScaledLaplacian};
7
8#[derive(Debug, Clone)]
10pub struct WaveletScale {
11 pub scale: f64,
13 pub filter: ChebyshevExpansion,
15}
16
17impl WaveletScale {
18 pub fn mexican_hat(scale: f64, degree: usize) -> Self {
21 let filter = ChebyshevExpansion::from_function(
22 |x| {
23 let lambda = (x + 1.0); lambda * (-lambda * scale).exp()
25 },
26 degree,
27 );
28
29 Self { scale, filter }
30 }
31
32 pub fn heat_derivative(scale: f64, degree: usize) -> Self {
35 Self::mexican_hat(scale, degree)
36 }
37
38 pub fn scaling_function(scale: f64, degree: usize) -> Self {
41 let filter = ChebyshevExpansion::from_function(
42 |x| {
43 let lambda = (x + 1.0);
44 (-lambda * scale).exp()
45 },
46 degree,
47 );
48
49 Self { scale, filter }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct GraphWavelet {
56 pub scale: WaveletScale,
58 pub center: usize,
60 pub coefficients: Vec<f64>,
62}
63
64impl GraphWavelet {
65 pub fn at_vertex(laplacian: &ScaledLaplacian, scale: &WaveletScale, center: usize) -> Self {
67 let n = laplacian.n;
68
69 let mut delta = vec![0.0; n];
71 if center < n {
72 delta[center] = 1.0;
73 }
74
75 let coefficients = apply_filter(laplacian, &scale.filter, &delta);
77
78 Self {
79 scale: scale.clone(),
80 center,
81 coefficients,
82 }
83 }
84
85 pub fn inner_product(&self, signal: &[f64]) -> f64 {
87 self.coefficients
88 .iter()
89 .zip(signal.iter())
90 .map(|(&w, &s)| w * s)
91 .sum()
92 }
93
94 pub fn norm(&self) -> f64 {
96 self.coefficients.iter().map(|x| x * x).sum::<f64>().sqrt()
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct SpectralWaveletTransform {
103 laplacian: ScaledLaplacian,
105 scales: Vec<WaveletScale>,
107 scaling: WaveletScale,
109 degree: usize,
111}
112
113impl SpectralWaveletTransform {
114 pub fn new(laplacian: ScaledLaplacian, num_scales: usize, degree: usize) -> Self {
116 let min_scale = 0.1;
118 let max_scale = 2.0 / laplacian.lambda_max;
119
120 let scales: Vec<WaveletScale> = (0..num_scales)
121 .map(|i| {
122 let t = if num_scales > 1 {
123 min_scale * (max_scale / min_scale).powf(i as f64 / (num_scales - 1) as f64)
124 } else {
125 min_scale
126 };
127 WaveletScale::mexican_hat(t, degree)
128 })
129 .collect();
130
131 let scaling = WaveletScale::scaling_function(max_scale, degree);
132
133 Self {
134 laplacian,
135 scales,
136 scaling,
137 degree,
138 }
139 }
140
141 pub fn forward(&self, signal: &[f64]) -> (Vec<f64>, Vec<Vec<f64>>) {
144 let scaling_coeffs = apply_filter(&self.laplacian, &self.scaling.filter, signal);
146
147 let wavelet_coeffs: Vec<Vec<f64>> = self
149 .scales
150 .iter()
151 .map(|s| apply_filter(&self.laplacian, &s.filter, signal))
152 .collect();
153
154 (scaling_coeffs, wavelet_coeffs)
155 }
156
157 pub fn inverse(&self, scaling_coeffs: &[f64], wavelet_coeffs: &[Vec<f64>]) -> Vec<f64> {
160 let n = self.laplacian.n;
161 let mut signal = vec![0.0; n];
162
163 let scaled_scaling = apply_filter(&self.laplacian, &self.scaling.filter, scaling_coeffs);
165 for i in 0..n {
166 signal[i] += scaled_scaling[i];
167 }
168
169 for (scale, coeffs) in self.scales.iter().zip(wavelet_coeffs.iter()) {
171 let scaled_wavelet = apply_filter(&self.laplacian, &scale.filter, coeffs);
172 for i in 0..n {
173 signal[i] += scaled_wavelet[i];
174 }
175 }
176
177 signal
178 }
179
180 pub fn scale_energies(&self, signal: &[f64]) -> Vec<f64> {
182 let (_, wavelet_coeffs) = self.forward(signal);
183
184 wavelet_coeffs
185 .iter()
186 .map(|coeffs| coeffs.iter().map(|x| x * x).sum::<f64>())
187 .collect()
188 }
189
190 pub fn wavelets_at(&self, vertex: usize) -> Vec<GraphWavelet> {
192 self.scales
193 .iter()
194 .map(|s| GraphWavelet::at_vertex(&self.laplacian, s, vertex))
195 .collect()
196 }
197
198 pub fn num_scales(&self) -> usize {
200 self.scales.len()
201 }
202
203 pub fn scale_values(&self) -> Vec<f64> {
205 self.scales.iter().map(|s| s.scale).collect()
206 }
207}
208
209fn apply_filter(
211 laplacian: &ScaledLaplacian,
212 filter: &ChebyshevExpansion,
213 signal: &[f64],
214) -> Vec<f64> {
215 let n = laplacian.n;
216 let coeffs = &filter.coefficients;
217
218 if coeffs.is_empty() || signal.len() != n {
219 return vec![0.0; n];
220 }
221
222 let k = coeffs.len() - 1;
223
224 let mut t_prev: Vec<f64> = signal.to_vec();
225 let mut t_curr: Vec<f64> = laplacian.apply(signal);
226
227 let mut output = vec![0.0; n];
228
229 for i in 0..n {
231 output[i] += coeffs[0] * t_prev[i];
232 }
233
234 if coeffs.len() > 1 {
236 for i in 0..n {
237 output[i] += coeffs[1] * t_curr[i];
238 }
239 }
240
241 for ki in 2..=k {
243 let lt_curr = laplacian.apply(&t_curr);
244 let mut t_next = vec![0.0; n];
245 for i in 0..n {
246 t_next[i] = 2.0 * lt_curr[i] - t_prev[i];
247 }
248
249 for i in 0..n {
250 output[i] += coeffs[ki] * t_next[i];
251 }
252
253 t_prev = t_curr;
254 t_curr = t_next;
255 }
256
257 output
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 fn path_graph_laplacian(n: usize) -> ScaledLaplacian {
265 let edges: Vec<(usize, usize, f64)> = (0..n - 1).map(|i| (i, i + 1, 1.0)).collect();
266 ScaledLaplacian::from_sparse_adjacency(&edges, n)
267 }
268
269 #[test]
270 fn test_wavelet_scale() {
271 let scale = WaveletScale::mexican_hat(0.5, 10);
272 assert_eq!(scale.scale, 0.5);
273 assert!(!scale.filter.coefficients.is_empty());
274 }
275
276 #[test]
277 fn test_graph_wavelet() {
278 let laplacian = path_graph_laplacian(10);
279 let scale = WaveletScale::mexican_hat(0.5, 10);
280
281 let wavelet = GraphWavelet::at_vertex(&laplacian, &scale, 5);
282
283 assert_eq!(wavelet.center, 5);
284 assert_eq!(wavelet.coefficients.len(), 10);
285 assert!(wavelet.coefficients[5].abs() > 0.0);
287 }
288
289 #[test]
290 fn test_wavelet_transform() {
291 let laplacian = path_graph_laplacian(20);
292 let transform = SpectralWaveletTransform::new(laplacian, 4, 10);
293
294 assert_eq!(transform.num_scales(), 4);
295
296 let signal: Vec<f64> = (0..20).map(|i| (i as f64 * 0.3).sin()).collect();
298 let (scaling, wavelets) = transform.forward(&signal);
299
300 assert_eq!(scaling.len(), 20);
301 assert_eq!(wavelets.len(), 4);
302 for w in &wavelets {
303 assert_eq!(w.len(), 20);
304 }
305 }
306
307 #[test]
308 fn test_scale_energies() {
309 let laplacian = path_graph_laplacian(20);
310 let transform = SpectralWaveletTransform::new(laplacian, 4, 10);
311
312 let signal: Vec<f64> = (0..20).map(|i| (i as f64 * 0.3).sin()).collect();
313 let energies = transform.scale_energies(&signal);
314
315 assert_eq!(energies.len(), 4);
316 for e in energies {
318 assert!(e >= 0.0);
319 }
320 }
321
322 #[test]
323 fn test_wavelets_at_vertex() {
324 let laplacian = path_graph_laplacian(10);
325 let transform = SpectralWaveletTransform::new(laplacian, 3, 8);
326
327 let wavelets = transform.wavelets_at(5);
328
329 assert_eq!(wavelets.len(), 3);
330 for w in &wavelets {
331 assert_eq!(w.center, 5);
332 }
333 }
334}