1use scirs2_core::ndarray::{Array1, Array2};
41
42use crate::error::{GraphError, Result};
43use crate::spectral_graph::graph_laplacian;
44
45fn symmetric_eigen(a: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
55 let n = a.nrows();
56 if n == 0 {
57 return Err(GraphError::InvalidGraph("empty matrix".into()));
58 }
59 if a.ncols() != n {
60 return Err(GraphError::InvalidGraph("matrix must be square".into()));
61 }
62
63 let mut m = a.clone();
65 let mut v = Array2::<f64>::eye(n);
67
68 const MAX_SWEEPS: usize = 500;
69 const TOL: f64 = 1e-12;
70
71 for _ in 0..MAX_SWEEPS {
72 let mut max_val = 0.0_f64;
74 let mut p = 0_usize;
75 let mut q = 1_usize;
76 for i in 0..n {
77 for j in (i + 1)..n {
78 let v_ij = m[[i, j]].abs();
79 if v_ij > max_val {
80 max_val = v_ij;
81 p = i;
82 q = j;
83 }
84 }
85 }
86 if max_val < TOL {
87 break;
88 }
89
90 let theta = if (m[[q, q]] - m[[p, p]]).abs() < TOL {
92 std::f64::consts::FRAC_PI_4
93 } else {
94 0.5 * ((2.0 * m[[p, q]]) / (m[[q, q]] - m[[p, p]])).atan()
95 };
96 let cos_t = theta.cos();
97 let sin_t = theta.sin();
98
99 let mut new_m = m.clone();
102 for r in 0..n {
103 if r != p && r != q {
104 new_m[[r, p]] = cos_t * m[[r, p]] - sin_t * m[[r, q]];
105 new_m[[p, r]] = new_m[[r, p]];
106 new_m[[r, q]] = sin_t * m[[r, p]] + cos_t * m[[r, q]];
107 new_m[[q, r]] = new_m[[r, q]];
108 }
109 }
110 new_m[[p, p]] =
111 cos_t * cos_t * m[[p, p]] - 2.0 * sin_t * cos_t * m[[p, q]] + sin_t * sin_t * m[[q, q]];
112 new_m[[q, q]] =
113 sin_t * sin_t * m[[p, p]] + 2.0 * sin_t * cos_t * m[[p, q]] + cos_t * cos_t * m[[q, q]];
114 new_m[[p, q]] = 0.0;
115 new_m[[q, p]] = 0.0;
116 m = new_m;
117
118 let v_old = v.clone();
120 for r in 0..n {
121 v[[r, p]] = cos_t * v_old[[r, p]] - sin_t * v_old[[r, q]];
122 v[[r, q]] = sin_t * v_old[[r, p]] + cos_t * v_old[[r, q]];
123 }
124 }
125
126 let eigenvalues = Array1::from_iter((0..n).map(|i| m[[i, i]]));
128
129 let mut idx: Vec<usize> = (0..n).collect();
131 idx.sort_by(|&a, &b| {
132 eigenvalues[a]
133 .partial_cmp(&eigenvalues[b])
134 .unwrap_or(std::cmp::Ordering::Equal)
135 });
136
137 let sorted_evals = Array1::from_iter(idx.iter().map(|&i| eigenvalues[i]));
138 let mut sorted_evecs = Array2::<f64>::zeros((n, n));
139 for (new_col, &old_col) in idx.iter().enumerate() {
140 for row in 0..n {
141 sorted_evecs[[row, new_col]] = v[[row, old_col]];
142 }
143 }
144
145 Ok((sorted_evals, sorted_evecs))
146}
147
148#[derive(Debug, Clone)]
163pub struct GraphFourierTransform {
164 pub eigenvalues: Array1<f64>,
166 pub eigenvectors: Array2<f64>,
168}
169
170impl GraphFourierTransform {
171 pub fn from_adjacency(adj: &Array2<f64>) -> Result<Self> {
175 let n = adj.nrows();
176 if n == 0 {
177 return Err(GraphError::InvalidGraph("empty adjacency matrix".into()));
178 }
179 let lap = graph_laplacian(adj);
180 let (eigenvalues, eigenvectors) = symmetric_eigen(&lap)?;
181 Ok(Self {
182 eigenvalues,
183 eigenvectors,
184 })
185 }
186
187 pub fn from_laplacian(laplacian: &Array2<f64>) -> Result<Self> {
189 let (eigenvalues, eigenvectors) = symmetric_eigen(laplacian)?;
190 Ok(Self {
191 eigenvalues,
192 eigenvectors,
193 })
194 }
195
196 pub fn num_nodes(&self) -> usize {
198 self.eigenvalues.len()
199 }
200
201 pub fn transform(&self, signal: &Array1<f64>) -> Result<Array1<f64>> {
209 let n = self.num_nodes();
210 if signal.len() != n {
211 return Err(GraphError::InvalidParameter {
212 param: "signal.len()".into(),
213 value: signal.len().to_string(),
214 expected: n.to_string(),
215 context: "GFT forward transform".into(),
216 });
217 }
218 let mut x_hat = Array1::<f64>::zeros(n);
220 for k in 0..n {
221 let mut acc = 0.0_f64;
222 for i in 0..n {
223 acc += self.eigenvectors[[i, k]] * signal[i];
224 }
225 x_hat[k] = acc;
226 }
227 Ok(x_hat)
228 }
229
230 pub fn inverse(&self, freq_signal: &Array1<f64>) -> Result<Array1<f64>> {
238 let n = self.num_nodes();
239 if freq_signal.len() != n {
240 return Err(GraphError::InvalidParameter {
241 param: "freq_signal.len()".into(),
242 value: freq_signal.len().to_string(),
243 expected: n.to_string(),
244 context: "GFT inverse transform".into(),
245 });
246 }
247 let mut x = Array1::<f64>::zeros(n);
249 for i in 0..n {
250 let mut acc = 0.0_f64;
251 for k in 0..n {
252 acc += self.eigenvectors[[i, k]] * freq_signal[k];
253 }
254 x[i] = acc;
255 }
256 Ok(x)
257 }
258}
259
260pub trait GraphFilter {
269 fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>>;
271
272 fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64>;
274}
275
276#[derive(Debug, Clone)]
286pub struct IdealLowPass {
287 pub k: usize,
289}
290
291impl IdealLowPass {
292 pub fn new(k: usize) -> Self {
294 Self { k }
295 }
296}
297
298impl GraphFilter for IdealLowPass {
299 fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
300 let n = gft.num_nodes();
301 Array1::from_iter((0..n).map(|i| if i < self.k { 1.0 } else { 0.0 }))
302 }
303
304 fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
305 let x_hat = gft.transform(signal)?;
306 let h = self.frequency_response(gft);
307 let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
308 gft.inverse(&filtered_hat)
309 }
310}
311
312#[derive(Debug, Clone)]
322pub struct IdealHighPass {
323 pub k: usize,
325}
326
327impl IdealHighPass {
328 pub fn new(k: usize) -> Self {
330 Self { k }
331 }
332}
333
334impl GraphFilter for IdealHighPass {
335 fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
336 let n = gft.num_nodes();
337 Array1::from_iter((0..n).map(|i| if i < self.k { 0.0 } else { 1.0 }))
338 }
339
340 fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
341 let x_hat = gft.transform(signal)?;
342 let h = self.frequency_response(gft);
343 let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
344 gft.inverse(&filtered_hat)
345 }
346}
347
348#[derive(Debug, Clone)]
357pub struct GraphBandpass {
358 pub low_k: usize,
360 pub high_k: usize,
362}
363
364impl GraphBandpass {
365 pub fn new(low_k: usize, high_k: usize) -> Self {
367 Self { low_k, high_k }
368 }
369}
370
371impl GraphFilter for GraphBandpass {
372 fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
373 let n = gft.num_nodes();
374 Array1::from_iter((0..n).map(|i| {
375 if i >= self.low_k && i < self.high_k {
376 1.0
377 } else {
378 0.0
379 }
380 }))
381 }
382
383 fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
384 let x_hat = gft.transform(signal)?;
385 let h = self.frequency_response(gft);
386 let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
387 gft.inverse(&filtered_hat)
388 }
389}
390
391#[derive(Debug, Clone)]
406pub struct GraphWavelet {
407 pub scale: f64,
409 kernel: Array2<f64>,
411}
412
413impl GraphWavelet {
414 pub fn new(gft: &GraphFourierTransform, scale: f64) -> Result<Self> {
420 if scale <= 0.0 {
421 return Err(GraphError::InvalidParameter {
422 param: "scale".into(),
423 value: scale.to_string(),
424 expected: "> 0".into(),
425 context: "GraphWavelet construction".into(),
426 });
427 }
428 let n = gft.num_nodes();
429 let h: Vec<f64> = gft
431 .eigenvalues
432 .iter()
433 .map(|&lam| (-scale * lam).exp())
434 .collect();
435
436 let mut kernel = Array2::<f64>::zeros((n, n));
438 for i in 0..n {
439 for j in 0..n {
440 let mut acc = 0.0_f64;
441 for k in 0..n {
442 acc += gft.eigenvectors[[i, k]] * h[k] * gft.eigenvectors[[j, k]];
443 }
444 kernel[[i, j]] = acc;
445 }
446 }
447 Ok(Self { scale, kernel })
448 }
449
450 pub fn apply(&self, signal: &Array1<f64>) -> Result<Array1<f64>> {
452 let n = self.kernel.nrows();
453 if signal.len() != n {
454 return Err(GraphError::InvalidParameter {
455 param: "signal.len()".into(),
456 value: signal.len().to_string(),
457 expected: n.to_string(),
458 context: "GraphWavelet apply".into(),
459 });
460 }
461 let mut out = Array1::<f64>::zeros(n);
462 for i in 0..n {
463 let mut acc = 0.0_f64;
464 for j in 0..n {
465 acc += self.kernel[[i, j]] * signal[j];
466 }
467 out[i] = acc;
468 }
469 Ok(out)
470 }
471
472 pub fn wavelet_atom(&self, s: usize) -> Result<Array1<f64>> {
474 let n = self.kernel.nrows();
475 if s >= n {
476 return Err(GraphError::InvalidParameter {
477 param: "s".into(),
478 value: s.to_string(),
479 expected: format!("< {n}"),
480 context: "GraphWavelet atom".into(),
481 });
482 }
483 Ok(self.kernel.column(s).to_owned())
484 }
485
486 pub fn kernel(&self) -> &Array2<f64> {
488 &self.kernel
489 }
490}
491
492#[derive(Debug, Clone)]
512pub struct GraphSignalSmoother {
513 pub alpha: f64,
515}
516
517impl GraphSignalSmoother {
518 pub fn new(alpha: f64) -> Result<Self> {
520 if alpha <= 0.0 {
521 return Err(GraphError::InvalidParameter {
522 param: "alpha".into(),
523 value: alpha.to_string(),
524 expected: "> 0".into(),
525 context: "GraphSignalSmoother construction".into(),
526 });
527 }
528 Ok(Self { alpha })
529 }
530
531 pub fn smooth(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
535 let n = gft.num_nodes();
536 if signal.len() != n {
537 return Err(GraphError::InvalidParameter {
538 param: "signal.len()".into(),
539 value: signal.len().to_string(),
540 expected: n.to_string(),
541 context: "GraphSignalSmoother smooth".into(),
542 });
543 }
544 let y_hat = gft.transform(signal)?;
545 let x_hat = Array1::from_iter(
547 y_hat
548 .iter()
549 .zip(gft.eigenvalues.iter())
550 .map(|(&c, &lam)| c / (1.0 + self.alpha * lam)),
551 );
552 gft.inverse(&x_hat)
553 }
554
555 pub fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
558 Array1::from_iter(
559 gft.eigenvalues
560 .iter()
561 .map(|&lam| 1.0 / (1.0 + self.alpha * lam)),
562 )
563 }
564
565 pub fn total_variation(adj: &Array2<f64>, signal: &Array1<f64>) -> Result<f64> {
568 let n = adj.nrows();
569 if signal.len() != n {
570 return Err(GraphError::InvalidParameter {
571 param: "signal.len()".into(),
572 value: signal.len().to_string(),
573 expected: n.to_string(),
574 context: "total_variation".into(),
575 });
576 }
577 let mut tv = 0.0_f64;
578 for i in 0..n {
579 for j in (i + 1)..n {
580 let w = adj[[i, j]];
581 if w != 0.0 {
582 let diff = signal[i] - signal[j];
583 tv += w * diff * diff;
584 }
585 }
586 }
587 Ok(tv)
588 }
589}
590
591#[cfg(test)]
596mod tests {
597 use super::*;
598 use scirs2_core::ndarray::Array1;
599
600 fn path_graph_adj(n: usize) -> Array2<f64> {
601 let mut adj = Array2::<f64>::zeros((n, n));
602 for i in 0..(n - 1) {
603 adj[[i, i + 1]] = 1.0;
604 adj[[i + 1, i]] = 1.0;
605 }
606 adj
607 }
608
609 #[test]
610 fn test_gft_reconstruction() {
611 let adj = path_graph_adj(5);
612 let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
613 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
614 let freq = gft.transform(&signal).unwrap();
615 let rec = gft.inverse(&freq).unwrap();
616 for (a, b) in signal.iter().zip(rec.iter()) {
617 assert!((a - b).abs() < 1e-9, "Reconstruction error: {a} vs {b}");
618 }
619 }
620
621 #[test]
622 fn test_low_pass_smoothing() {
623 let adj = path_graph_adj(6);
624 let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
625 let signal = Array1::from_vec(vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0]);
626 let lp = IdealLowPass::new(2);
627 let smoothed = lp.apply(&gft, &signal).unwrap();
628 let tv_orig = GraphSignalSmoother::total_variation(&adj, &signal).unwrap();
630 let tv_smooth = GraphSignalSmoother::total_variation(&adj, &smoothed).unwrap();
631 assert!(
632 tv_smooth < tv_orig,
633 "LP filter should reduce TV: {tv_smooth} vs {tv_orig}"
634 );
635 }
636
637 #[test]
638 fn test_high_pass_removes_dc() {
639 let adj = path_graph_adj(5);
640 let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
641 let dc_signal = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
643 let hp = IdealHighPass::new(1);
644 let out = hp.apply(&gft, &dc_signal).unwrap();
645 for v in out.iter() {
646 assert!(v.abs() < 1e-9, "HP filter should remove DC: got {v}");
647 }
648 }
649
650 #[test]
651 fn test_bandpass() {
652 let adj = path_graph_adj(8);
653 let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
654 let signal = Array1::from_vec(vec![1.0, 0.5, 0.0, -0.5, -1.0, -0.5, 0.0, 0.5]);
655 let bp = GraphBandpass::new(2, 5);
656 let out = bp.apply(&gft, &signal).unwrap();
657 assert_eq!(out.len(), 8);
658 }
659
660 #[test]
661 fn test_wavelet_kernel_symmetry() {
662 let adj = path_graph_adj(5);
663 let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
664 let wv = GraphWavelet::new(&gft, 1.0).unwrap();
665 let k = wv.kernel();
666 for i in 0..5 {
667 for j in 0..5 {
668 assert!((k[[i, j]] - k[[j, i]]).abs() < 1e-10);
669 }
670 }
671 }
672
673 #[test]
674 fn test_smoother_reduces_variation() {
675 let adj = path_graph_adj(6);
676 let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
677 let noisy = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
678 let smoother = GraphSignalSmoother::new(5.0).unwrap();
679 let smoothed = smoother.smooth(&gft, &noisy).unwrap();
680 let tv_noisy = GraphSignalSmoother::total_variation(&adj, &noisy).unwrap();
681 let tv_smooth = GraphSignalSmoother::total_variation(&adj, &smoothed).unwrap();
682 assert!(tv_smooth < tv_noisy);
683 }
684}