1use std::collections::VecDeque;
15
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::random::{Rng, RngExt};
18
19use crate::error::{GraphError, Result};
20use crate::gnn::gcn::CsrMatrix;
21
22#[derive(Debug, Clone)]
31pub struct CentralityEncoding {
32 pub in_degree_embed: Array2<f64>,
34 pub out_degree_embed: Array2<f64>,
36 pub max_degree: usize,
38 pub hidden_dim: usize,
40}
41
42impl CentralityEncoding {
43 pub fn new(max_degree: usize, hidden_dim: usize) -> Self {
49 let mut rng = scirs2_core::random::rng();
50 let scale = (1.0 / hidden_dim as f64).sqrt();
51
52 let in_degree_embed = Array2::from_shape_fn((max_degree + 1, hidden_dim), |_| {
53 (rng.random::<f64>() * 2.0 - 1.0) * scale
54 });
55 let out_degree_embed = Array2::from_shape_fn((max_degree + 1, hidden_dim), |_| {
56 (rng.random::<f64>() * 2.0 - 1.0) * scale
57 });
58
59 CentralityEncoding {
60 in_degree_embed,
61 out_degree_embed,
62 max_degree,
63 hidden_dim,
64 }
65 }
66
67 pub fn compute_degrees(&self, adj: &CsrMatrix) -> (Vec<usize>, Vec<usize>) {
69 let n = adj.n_rows;
70 let mut in_deg = vec![0usize; n];
71 let mut out_deg = vec![0usize; n];
72
73 for (row, col, _) in adj.iter() {
74 out_deg[row] += 1;
75 if col < n {
76 in_deg[col] += 1;
77 }
78 }
79
80 (in_deg, out_deg)
81 }
82
83 pub fn forward(&self, features: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
92 let (n, dim) = features.dim();
93 if dim != self.hidden_dim {
94 return Err(GraphError::InvalidParameter {
95 param: "features".to_string(),
96 value: format!("dim={dim}"),
97 expected: format!("dim={}", self.hidden_dim),
98 context: "CentralityEncoding::forward".to_string(),
99 });
100 }
101
102 let (in_deg, out_deg) = self.compute_degrees(adj);
103 let mut output = features.clone();
104
105 for i in 0..n {
106 let in_d = in_deg[i].min(self.max_degree);
107 let out_d = out_deg[i].min(self.max_degree);
108 for j in 0..dim {
109 output[[i, j]] +=
110 self.in_degree_embed[[in_d, j]] + self.out_degree_embed[[out_d, j]];
111 }
112 }
113
114 Ok(output)
115 }
116}
117
118#[derive(Debug, Clone)]
127pub struct SpatialEncoding {
128 pub spatial_bias: Array1<f64>,
131 pub max_distance: usize,
133}
134
135impl SpatialEncoding {
136 pub fn new(max_distance: usize) -> Self {
141 let mut rng = scirs2_core::random::rng();
142 let spatial_bias =
143 Array1::from_iter((0..=max_distance).map(|_| (rng.random::<f64>() * 2.0 - 1.0) * 0.1));
144
145 SpatialEncoding {
146 spatial_bias,
147 max_distance,
148 }
149 }
150
151 pub fn compute_spd_matrix(&self, adj: &CsrMatrix) -> Array2<usize> {
157 let n = adj.n_rows;
158 let unreachable = self.max_distance + 1;
159 let mut spd = Array2::from_elem((n, n), unreachable);
160
161 let mut adj_list: Vec<Vec<usize>> = vec![Vec::new(); n];
163 for (row, col, _) in adj.iter() {
164 adj_list[row].push(col);
165 }
166
167 for src in 0..n {
169 spd[[src, src]] = 0;
170 let mut queue = VecDeque::new();
171 queue.push_back(src);
172 let mut visited = vec![false; n];
173 visited[src] = true;
174
175 while let Some(u) = queue.pop_front() {
176 let dist = spd[[src, u]];
177 if dist >= self.max_distance {
178 continue;
179 }
180 for &v in &adj_list[u] {
181 if !visited[v] {
182 visited[v] = true;
183 spd[[src, v]] = dist + 1;
184 queue.push_back(v);
185 }
186 }
187 }
188 }
189
190 spd
191 }
192
193 pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
201 let spd = self.compute_spd_matrix(adj);
202 let n = adj.n_rows;
203 let mut bias = Array2::zeros((n, n));
204
205 for i in 0..n {
206 for j in 0..n {
207 let d = spd[[i, j]].min(self.max_distance);
208 bias[[i, j]] = self.spatial_bias[d];
209 }
210 }
211
212 bias
213 }
214}
215
216#[derive(Debug, Clone)]
226pub struct EdgeEncoding {
227 pub edge_embed: Array2<f64>,
229 pub projection: Array1<f64>,
231 pub max_edge_types: usize,
233 pub hidden_dim: usize,
235}
236
237impl EdgeEncoding {
238 pub fn new(max_edge_types: usize, hidden_dim: usize) -> Self {
244 let mut rng = scirs2_core::random::rng();
245 let scale = (1.0 / hidden_dim as f64).sqrt();
246
247 let edge_embed = Array2::from_shape_fn((max_edge_types, hidden_dim), |_| {
248 (rng.random::<f64>() * 2.0 - 1.0) * scale
249 });
250 let projection =
251 Array1::from_iter((0..hidden_dim).map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale));
252
253 EdgeEncoding {
254 edge_embed,
255 projection,
256 max_edge_types,
257 hidden_dim,
258 }
259 }
260
261 pub fn forward(&self, adj: &CsrMatrix, spd: &Array2<usize>) -> Array2<f64> {
274 let n = adj.n_rows;
275 let mut bias = Array2::zeros((n, n));
276
277 let mut adj_list: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n];
279 for (row, col, val) in adj.iter() {
280 let edge_type = (val.abs() as usize).min(self.max_edge_types - 1);
282 adj_list[row].push((col, edge_type));
283 }
284
285 for src in 0..n {
287 let mut parent: Vec<Option<(usize, usize)>> = vec![None; n]; let mut visited = vec![false; n];
290 visited[src] = true;
291 let mut queue = VecDeque::new();
292 queue.push_back(src);
293
294 while let Some(u) = queue.pop_front() {
295 for &(v, etype) in &adj_list[u] {
296 if !visited[v] {
297 visited[v] = true;
298 parent[v] = Some((u, etype));
299 queue.push_back(v);
300 }
301 }
302 }
303
304 for dst in 0..n {
306 if src == dst || spd[[src, dst]] == 0 {
307 continue;
308 }
309 if parent[dst].is_none() {
310 continue; }
312
313 let mut avg_embed = vec![0.0f64; self.hidden_dim];
315 let mut path_len = 0usize;
316 let mut cur = dst;
317
318 while let Some((p, etype)) = parent[cur] {
319 for k in 0..self.hidden_dim {
320 avg_embed[k] += self.edge_embed[[etype, k]];
321 }
322 path_len += 1;
323 cur = p;
324 if cur == src {
325 break;
326 }
327 }
328
329 if path_len > 0 {
330 let inv = 1.0 / path_len as f64;
331 let mut scalar = 0.0f64;
332 for k in 0..self.hidden_dim {
333 scalar += avg_embed[k] * inv * self.projection[k];
334 }
335 bias[[src, dst]] = scalar;
336 }
337 }
338 }
339
340 bias
341 }
342}
343
344fn softmax_row(row: &[f64]) -> Vec<f64> {
350 if row.is_empty() {
351 return Vec::new();
352 }
353 let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
354 let exps: Vec<f64> = row.iter().map(|x| (x - max_val).exp()).collect();
355 let sum = exps.iter().sum::<f64>().max(1e-12);
356 exps.iter().map(|e| e / sum).collect()
357}
358
359fn layer_norm(x: &mut [f64], eps: f64) {
361 let n = x.len();
362 if n == 0 {
363 return;
364 }
365 let mean = x.iter().sum::<f64>() / n as f64;
366 let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n as f64;
367 let inv_std = 1.0 / (var + eps).sqrt();
368 for v in x.iter_mut() {
369 *v = (*v - mean) * inv_std;
370 }
371}
372
373#[derive(Debug, Clone)]
379pub struct GraphormerConfig {
380 pub in_dim: usize,
382 pub hidden_dim: usize,
384 pub num_heads: usize,
386 pub num_layers: usize,
388 pub ffn_dim: usize,
390 pub max_distance: usize,
392 pub max_degree: usize,
394 pub max_edge_types: usize,
396 pub dropout: f64,
398 pub layer_norm_eps: f64,
400}
401
402impl Default for GraphormerConfig {
403 fn default() -> Self {
404 GraphormerConfig {
405 in_dim: 64,
406 hidden_dim: 64,
407 num_heads: 4,
408 num_layers: 3,
409 ffn_dim: 256,
410 max_distance: 10,
411 max_degree: 50,
412 max_edge_types: 4,
413 dropout: 0.1,
414 layer_norm_eps: 1e-5,
415 }
416 }
417}
418
419#[derive(Debug, Clone)]
431pub struct GraphormerLayer {
432 pub w_q: Array2<f64>,
434 pub w_k: Array2<f64>,
436 pub w_v: Array2<f64>,
438 pub w_o: Array2<f64>,
440 pub ffn_w1: Array2<f64>,
442 pub ffn_w2: Array2<f64>,
444 pub ffn_b1: Array1<f64>,
446 pub ffn_b2: Array1<f64>,
448 pub num_heads: usize,
450 pub hidden_dim: usize,
452 pub head_dim: usize,
454 pub layer_norm_eps: f64,
456}
457
458impl GraphormerLayer {
459 pub fn new(
461 hidden_dim: usize,
462 num_heads: usize,
463 ffn_dim: usize,
464 layer_norm_eps: f64,
465 ) -> Result<Self> {
466 if !hidden_dim.is_multiple_of(num_heads) {
467 return Err(GraphError::InvalidParameter {
468 param: "hidden_dim".to_string(),
469 value: format!("{hidden_dim}"),
470 expected: format!("divisible by num_heads={num_heads}"),
471 context: "GraphormerLayer::new".to_string(),
472 });
473 }
474
475 let head_dim = hidden_dim / num_heads;
476 let mut rng = scirs2_core::random::rng();
477 let w_scale = (6.0_f64 / (hidden_dim + hidden_dim) as f64).sqrt();
478 let ffn_scale = (6.0_f64 / (hidden_dim + ffn_dim) as f64).sqrt();
479
480 let mut init_w = |rows: usize, cols: usize, scale: f64| -> Array2<f64> {
481 Array2::from_shape_fn((rows, cols), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
482 };
483
484 Ok(GraphormerLayer {
485 w_q: init_w(hidden_dim, hidden_dim, w_scale),
486 w_k: init_w(hidden_dim, hidden_dim, w_scale),
487 w_v: init_w(hidden_dim, hidden_dim, w_scale),
488 w_o: init_w(hidden_dim, hidden_dim, w_scale),
489 ffn_w1: init_w(hidden_dim, ffn_dim, ffn_scale),
490 ffn_w2: init_w(ffn_dim, hidden_dim, ffn_scale),
491 ffn_b1: Array1::zeros(ffn_dim),
492 ffn_b2: Array1::zeros(hidden_dim),
493 num_heads,
494 hidden_dim,
495 head_dim,
496 layer_norm_eps,
497 })
498 }
499
500 fn multi_head_attention(
510 &self,
511 x: &Array2<f64>,
512 spatial_bias: &Array2<f64>,
513 edge_bias: &Array2<f64>,
514 ) -> Array2<f64> {
515 let n = x.dim().0;
516 let d = self.hidden_dim;
517 let h = self.num_heads;
518 let dk = self.head_dim;
519 let scale = 1.0 / (dk as f64).sqrt();
520
521 let mut q = Array2::zeros((n, d));
523 let mut k = Array2::zeros((n, d));
524 let mut v = Array2::zeros((n, d));
525
526 for i in 0..n {
527 for j in 0..d {
528 let mut sq = 0.0;
529 let mut sk = 0.0;
530 let mut sv = 0.0;
531 for m in 0..d {
532 let xi = x[[i, m]];
533 sq += xi * self.w_q[[m, j]];
534 sk += xi * self.w_k[[m, j]];
535 sv += xi * self.w_v[[m, j]];
536 }
537 q[[i, j]] = sq;
538 k[[i, j]] = sk;
539 v[[i, j]] = sv;
540 }
541 }
542
543 let mut output = Array2::<f64>::zeros((n, d));
545
546 for head in 0..h {
547 let offset = head * dk;
548
549 let mut scores = vec![vec![0.0f64; n]; n];
551 for i in 0..n {
552 for j in 0..n {
553 let mut dot = 0.0;
554 for m in 0..dk {
555 dot += q[[i, offset + m]] * k[[j, offset + m]];
556 }
557 scores[i][j] = dot * scale + spatial_bias[[i, j]] + edge_bias[[i, j]];
559 }
560 }
561
562 for i in 0..n {
564 let alphas = softmax_row(&scores[i]);
565 for j in 0..n {
566 let a = alphas[j];
567 for m in 0..dk {
568 output[[i, offset + m]] += a * v[[j, offset + m]];
569 }
570 }
571 }
572 }
573
574 let mut projected = Array2::zeros((n, d));
576 for i in 0..n {
577 for j in 0..d {
578 let mut s = 0.0;
579 for m in 0..d {
580 s += output[[i, m]] * self.w_o[[m, j]];
581 }
582 projected[[i, j]] = s;
583 }
584 }
585
586 projected
587 }
588
589 fn ffn(&self, x: &Array2<f64>) -> Array2<f64> {
591 let n = x.dim().0;
592 let ffn_dim = self.ffn_w1.dim().1;
593 let d = self.hidden_dim;
594
595 let mut h = Array2::zeros((n, ffn_dim));
597 for i in 0..n {
598 for j in 0..ffn_dim {
599 let mut s = self.ffn_b1[j];
600 for m in 0..d {
601 s += x[[i, m]] * self.ffn_w1[[m, j]];
602 }
603 let x3 = s * s * s;
605 let inner = std::f64::consts::FRAC_2_PI.sqrt() * (s + 0.044715 * x3);
606 h[[i, j]] = 0.5 * s * (1.0 + inner.tanh());
607 }
608 }
609
610 let mut out = Array2::zeros((n, d));
612 for i in 0..n {
613 for j in 0..d {
614 let mut s = self.ffn_b2[j];
615 for m in 0..ffn_dim {
616 s += h[[i, m]] * self.ffn_w2[[m, j]];
617 }
618 out[[i, j]] = s;
619 }
620 }
621
622 out
623 }
624
625 pub fn forward(
638 &self,
639 x: &Array2<f64>,
640 spatial_bias: &Array2<f64>,
641 edge_bias: &Array2<f64>,
642 ) -> Result<Array2<f64>> {
643 let (n, d) = x.dim();
644 if d != self.hidden_dim {
645 return Err(GraphError::InvalidParameter {
646 param: "x".to_string(),
647 value: format!("dim={d}"),
648 expected: format!("dim={}", self.hidden_dim),
649 context: "GraphormerLayer::forward".to_string(),
650 });
651 }
652
653 let mut normed = x.clone();
655 for i in 0..n {
656 let mut row: Vec<f64> = (0..d).map(|j| normed[[i, j]]).collect();
657 layer_norm(&mut row, self.layer_norm_eps);
658 for j in 0..d {
659 normed[[i, j]] = row[j];
660 }
661 }
662
663 let attn_out = self.multi_head_attention(&normed, spatial_bias, edge_bias);
665 let mut out = x.clone();
666 for i in 0..n {
667 for j in 0..d {
668 out[[i, j]] += attn_out[[i, j]];
669 }
670 }
671
672 let mut normed2 = out.clone();
674 for i in 0..n {
675 let mut row: Vec<f64> = (0..d).map(|j| normed2[[i, j]]).collect();
676 layer_norm(&mut row, self.layer_norm_eps);
677 for j in 0..d {
678 normed2[[i, j]] = row[j];
679 }
680 }
681
682 let ffn_out = self.ffn(&normed2);
684 for i in 0..n {
685 for j in 0..d {
686 out[[i, j]] += ffn_out[[i, j]];
687 }
688 }
689
690 Ok(out)
691 }
692}
693
694#[derive(Debug, Clone)]
701pub struct GraphormerModel {
702 pub input_proj: Array2<f64>,
704 pub centrality_encoding: CentralityEncoding,
706 pub spatial_encoding: SpatialEncoding,
708 pub edge_encoding: EdgeEncoding,
710 pub layers: Vec<GraphormerLayer>,
712 pub config: GraphormerConfig,
714}
715
716impl GraphormerModel {
717 pub fn new(config: GraphormerConfig) -> Result<Self> {
719 let mut rng = scirs2_core::random::rng();
720 let proj_scale = (6.0_f64 / (config.in_dim + config.hidden_dim) as f64).sqrt();
721 let input_proj = Array2::from_shape_fn((config.in_dim, config.hidden_dim), |_| {
722 (rng.random::<f64>() * 2.0 - 1.0) * proj_scale
723 });
724
725 let centrality_encoding = CentralityEncoding::new(config.max_degree, config.hidden_dim);
726 let spatial_encoding = SpatialEncoding::new(config.max_distance);
727 let edge_encoding = EdgeEncoding::new(config.max_edge_types, config.hidden_dim);
728
729 let mut layers = Vec::with_capacity(config.num_layers);
730 for _ in 0..config.num_layers {
731 layers.push(GraphormerLayer::new(
732 config.hidden_dim,
733 config.num_heads,
734 config.ffn_dim,
735 config.layer_norm_eps,
736 )?);
737 }
738
739 Ok(GraphormerModel {
740 input_proj,
741 centrality_encoding,
742 spatial_encoding,
743 edge_encoding,
744 layers,
745 config,
746 })
747 }
748
749 pub fn forward(&self, features: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
758 let (n, in_dim) = features.dim();
759 if in_dim != self.config.in_dim {
760 return Err(GraphError::InvalidParameter {
761 param: "features".to_string(),
762 value: format!("in_dim={in_dim}"),
763 expected: format!("in_dim={}", self.config.in_dim),
764 context: "GraphormerModel::forward".to_string(),
765 });
766 }
767 if adj.n_rows != n {
768 return Err(GraphError::InvalidParameter {
769 param: "adj".to_string(),
770 value: format!("n_rows={}", adj.n_rows),
771 expected: format!("n_rows={n}"),
772 context: "GraphormerModel::forward".to_string(),
773 });
774 }
775
776 let d = self.config.hidden_dim;
778 let mut h = Array2::zeros((n, d));
779 for i in 0..n {
780 for j in 0..d {
781 let mut s = 0.0;
782 for m in 0..in_dim {
783 s += features[[i, m]] * self.input_proj[[m, j]];
784 }
785 h[[i, j]] = s;
786 }
787 }
788
789 h = self.centrality_encoding.forward(&h, adj)?;
791
792 let spatial_bias = self.spatial_encoding.forward(adj);
794 let spd = self.spatial_encoding.compute_spd_matrix(adj);
795 let edge_bias = self.edge_encoding.forward(adj, &spd);
796
797 for layer in &self.layers {
799 h = layer.forward(&h, &spatial_bias, &edge_bias)?;
800 }
801
802 Ok(h)
803 }
804}
805
806#[cfg(test)]
811mod tests {
812 use super::*;
813
814 fn triangle_csr() -> CsrMatrix {
815 let coo = vec![
816 (0, 1, 1.0),
817 (1, 0, 1.0),
818 (1, 2, 1.0),
819 (2, 1, 1.0),
820 (0, 2, 1.0),
821 (2, 0, 1.0),
822 ];
823 CsrMatrix::from_coo(3, 3, &coo).expect("triangle CSR")
824 }
825
826 fn path_csr() -> CsrMatrix {
827 let coo = vec![
829 (0, 1, 1.0),
830 (1, 0, 1.0),
831 (1, 2, 1.0),
832 (2, 1, 1.0),
833 (2, 3, 1.0),
834 (3, 2, 1.0),
835 ];
836 CsrMatrix::from_coo(4, 4, &coo).expect("path CSR")
837 }
838
839 fn feats(n: usize, d: usize) -> Array2<f64> {
840 Array2::from_shape_fn((n, d), |(i, j)| (i * d + j) as f64 * 0.1)
841 }
842
843 #[test]
844 fn test_spatial_encoding_spd_matrix() {
845 let adj = path_csr();
846 let se = SpatialEncoding::new(10);
847 let spd = se.compute_spd_matrix(&adj);
848
849 for i in 0..4 {
851 assert_eq!(spd[[i, i]], 0, "self-distance should be 0 for node {i}");
852 }
853
854 assert_eq!(spd[[0, 1]], 1);
856 assert_eq!(spd[[1, 2]], 1);
857 assert_eq!(spd[[2, 3]], 1);
858
859 assert_eq!(spd[[0, 2]], 2);
861 assert_eq!(spd[[0, 3]], 3);
862 assert_eq!(spd[[1, 3]], 2);
863
864 for i in 0..4 {
866 for j in 0..4 {
867 assert_eq!(spd[[i, j]], spd[[j, i]], "SPD should be symmetric");
868 }
869 }
870 }
871
872 #[test]
873 fn test_centrality_encoding_degrees() {
874 let adj = triangle_csr();
875 let ce = CentralityEncoding::new(10, 8);
876 let (in_deg, out_deg) = ce.compute_degrees(&adj);
877
878 for i in 0..3 {
880 assert_eq!(in_deg[i], 2, "in-degree of node {i}");
881 assert_eq!(out_deg[i], 2, "out-degree of node {i}");
882 }
883 }
884
885 #[test]
886 fn test_centrality_encoding_forward_shape() {
887 let adj = triangle_csr();
888 let ce = CentralityEncoding::new(10, 8);
889 let features = feats(3, 8);
890 let result = ce.forward(&features, &adj).expect("centrality forward");
891 assert_eq!(result.dim(), (3, 8));
892
893 let mut differs = false;
895 for i in 0..3 {
896 for j in 0..8 {
897 if (result[[i, j]] - features[[i, j]]).abs() > 1e-12 {
898 differs = true;
899 }
900 }
901 }
902 assert!(differs, "centrality encoding should modify features");
903 }
904
905 #[test]
906 fn test_graphormer_attention_with_bias_output_shape() {
907 let adj = triangle_csr();
908 let config = GraphormerConfig {
909 in_dim: 4,
910 hidden_dim: 8,
911 num_heads: 2,
912 num_layers: 1,
913 ffn_dim: 16,
914 max_distance: 5,
915 max_degree: 10,
916 max_edge_types: 2,
917 ..Default::default()
918 };
919
920 let layer = GraphormerLayer::new(8, 2, 16, 1e-5).expect("layer");
921 let x = feats(3, 8);
922 let se = SpatialEncoding::new(5);
923 let spatial_bias = se.forward(&adj);
924 let edge_bias = Array2::zeros((3, 3));
925
926 let out = layer
927 .forward(&x, &spatial_bias, &edge_bias)
928 .expect("forward");
929 assert_eq!(out.dim(), (3, 8));
930 for &v in out.iter() {
931 assert!(v.is_finite(), "output should be finite, got {v}");
932 }
933 }
934
935 #[test]
936 fn test_graphormer_model_forward() {
937 let adj = triangle_csr();
938 let config = GraphormerConfig {
939 in_dim: 4,
940 hidden_dim: 8,
941 num_heads: 2,
942 num_layers: 2,
943 ffn_dim: 16,
944 max_distance: 5,
945 max_degree: 10,
946 max_edge_types: 2,
947 ..Default::default()
948 };
949
950 let model = GraphormerModel::new(config).expect("model");
951 let features = feats(3, 4);
952 let out = model.forward(&features, &adj).expect("forward");
953 assert_eq!(out.dim(), (3, 8));
954 for &v in out.iter() {
955 assert!(v.is_finite(), "output should be finite, got {v}");
956 }
957 }
958
959 #[test]
960 fn test_graphormer_edge_encoding() {
961 let adj = path_csr();
962 let se = SpatialEncoding::new(5);
963 let spd = se.compute_spd_matrix(&adj);
964 let ee = EdgeEncoding::new(2, 4);
965 let bias = ee.forward(&adj, &spd);
966
967 assert_eq!(bias.dim(), (4, 4));
968 for i in 0..4 {
970 assert!(bias[[i, i]].abs() < 1e-12, "self edge bias should be 0");
971 }
972 for &v in bias.iter() {
974 assert!(v.is_finite(), "edge bias should be finite");
975 }
976 }
977
978 #[test]
979 fn test_graphormer_invalid_hidden_dim() {
980 let result = GraphormerLayer::new(7, 2, 16, 1e-5);
982 assert!(result.is_err());
983 }
984
985 #[test]
986 fn test_spatial_encoding_disconnected() {
987 let coo = vec![(0, 1, 1.0), (1, 0, 1.0), (2, 3, 1.0), (3, 2, 1.0)];
989 let adj = CsrMatrix::from_coo(4, 4, &coo).expect("disconnected CSR");
990 let se = SpatialEncoding::new(5);
991 let spd = se.compute_spd_matrix(&adj);
992
993 assert_eq!(spd[[0, 1]], 1);
995 assert_eq!(spd[[2, 3]], 1);
996
997 assert_eq!(spd[[0, 2]], 6);
999 assert_eq!(spd[[0, 3]], 6);
1000 assert_eq!(spd[[1, 2]], 6);
1001 }
1002}