1use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::{Rng, RngExt};
15
16use crate::error::{GraphError, Result};
17use crate::gnn::gcn::CsrMatrix;
18
19#[derive(Debug, Clone)]
29pub struct RandomWalkPe {
30 pub walk_length: usize,
32 pub projection: Array2<f64>,
34 pub pe_dim: usize,
36}
37
38impl RandomWalkPe {
39 pub fn new(walk_length: usize, pe_dim: usize) -> Self {
45 let mut rng = scirs2_core::random::rng();
46 let scale = (6.0_f64 / (walk_length + pe_dim) as f64).sqrt();
47 let projection = Array2::from_shape_fn((walk_length, pe_dim), |_| {
48 (rng.random::<f64>() * 2.0 - 1.0) * scale
49 });
50
51 RandomWalkPe {
52 walk_length,
53 projection,
54 pe_dim,
55 }
56 }
57
58 pub fn compute_landing_probs(&self, adj: &CsrMatrix) -> Array2<f64> {
62 let n = adj.n_rows;
63
64 let row_sums = adj.row_sums();
67 let mut p_data: Vec<f64> = Vec::with_capacity(adj.nnz());
68 for (row, _col, val) in adj.iter() {
69 let d = row_sums[row];
70 if d > 0.0 {
71 p_data.push(val / d);
72 } else {
73 p_data.push(0.0);
74 }
75 }
76
77 let p = CsrMatrix {
78 n_rows: adj.n_rows,
79 n_cols: adj.n_cols,
80 indptr: adj.indptr.clone(),
81 indices: adj.indices.clone(),
82 data: p_data,
83 };
84
85 let mut landing = Array2::zeros((n, self.walk_length));
88
89 if n <= 500 {
100 let mut power = Array2::<f64>::eye(n);
102 for k in 0..self.walk_length {
103 let mut new_power = Array2::zeros((n, n));
105 for (row, col, val) in p.iter() {
106 for j in 0..n {
107 new_power[[row, j]] += val * power[[col, j]];
108 }
109 }
110 power = new_power;
111 for i in 0..n {
113 landing[[i, k]] = power[[i, i]];
114 }
115 }
116 } else {
117 for i in 0..n {
119 let mut vec_cur = vec![0.0f64; n];
120 vec_cur[i] = 1.0;
121
122 for k in 0..self.walk_length {
123 let mut vec_next = vec![0.0f64; n];
124 for (row, col, val) in p.iter() {
125 vec_next[row] += val * vec_cur[col];
126 }
127 landing[[i, k]] = vec_next[i];
128 vec_cur = vec_next;
129 }
130 }
131 }
132
133 landing
134 }
135
136 pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
144 let landing = self.compute_landing_probs(adj);
145 let n = adj.n_rows;
146
147 let mut pe = Array2::zeros((n, self.pe_dim));
149 for i in 0..n {
150 for j in 0..self.pe_dim {
151 let mut s = 0.0;
152 for k in 0..self.walk_length {
153 s += landing[[i, k]] * self.projection[[k, j]];
154 }
155 pe[[i, j]] = s;
156 }
157 }
158
159 pe
160 }
161}
162
163#[derive(Debug, Clone)]
169pub struct LaplacianPe {
170 pub k: usize,
172 pub projection: Array2<f64>,
174 pub pe_dim: usize,
176}
177
178impl LaplacianPe {
179 pub fn new(k: usize, pe_dim: usize) -> Self {
185 let mut rng = scirs2_core::random::rng();
186 let scale = (6.0_f64 / (k + pe_dim) as f64).sqrt();
187 let projection =
188 Array2::from_shape_fn((k, pe_dim), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale);
189
190 LaplacianPe {
191 k,
192 projection,
193 pe_dim,
194 }
195 }
196
197 pub fn compute_eigenvectors(&self, adj: &CsrMatrix) -> Array2<f64> {
202 let n = adj.n_rows;
203 let actual_k = self.k.min(n.saturating_sub(1));
204 if actual_k == 0 || n < 2 {
205 return Array2::zeros((n, self.k));
206 }
207
208 let row_sums = adj.row_sums();
210 let mut lap = Array2::zeros((n, n));
211 for i in 0..n {
212 lap[[i, i]] = row_sums[i];
213 }
214 for (row, col, val) in adj.iter() {
215 lap[[row, col]] -= val;
216 }
217
218 let mut eigvecs = Array2::zeros((n, self.k));
222
223 let max_lambda_estimate = row_sums.iter().cloned().fold(0.0_f64, f64::max) * 2.0 + 1.0;
227
228 let mut m_mat = Array2::zeros((n, n));
230 for i in 0..n {
231 for j in 0..n {
232 m_mat[[i, j]] = -lap[[i, j]];
233 }
234 m_mat[[i, i]] += max_lambda_estimate;
235 }
236
237 let mut found_vecs: Vec<Vec<f64>> = Vec::new();
238
239 let trivial: Vec<f64> = vec![1.0 / (n as f64).sqrt(); n];
241 found_vecs.push(trivial);
242
243 let num_iters = 200;
244
245 for _ev_idx in 0..actual_k {
246 let mut rng = scirs2_core::random::rng();
248 let mut v: Vec<f64> = (0..n).map(|_| rng.random::<f64>() * 2.0 - 1.0).collect();
249
250 for fv in &found_vecs {
252 let dot: f64 = v.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
253 for (vi, fi) in v.iter_mut().zip(fv.iter()) {
254 *vi -= dot * fi;
255 }
256 }
257
258 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
260 v.iter_mut().for_each(|x| *x /= norm);
261
262 for _ in 0..num_iters {
263 let mut v_new = vec![0.0f64; n];
265 for i in 0..n {
266 for j in 0..n {
267 v_new[i] += m_mat[[i, j]] * v[j];
268 }
269 }
270
271 for fv in &found_vecs {
273 let dot: f64 = v_new.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
274 for (vi, fi) in v_new.iter_mut().zip(fv.iter()) {
275 *vi -= dot * fi;
276 }
277 }
278
279 let norm: f64 = v_new.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
281 v_new.iter_mut().for_each(|x| *x /= norm);
282
283 v = v_new;
284 }
285
286 found_vecs.push(v);
288 }
289
290 for (idx, fv) in found_vecs.iter().skip(1).take(self.k).enumerate() {
292 for i in 0..n {
293 eigvecs[[i, idx]] = fv[i];
294 }
295 }
296 eigvecs
299 }
300
301 pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
309 let eigvecs = self.compute_eigenvectors(adj);
310 let n = adj.n_rows;
311
312 let mut pe = Array2::zeros((n, self.pe_dim));
314 for i in 0..n {
315 for j in 0..self.pe_dim {
316 let mut s = 0.0;
317 for m in 0..self.k {
318 s += eigvecs[[i, m]] * self.projection[[m, j]];
319 }
320 pe[[i, j]] = s;
321 }
322 }
323
324 pe
325 }
326}
327
328#[derive(Debug, Clone)]
339struct GinLocal {
340 w1: Array2<f64>,
342 w2: Array2<f64>,
344 b1: Array1<f64>,
346 b2: Array1<f64>,
348 eps: f64,
350 hidden_dim: usize,
352}
353
354impl GinLocal {
355 fn new(hidden_dim: usize) -> Self {
356 let mut rng = scirs2_core::random::rng();
357 let scale = (6.0_f64 / (2 * hidden_dim) as f64).sqrt();
358
359 GinLocal {
360 w1: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
361 (rng.random::<f64>() * 2.0 - 1.0) * scale
362 }),
363 w2: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
364 (rng.random::<f64>() * 2.0 - 1.0) * scale
365 }),
366 b1: Array1::zeros(hidden_dim),
367 b2: Array1::zeros(hidden_dim),
368 eps: 0.0,
369 hidden_dim,
370 }
371 }
372
373 fn forward(&self, x: &Array2<f64>, adj: &CsrMatrix) -> Array2<f64> {
374 let n = x.dim().0;
375 let d = self.hidden_dim;
376
377 let mut agg = Array2::zeros((n, d));
379 for i in 0..n {
380 for j in 0..d {
381 agg[[i, j]] = (1.0 + self.eps) * x[[i, j]];
382 }
383 }
384 for (row, col, _) in adj.iter() {
385 for j in 0..d {
386 agg[[row, j]] += x[[col, j]];
387 }
388 }
389
390 let mut h = Array2::zeros((n, d));
392 for i in 0..n {
393 for j in 0..d {
394 let mut s = self.b1[j];
395 for m in 0..d {
396 s += agg[[i, m]] * self.w1[[m, j]];
397 }
398 h[[i, j]] = s.max(0.0); }
400 }
401
402 let mut out = Array2::zeros((n, d));
403 for i in 0..n {
404 for j in 0..d {
405 let mut s = self.b2[j];
406 for m in 0..d {
407 s += h[[i, m]] * self.w2[[m, j]];
408 }
409 out[[i, j]] = s;
410 }
411 }
412
413 out
414 }
415}
416
417#[derive(Debug, Clone)]
423struct GlobalAttention {
424 w_q: Array2<f64>,
425 w_k: Array2<f64>,
426 w_v: Array2<f64>,
427 w_o: Array2<f64>,
428 num_heads: usize,
429 hidden_dim: usize,
430 head_dim: usize,
431}
432
433impl GlobalAttention {
434 fn new(hidden_dim: usize, num_heads: usize) -> Result<Self> {
435 if !hidden_dim.is_multiple_of(num_heads) {
436 return Err(GraphError::InvalidParameter {
437 param: "hidden_dim".to_string(),
438 value: format!("{hidden_dim}"),
439 expected: format!("divisible by num_heads={num_heads}"),
440 context: "GlobalAttention::new".to_string(),
441 });
442 }
443
444 let head_dim = hidden_dim / num_heads;
445 let mut rng = scirs2_core::random::rng();
446 let scale = (6.0_f64 / (2 * hidden_dim) as f64).sqrt();
447
448 let mut init = |r, c| -> Array2<f64> {
449 Array2::from_shape_fn((r, c), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
450 };
451
452 Ok(GlobalAttention {
453 w_q: init(hidden_dim, hidden_dim),
454 w_k: init(hidden_dim, hidden_dim),
455 w_v: init(hidden_dim, hidden_dim),
456 w_o: init(hidden_dim, hidden_dim),
457 num_heads,
458 hidden_dim,
459 head_dim,
460 })
461 }
462
463 fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
464 let n = x.dim().0;
465 let d = self.hidden_dim;
466 let h = self.num_heads;
467 let dk = self.head_dim;
468 let scale = 1.0 / (dk as f64).sqrt();
469
470 let mut q = Array2::zeros((n, d));
472 let mut k = Array2::zeros((n, d));
473 let mut v = Array2::zeros((n, d));
474
475 for i in 0..n {
476 for j in 0..d {
477 let mut sq = 0.0;
478 let mut sk = 0.0;
479 let mut sv = 0.0;
480 for m in 0..d {
481 let xi = x[[i, m]];
482 sq += xi * self.w_q[[m, j]];
483 sk += xi * self.w_k[[m, j]];
484 sv += xi * self.w_v[[m, j]];
485 }
486 q[[i, j]] = sq;
487 k[[i, j]] = sk;
488 v[[i, j]] = sv;
489 }
490 }
491
492 let mut output = Array2::<f64>::zeros((n, d));
493
494 for head in 0..h {
495 let offset = head * dk;
496
497 let mut scores = vec![vec![0.0f64; n]; n];
499 for i in 0..n {
500 for j in 0..n {
501 let mut dot = 0.0;
502 for m in 0..dk {
503 dot += q[[i, offset + m]] * k[[j, offset + m]];
504 }
505 scores[i][j] = dot * scale;
506 }
507 }
508
509 for i in 0..n {
511 let alphas = softmax_row(&scores[i]);
512 for j in 0..n {
513 let a = alphas[j];
514 for m in 0..dk {
515 output[[i, offset + m]] += a * v[[j, offset + m]];
516 }
517 }
518 }
519 }
520
521 let mut projected = Array2::zeros((n, d));
523 for i in 0..n {
524 for j in 0..d {
525 let mut s = 0.0;
526 for m in 0..d {
527 s += output[[i, m]] * self.w_o[[m, j]];
528 }
529 projected[[i, j]] = s;
530 }
531 }
532
533 projected
534 }
535}
536
537fn softmax_row(row: &[f64]) -> Vec<f64> {
539 if row.is_empty() {
540 return Vec::new();
541 }
542 let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
543 let exps: Vec<f64> = row.iter().map(|x| (x - max_val).exp()).collect();
544 let sum = exps.iter().sum::<f64>().max(1e-12);
545 exps.iter().map(|e| e / sum).collect()
546}
547
548fn layer_norm_vec(x: &mut [f64], eps: f64) {
550 let n = x.len();
551 if n == 0 {
552 return;
553 }
554 let mean = x.iter().sum::<f64>() / n as f64;
555 let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n as f64;
556 let inv_std = 1.0 / (var + eps).sqrt();
557 for v in x.iter_mut() {
558 *v = (*v - mean) * inv_std;
559 }
560}
561
562#[derive(Debug, Clone, PartialEq, Eq)]
568pub enum LocalModel {
569 Gin,
571 Gat,
573}
574
575#[derive(Debug, Clone)]
577pub struct GpsConfig {
578 pub in_dim: usize,
580 pub hidden_dim: usize,
582 pub num_heads: usize,
584 pub num_layers: usize,
586 pub ffn_dim: usize,
588 pub local_model: LocalModel,
590 pub pe_dim: usize,
592 pub rw_walk_length: usize,
594 pub layer_norm_eps: f64,
596}
597
598impl Default for GpsConfig {
599 fn default() -> Self {
600 GpsConfig {
601 in_dim: 64,
602 hidden_dim: 64,
603 num_heads: 4,
604 num_layers: 3,
605 ffn_dim: 256,
606 local_model: LocalModel::Gin,
607 pe_dim: 16,
608 rw_walk_length: 8,
609 layer_norm_eps: 1e-5,
610 }
611 }
612}
613
614#[derive(Debug, Clone)]
624pub struct GpsLayer {
625 gin_local: GinLocal,
627 global_attn: GlobalAttention,
629 ffn_w1: Array2<f64>,
631 ffn_w2: Array2<f64>,
633 ffn_b1: Array1<f64>,
635 ffn_b2: Array1<f64>,
637 hidden_dim: usize,
639 layer_norm_eps: f64,
641}
642
643impl GpsLayer {
644 pub fn new(
646 hidden_dim: usize,
647 num_heads: usize,
648 ffn_dim: usize,
649 layer_norm_eps: f64,
650 ) -> Result<Self> {
651 let mut rng = scirs2_core::random::rng();
652 let ffn_scale = (6.0_f64 / (hidden_dim + ffn_dim) as f64).sqrt();
653
654 Ok(GpsLayer {
655 gin_local: GinLocal::new(hidden_dim),
656 global_attn: GlobalAttention::new(hidden_dim, num_heads)?,
657 ffn_w1: Array2::from_shape_fn((hidden_dim, ffn_dim), |_| {
658 (rng.random::<f64>() * 2.0 - 1.0) * ffn_scale
659 }),
660 ffn_w2: Array2::from_shape_fn((ffn_dim, hidden_dim), |_| {
661 (rng.random::<f64>() * 2.0 - 1.0) * ffn_scale
662 }),
663 ffn_b1: Array1::zeros(ffn_dim),
664 ffn_b2: Array1::zeros(hidden_dim),
665 hidden_dim,
666 layer_norm_eps,
667 })
668 }
669
670 fn ffn(&self, x: &Array2<f64>) -> Array2<f64> {
672 let n = x.dim().0;
673 let d = self.hidden_dim;
674 let ffn_dim = self.ffn_w1.dim().1;
675
676 let mut h = Array2::zeros((n, ffn_dim));
677 for i in 0..n {
678 for j in 0..ffn_dim {
679 let mut s = self.ffn_b1[j];
680 for m in 0..d {
681 s += x[[i, m]] * self.ffn_w1[[m, j]];
682 }
683 h[[i, j]] = s.max(0.0); }
685 }
686
687 let mut out = Array2::zeros((n, d));
688 for i in 0..n {
689 for j in 0..d {
690 let mut s = self.ffn_b2[j];
691 for m in 0..ffn_dim {
692 s += h[[i, m]] * self.ffn_w2[[m, j]];
693 }
694 out[[i, j]] = s;
695 }
696 }
697
698 out
699 }
700
701 pub fn forward(&self, x: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
707 let (n, d) = x.dim();
708 if d != self.hidden_dim {
709 return Err(GraphError::InvalidParameter {
710 param: "x".to_string(),
711 value: format!("dim={d}"),
712 expected: format!("dim={}", self.hidden_dim),
713 context: "GpsLayer::forward".to_string(),
714 });
715 }
716
717 let local_out = self.gin_local.forward(x, adj);
719
720 let global_out = self.global_attn.forward(x);
722
723 let ffn_out = self.ffn(x);
725
726 let mut out = x.clone();
728 for i in 0..n {
729 for j in 0..d {
730 out[[i, j]] += local_out[[i, j]] + global_out[[i, j]] + ffn_out[[i, j]];
731 }
732 }
733
734 for i in 0..n {
736 let mut row: Vec<f64> = (0..d).map(|j| out[[i, j]]).collect();
737 layer_norm_vec(&mut row, self.layer_norm_eps);
738 for j in 0..d {
739 out[[i, j]] = row[j];
740 }
741 }
742
743 Ok(out)
744 }
745}
746
747#[derive(Debug, Clone)]
753pub struct GpsModel {
754 pub input_proj: Array2<f64>,
756 pub rwpe: RandomWalkPe,
758 pub layers: Vec<GpsLayer>,
760 pub config: GpsConfig,
762}
763
764impl GpsModel {
765 pub fn new(config: GpsConfig) -> Result<Self> {
767 let mut rng = scirs2_core::random::rng();
768 let total_in = config.in_dim + config.pe_dim;
769 let proj_scale = (6.0_f64 / (total_in + config.hidden_dim) as f64).sqrt();
770 let input_proj = Array2::from_shape_fn((total_in, config.hidden_dim), |_| {
771 (rng.random::<f64>() * 2.0 - 1.0) * proj_scale
772 });
773
774 let rwpe = RandomWalkPe::new(config.rw_walk_length, config.pe_dim);
775
776 let mut layers = Vec::with_capacity(config.num_layers);
777 for _ in 0..config.num_layers {
778 layers.push(GpsLayer::new(
779 config.hidden_dim,
780 config.num_heads,
781 config.ffn_dim,
782 config.layer_norm_eps,
783 )?);
784 }
785
786 Ok(GpsModel {
787 input_proj,
788 rwpe,
789 layers,
790 config,
791 })
792 }
793
794 pub fn forward(&self, features: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
803 let (n, in_dim) = features.dim();
804 if in_dim != self.config.in_dim {
805 return Err(GraphError::InvalidParameter {
806 param: "features".to_string(),
807 value: format!("in_dim={in_dim}"),
808 expected: format!("in_dim={}", self.config.in_dim),
809 context: "GpsModel::forward".to_string(),
810 });
811 }
812 if adj.n_rows != n {
813 return Err(GraphError::InvalidParameter {
814 param: "adj".to_string(),
815 value: format!("n_rows={}", adj.n_rows),
816 expected: format!("n_rows={n}"),
817 context: "GpsModel::forward".to_string(),
818 });
819 }
820
821 let pe = self.rwpe.forward(adj);
823
824 let total_in = self.config.in_dim + self.config.pe_dim;
826 let mut concat = Array2::zeros((n, total_in));
827 for i in 0..n {
828 for j in 0..in_dim {
829 concat[[i, j]] = features[[i, j]];
830 }
831 for j in 0..self.config.pe_dim {
832 concat[[i, in_dim + j]] = pe[[i, j]];
833 }
834 }
835
836 let d = self.config.hidden_dim;
838 let mut h = Array2::zeros((n, d));
839 for i in 0..n {
840 for j in 0..d {
841 let mut s = 0.0;
842 for m in 0..total_in {
843 s += concat[[i, m]] * self.input_proj[[m, j]];
844 }
845 h[[i, j]] = s;
846 }
847 }
848
849 for layer in &self.layers {
851 h = layer.forward(&h, adj)?;
852 }
853
854 Ok(h)
855 }
856}
857
858#[cfg(test)]
863mod tests {
864 use super::*;
865
866 fn triangle_csr() -> CsrMatrix {
867 let coo = vec![
868 (0, 1, 1.0),
869 (1, 0, 1.0),
870 (1, 2, 1.0),
871 (2, 1, 1.0),
872 (0, 2, 1.0),
873 (2, 0, 1.0),
874 ];
875 CsrMatrix::from_coo(3, 3, &coo).expect("triangle CSR")
876 }
877
878 fn path_csr() -> CsrMatrix {
879 let coo = vec![
880 (0, 1, 1.0),
881 (1, 0, 1.0),
882 (1, 2, 1.0),
883 (2, 1, 1.0),
884 (2, 3, 1.0),
885 (3, 2, 1.0),
886 ];
887 CsrMatrix::from_coo(4, 4, &coo).expect("path CSR")
888 }
889
890 fn feats(n: usize, d: usize) -> Array2<f64> {
891 Array2::from_shape_fn((n, d), |(i, j)| (i * d + j) as f64 * 0.1)
892 }
893
894 #[test]
895 fn test_rwpe_landing_probs_shape() {
896 let adj = triangle_csr();
897 let rwpe = RandomWalkPe::new(4, 8);
898 let landing = rwpe.compute_landing_probs(&adj);
899 assert_eq!(landing.dim(), (3, 4));
900 for &v in landing.iter() {
901 assert!(v.is_finite(), "landing prob should be finite, got {v}");
902 assert!(v >= 0.0, "landing prob should be non-negative, got {v}");
903 }
904 }
905
906 #[test]
907 fn test_rwpe_produces_correct_features() {
908 let adj = triangle_csr();
909 let rwpe = RandomWalkPe::new(3, 6);
910 let pe = rwpe.forward(&adj);
911 assert_eq!(pe.dim(), (3, 6));
912
913 let landing = rwpe.compute_landing_probs(&adj);
916 for k in 0..3 {
917 let val0 = landing[[0, k]];
918 let val1 = landing[[1, k]];
919 let val2 = landing[[2, k]];
920 assert!(
921 (val0 - val1).abs() < 1e-10 && (val1 - val2).abs() < 1e-10,
922 "symmetric graph should have equal landing probs at step {k}: {val0}, {val1}, {val2}"
923 );
924 }
925 }
926
927 #[test]
928 fn test_rwpe_path_graph_different_probs() {
929 let adj = path_csr();
930 let rwpe = RandomWalkPe::new(3, 4);
931 let landing = rwpe.compute_landing_probs(&adj);
932 assert_eq!(landing.dim(), (4, 3));
933
934 let end_prob = landing[[0, 0]]; let mid_prob = landing[[1, 0]]; assert!(end_prob.is_finite());
943 assert!(mid_prob.is_finite());
944 }
945
946 #[test]
947 fn test_laplacian_pe_shape() {
948 let adj = triangle_csr();
949 let lpe = LaplacianPe::new(2, 6);
950 let pe = lpe.forward(&adj);
951 assert_eq!(pe.dim(), (3, 6));
952 for &v in pe.iter() {
953 assert!(v.is_finite(), "Laplacian PE should be finite, got {v}");
954 }
955 }
956
957 #[test]
958 fn test_gps_hybrid_combines_local_and_global() {
959 let adj = triangle_csr();
960 let features = feats(3, 8);
961
962 let config = GpsConfig {
963 in_dim: 8,
964 hidden_dim: 8,
965 num_heads: 2,
966 num_layers: 1,
967 ffn_dim: 16,
968 local_model: LocalModel::Gin,
969 pe_dim: 4,
970 rw_walk_length: 3,
971 ..Default::default()
972 };
973
974 let model = GpsModel::new(config).expect("GPS model");
975 let out = model.forward(&features, &adj).expect("GPS forward");
976 assert_eq!(out.dim(), (3, 8));
977
978 for &v in out.iter() {
979 assert!(v.is_finite(), "GPS output should be finite, got {v}");
980 }
981
982 let has_variation = out.iter().any(|&v| v.abs() > 1e-12);
984 assert!(has_variation, "GPS output should have non-trivial values");
985 }
986
987 #[test]
988 fn test_gps_layer_forward_shape() {
989 let adj = triangle_csr();
990 let x = feats(3, 8);
991 let layer = GpsLayer::new(8, 2, 16, 1e-5).expect("GPS layer");
992 let out = layer.forward(&x, &adj).expect("GPS layer forward");
993 assert_eq!(out.dim(), (3, 8));
994
995 for i in 0..3 {
997 let mean: f64 = (0..8).map(|j| out[[i, j]]).sum::<f64>() / 8.0;
998 assert!(
999 mean.abs() < 0.1,
1000 "after layer norm, mean should be near 0, got {mean}"
1001 );
1002 }
1003 }
1004
1005 #[test]
1006 fn test_gps_multi_layer() {
1007 let adj = path_csr();
1008 let config = GpsConfig {
1009 in_dim: 4,
1010 hidden_dim: 8,
1011 num_heads: 2,
1012 num_layers: 3,
1013 ffn_dim: 16,
1014 pe_dim: 4,
1015 rw_walk_length: 3,
1016 ..Default::default()
1017 };
1018
1019 let model = GpsModel::new(config).expect("GPS model");
1020 let features = feats(4, 4);
1021 let out = model.forward(&features, &adj).expect("GPS forward");
1022 assert_eq!(out.dim(), (4, 8));
1023 for &v in out.iter() {
1024 assert!(v.is_finite(), "multi-layer GPS output should be finite");
1025 }
1026 }
1027
1028 #[test]
1029 fn test_gps_invalid_dim_error() {
1030 let adj = triangle_csr();
1031 let config = GpsConfig {
1032 in_dim: 4,
1033 hidden_dim: 7, num_heads: 4,
1035 ..Default::default()
1036 };
1037 let result = GpsModel::new(config);
1038 assert!(result.is_err());
1039 }
1040
1041 #[test]
1042 fn test_gin_local_aggregation() {
1043 let adj = triangle_csr();
1044 let x = feats(3, 8);
1045 let gin = GinLocal::new(8);
1046 let out = gin.forward(&x, &adj);
1047 assert_eq!(out.dim(), (3, 8));
1048 for &v in out.iter() {
1049 assert!(v.is_finite());
1050 }
1051 }
1052}