1#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use scirs2_core::ndarray::Array2;
20use torsh_tensor::{
21 creation::{from_vec, randn, zeros},
22 Tensor,
23};
24
25#[derive(Debug, Clone, Copy)]
27pub enum LaplacianType {
28 Unnormalized,
30 Symmetric,
32 RandomWalk,
34}
35
36pub struct SpectralGraphAnalysis;
38
39impl SpectralGraphAnalysis {
40 pub fn compute_laplacian(graph: &GraphData, laplacian_type: LaplacianType) -> Array2<f32> {
42 let num_nodes = graph.num_nodes;
43 let edge_data = graph
44 .edge_index
45 .to_vec()
46 .expect("conversion should succeed");
47
48 let mut adj = Array2::zeros((num_nodes, num_nodes));
50
51 for i in (0..edge_data.len()).step_by(2) {
52 if i + 1 < edge_data.len() {
53 let src = edge_data[i] as usize;
54 let dst = edge_data[i + 1] as usize;
55
56 if src < num_nodes && dst < num_nodes {
57 adj[[src, dst]] = 1.0;
58 adj[[dst, src]] = 1.0; }
60 }
61 }
62
63 let mut degrees = vec![0.0; num_nodes];
65 for i in 0..num_nodes {
66 for j in 0..num_nodes {
67 degrees[i] += adj[[i, j]];
68 }
69 }
70
71 match laplacian_type {
73 LaplacianType::Unnormalized => {
74 let mut laplacian = Array2::zeros((num_nodes, num_nodes));
75 for i in 0..num_nodes {
76 laplacian[[i, i]] = degrees[i];
77 for j in 0..num_nodes {
78 laplacian[[i, j]] -= adj[[i, j]];
79 }
80 }
81 laplacian
82 }
83 LaplacianType::Symmetric => {
84 let mut laplacian = Array2::zeros((num_nodes, num_nodes));
85
86 let mut d_inv_sqrt = vec![0.0; num_nodes];
88 for i in 0..num_nodes {
89 d_inv_sqrt[i] = if degrees[i] > 0.0 {
90 1.0 / degrees[i].sqrt()
91 } else {
92 0.0
93 };
94 }
95
96 for i in 0..num_nodes {
98 laplacian[[i, i]] = 1.0;
99 for j in 0..num_nodes {
100 laplacian[[i, j]] -= d_inv_sqrt[i] * adj[[i, j]] * d_inv_sqrt[j];
101 }
102 }
103 laplacian
104 }
105 LaplacianType::RandomWalk => {
106 let mut laplacian = Array2::zeros((num_nodes, num_nodes));
107
108 let mut d_inv = vec![0.0; num_nodes];
110 for i in 0..num_nodes {
111 d_inv[i] = if degrees[i] > 0.0 {
112 1.0 / degrees[i]
113 } else {
114 0.0
115 };
116 }
117
118 for i in 0..num_nodes {
120 laplacian[[i, i]] = 1.0;
121 for j in 0..num_nodes {
122 laplacian[[i, j]] -= d_inv[i] * adj[[i, j]];
123 }
124 }
125 laplacian
126 }
127 }
128 }
129
130 pub fn spectral_embedding(graph: &GraphData, num_components: usize) -> Tensor {
132 let laplacian = Self::compute_laplacian(graph, LaplacianType::Symmetric);
133 let num_nodes = graph.num_nodes;
134
135 let mut embeddings = Vec::new();
138
139 for _comp in 0..num_components {
140 let mut v = vec![0.0; num_nodes];
142 let mut rng = scirs2_core::random::thread_rng();
143 for val in v.iter_mut() {
144 *val = rng.gen_range(-0.5..0.5);
145 }
146
147 for _ in 0..50 {
149 let mut new_v = vec![0.0; num_nodes];
150
151 for i in 0..num_nodes {
152 for j in 0..num_nodes {
153 new_v[i] += laplacian[[i, j]] * v[j];
154 }
155 }
156
157 let norm: f32 = new_v.iter().map(|x| x * x).sum::<f32>().sqrt();
159 if norm > 0.0 {
160 for val in new_v.iter_mut() {
161 *val /= norm;
162 }
163 }
164
165 v = new_v;
166 }
167
168 embeddings.extend(v);
169 }
170
171 from_vec(
172 embeddings,
173 &[num_nodes, num_components],
174 torsh_core::device::DeviceType::Cpu,
175 )
176 .expect("from_vec embeddings should succeed")
177 }
178
179 pub fn compute_spectrum(graph: &GraphData, num_eigenvalues: usize) -> Vec<f32> {
181 let _laplacian = Self::compute_laplacian(graph, LaplacianType::Symmetric);
182 let num_nodes = graph.num_nodes;
183
184 let mut eigenvalues = Vec::new();
187
188 for k in 0..num_eigenvalues.min(num_nodes) {
189 let lambda =
190 2.0 * (1.0 - ((k as f32 * std::f32::consts::PI) / (num_nodes as f32)).cos());
191 eigenvalues.push(lambda);
192 }
193
194 eigenvalues
195 }
196
197 pub fn spectral_clustering(graph: &GraphData, num_clusters: usize) -> Vec<usize> {
199 let num_nodes = graph.num_nodes;
200
201 let embedding = Self::spectral_embedding(graph, num_clusters);
203 let embedding_data = embedding.to_vec().expect("conversion should succeed");
204
205 let mut labels = vec![0; num_nodes];
207 let mut centroids = vec![vec![0.0; num_clusters]; num_clusters];
208
209 let mut rng = scirs2_core::random::thread_rng();
211 for k in 0..num_clusters {
212 let idx = rng.gen_range(0..num_nodes);
213 for d in 0..num_clusters {
214 centroids[k][d] = embedding_data[idx * num_clusters + d];
215 }
216 }
217
218 for _ in 0..100 {
220 for i in 0..num_nodes {
222 let mut min_dist = f32::MAX;
223 let mut best_cluster = 0;
224
225 for k in 0..num_clusters {
226 let mut dist = 0.0;
227 for d in 0..num_clusters {
228 let diff = embedding_data[i * num_clusters + d] - centroids[k][d];
229 dist += diff * diff;
230 }
231
232 if dist < min_dist {
233 min_dist = dist;
234 best_cluster = k;
235 }
236 }
237
238 labels[i] = best_cluster;
239 }
240
241 let mut counts = vec![0; num_clusters];
243 let mut new_centroids = vec![vec![0.0; num_clusters]; num_clusters];
244
245 for i in 0..num_nodes {
246 let cluster = labels[i];
247 counts[cluster] += 1;
248
249 for d in 0..num_clusters {
250 new_centroids[cluster][d] += embedding_data[i * num_clusters + d];
251 }
252 }
253
254 for k in 0..num_clusters {
255 if counts[k] > 0 {
256 for d in 0..num_clusters {
257 new_centroids[k][d] /= counts[k] as f32;
258 }
259 }
260 }
261
262 centroids = new_centroids;
263 }
264
265 labels
266 }
267}
268
269#[derive(Debug)]
271pub struct ChebConv {
272 in_features: usize,
273 out_features: usize,
274 k: usize, weights: Vec<Parameter>,
278
279 bias: Option<Parameter>,
280}
281
282impl ChebConv {
283 pub fn new(in_features: usize, out_features: usize, k: usize, use_bias: bool) -> Self {
285 let mut weights = Vec::new();
286
287 for _ in 0..k {
288 weights.push(Parameter::new(
289 randn(&[in_features, out_features]).expect("randn weights should succeed"),
290 ));
291 }
292
293 let bias = if use_bias {
294 Some(Parameter::new(
295 zeros(&[out_features]).expect("zeros bias should succeed"),
296 ))
297 } else {
298 None
299 };
300
301 Self {
302 in_features,
303 out_features,
304 k,
305 weights,
306 bias,
307 }
308 }
309
310 pub fn forward(&self, graph: &GraphData) -> GraphData {
312 let num_nodes = graph.num_nodes;
313
314 let laplacian = SpectralGraphAnalysis::compute_laplacian(graph, LaplacianType::Symmetric);
316
317 let lap_data: Vec<f32> = laplacian.iter().copied().collect();
319 let lap_tensor = from_vec(
320 lap_data,
321 &[num_nodes, num_nodes],
322 torsh_core::device::DeviceType::Cpu,
323 )
324 .expect("from_vec laplacian should succeed");
325
326 let mut chebyshev_polynomials = Vec::new();
328
329 chebyshev_polynomials.push(graph.x.clone());
331
332 if self.k > 1 {
334 let t1 = lap_tensor
335 .matmul(&graph.x)
336 .expect("operation should succeed");
337 chebyshev_polynomials.push(t1);
338 }
339
340 for i in 2..self.k {
342 let term1 = lap_tensor
343 .matmul(&chebyshev_polynomials[i - 1])
344 .expect("operation should succeed");
345 let term1_scaled = term1.mul_scalar(2.0).expect("operation should succeed");
346 let t_k = term1_scaled
347 .sub(&chebyshev_polynomials[i - 2])
348 .expect("operation should succeed");
349 chebyshev_polynomials.push(t_k);
350 }
351
352 let mut output =
354 zeros::<f32>(&[num_nodes, self.out_features]).expect("zeros output should succeed");
355
356 for (i, t_k) in chebyshev_polynomials.iter().enumerate().take(self.k) {
357 let weighted = t_k
358 .matmul(&self.weights[i].clone_data())
359 .expect("operation should succeed");
360 output = output.add(&weighted).expect("operation should succeed");
361 }
362
363 if let Some(ref bias) = self.bias {
365 output = output
366 .add(&bias.clone_data())
367 .expect("operation should succeed");
368 }
369
370 let mut output_graph = graph.clone();
371 output_graph.x = output;
372 output_graph
373 }
374}
375
376impl GraphLayer for ChebConv {
377 fn forward(&self, graph: &GraphData) -> GraphData {
378 self.forward(graph)
379 }
380
381 fn parameters(&self) -> Vec<Tensor> {
382 let mut params: Vec<_> = self.weights.iter().map(|w| w.clone_data()).collect();
383
384 if let Some(ref bias) = self.bias {
385 params.push(bias.clone_data());
386 }
387
388 params
389 }
390}
391
392#[derive(Debug)]
394pub struct SpectralConv {
395 in_features: usize,
396 out_features: usize,
397 num_filters: usize,
398
399 spectral_weights: Parameter,
401
402 spatial_weight: Parameter,
404
405 bias: Option<Parameter>,
406}
407
408impl SpectralConv {
409 pub fn new(
411 in_features: usize,
412 out_features: usize,
413 num_filters: usize,
414 use_bias: bool,
415 ) -> Self {
416 let spectral_weights = Parameter::new(
417 randn(&[num_filters, in_features]).expect("randn spectral_weights should succeed"),
418 );
419 let spatial_weight = Parameter::new(
420 randn(&[in_features, out_features]).expect("randn spatial_weight should succeed"),
421 );
422
423 let bias = if use_bias {
424 Some(Parameter::new(
425 zeros(&[out_features]).expect("zeros bias should succeed"),
426 ))
427 } else {
428 None
429 };
430
431 Self {
432 in_features,
433 out_features,
434 num_filters,
435 spectral_weights,
436 spatial_weight,
437 bias,
438 }
439 }
440
441 pub fn forward(&self, graph: &GraphData) -> GraphData {
443 let _num_nodes = graph.num_nodes;
444
445 let spectral_features = SpectralGraphAnalysis::spectral_embedding(graph, self.num_filters);
447
448 let filtered = spectral_features
452 .matmul(&self.spectral_weights.clone_data())
453 .expect("operation should succeed");
454
455 let combined = filtered.add(&graph.x).expect("operation should succeed");
457
458 let mut output = combined
460 .matmul(&self.spatial_weight.clone_data())
461 .expect("operation should succeed");
462
463 if let Some(ref bias) = self.bias {
465 output = output
466 .add(&bias.clone_data())
467 .expect("operation should succeed");
468 }
469
470 let mut output_graph = graph.clone();
471 output_graph.x = output;
472 output_graph
473 }
474}
475
476impl GraphLayer for SpectralConv {
477 fn forward(&self, graph: &GraphData) -> GraphData {
478 self.forward(graph)
479 }
480
481 fn parameters(&self) -> Vec<Tensor> {
482 let mut params = vec![
483 self.spectral_weights.clone_data(),
484 self.spatial_weight.clone_data(),
485 ];
486
487 if let Some(ref bias) = self.bias {
488 params.push(bias.clone_data());
489 }
490
491 params
492 }
493}
494
495pub struct GraphSignalProcessing;
497
498impl GraphSignalProcessing {
499 pub fn graph_fourier_transform(graph: &GraphData, signal: &Tensor) -> Tensor {
501 let num_nodes = graph.num_nodes;
503 let embedding = SpectralGraphAnalysis::spectral_embedding(graph, num_nodes);
504
505 embedding
507 .t()
508 .expect("operation should succeed")
509 .matmul(signal)
510 .expect("operation should succeed")
511 }
512
513 pub fn inverse_graph_fourier_transform(graph: &GraphData, spectral_signal: &Tensor) -> Tensor {
515 let num_nodes = graph.num_nodes;
516 let embedding = SpectralGraphAnalysis::spectral_embedding(graph, num_nodes);
517
518 embedding
520 .matmul(spectral_signal)
521 .expect("operation should succeed")
522 }
523
524 pub fn low_pass_filter(graph: &GraphData, signal: &Tensor, cutoff: usize) -> Tensor {
526 let spectral = Self::graph_fourier_transform(graph, signal);
528
529 let mut filtered_data = spectral.to_vec().expect("conversion should succeed");
531 let _signal_dim = signal.shape().dims()[1];
532
533 for i in cutoff..filtered_data.len() {
534 filtered_data[i] = 0.0;
535 }
536
537 let filtered_spectral = from_vec(
538 filtered_data,
539 spectral.shape().dims(),
540 torsh_core::device::DeviceType::Cpu,
541 )
542 .expect("from_vec filtered_spectral should succeed");
543
544 Self::inverse_graph_fourier_transform(graph, &filtered_spectral)
546 }
547
548 pub fn high_pass_filter(graph: &GraphData, signal: &Tensor, cutoff: usize) -> Tensor {
550 let spectral = Self::graph_fourier_transform(graph, signal);
552
553 let mut filtered_data = spectral.to_vec().expect("conversion should succeed");
555
556 for i in 0..cutoff.min(filtered_data.len()) {
557 filtered_data[i] = 0.0;
558 }
559
560 let filtered_spectral = from_vec(
561 filtered_data,
562 spectral.shape().dims(),
563 torsh_core::device::DeviceType::Cpu,
564 )
565 .expect("from_vec filtered_spectral should succeed");
566
567 Self::inverse_graph_fourier_transform(graph, &filtered_spectral)
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use torsh_core::device::DeviceType;
576
577 #[test]
578 fn test_laplacian_computation() {
579 let features = randn(&[4, 3]).unwrap();
580 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0];
581 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
582 let graph = GraphData::new(features, edge_index);
583
584 let laplacian = SpectralGraphAnalysis::compute_laplacian(&graph, LaplacianType::Symmetric);
585
586 assert_eq!(laplacian.shape(), [4, 4]);
587 }
588
589 #[test]
590 fn test_spectral_embedding() {
591 let features = randn(&[5, 3]).unwrap();
592 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
593 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
594 let graph = GraphData::new(features, edge_index);
595
596 let embedding = SpectralGraphAnalysis::spectral_embedding(&graph, 3);
597
598 assert_eq!(embedding.shape().dims(), &[5, 3]);
599 }
600
601 #[test]
602 fn test_spectral_clustering() {
603 let features = randn(&[6, 2]).unwrap();
604 let edges = vec![
605 0.0, 1.0, 1.0, 2.0, 3.0, 4.0, 4.0, 5.0, ];
608 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
609 let graph = GraphData::new(features, edge_index);
610
611 let labels = SpectralGraphAnalysis::spectral_clustering(&graph, 2);
612
613 assert_eq!(labels.len(), 6);
614 }
615
616 #[test]
617 fn test_cheb_conv() {
618 let features = randn(&[4, 6]).unwrap();
619 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
620 let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
621 let graph = GraphData::new(features, edge_index);
622
623 let cheb = ChebConv::new(6, 8, 3, true);
624 let output = cheb.forward(&graph);
625
626 assert_eq!(output.x.shape().dims(), &[4, 8]);
627 }
628
629 #[test]
630 fn test_spectral_conv() {
631 let features = randn(&[5, 4]).unwrap();
632 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
633 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
634 let graph = GraphData::new(features, edge_index);
635
636 let spec_conv = SpectralConv::new(4, 6, 3, true);
637 let output = spec_conv.forward(&graph);
638
639 assert_eq!(output.x.shape().dims(), &[5, 6]);
640 }
641
642 #[test]
643 fn test_graph_fourier_transform() {
644 let features = randn(&[4, 3]).unwrap();
645 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
646 let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
647 let graph = GraphData::new(features.clone(), edge_index);
648
649 let spectral = GraphSignalProcessing::graph_fourier_transform(&graph, &features);
650 let reconstructed =
651 GraphSignalProcessing::inverse_graph_fourier_transform(&graph, &spectral);
652
653 assert_eq!(reconstructed.shape().dims(), features.shape().dims());
654 }
655
656 #[test]
657 fn test_low_pass_filter() {
658 let features = randn(&[5, 4]).unwrap();
659 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
660 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
661 let graph = GraphData::new(features.clone(), edge_index);
662
663 let filtered = GraphSignalProcessing::low_pass_filter(&graph, &features, 2);
664
665 assert_eq!(filtered.shape().dims(), features.shape().dims());
666 }
667}