Skip to main content

scirs2_graph/gnn/transformers/
gps.rs

1//! GPS - General Powerful Scalable Graph Transformer
2//!
3//! Implements the GPS architecture from Rampasek et al. (2022),
4//! "Recipe for a General, Powerful, Scalable Graph Transformer".
5//!
6//! Key components:
7//! - **Hybrid architecture**: local MPNN + global attention combined
8//! - **Local message passing**: GIN-style aggregation for local structure
9//! - **Global attention**: standard multi-head attention over all nodes
10//! - **Positional/structural encoding**: Random Walk PE (RWPE), Laplacian PE
11//! - **Layer design**: `output = MPNN(x) + Attention(x) + FFN(x)` with residuals
12
13use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::{Rng, RngExt};
15
16use crate::error::{GraphError, Result};
17use crate::gnn::gcn::CsrMatrix;
18
19// ============================================================================
20// Positional Encodings
21// ============================================================================
22
23/// Random Walk Positional Encoding (RWPE).
24///
25/// Computes the diagonal of P^k for k=1..K where P is the random walk
26/// transition matrix. The landing probabilities R_ii = (P^k)_ii encode
27/// local structural information around each node.
28#[derive(Debug, Clone)]
29pub struct RandomWalkPe {
30    /// Number of random walk steps (K)
31    pub walk_length: usize,
32    /// Linear projection: `[walk_length, pe_dim]`
33    pub projection: Array2<f64>,
34    /// Output PE dimension
35    pub pe_dim: usize,
36}
37
38impl RandomWalkPe {
39    /// Create a new RWPE module.
40    ///
41    /// # Arguments
42    /// * `walk_length` - Number of random walk steps K
43    /// * `pe_dim` - Output dimension for the positional encoding
44    pub fn new(walk_length: usize, pe_dim: usize) -> Self {
45        let mut rng = scirs2_core::random::rng();
46        let scale = (6.0_f64 / (walk_length + pe_dim) as f64).sqrt();
47        let projection = Array2::from_shape_fn((walk_length, pe_dim), |_| {
48            (rng.random::<f64>() * 2.0 - 1.0) * scale
49        });
50
51        RandomWalkPe {
52            walk_length,
53            projection,
54            pe_dim,
55        }
56    }
57
58    /// Compute random walk landing probabilities.
59    ///
60    /// Returns `[n, walk_length]` where entry `(i, k)` = `(P^{k+1})_{ii}`.
61    pub fn compute_landing_probs(&self, adj: &CsrMatrix) -> Array2<f64> {
62        let n = adj.n_rows;
63
64        // Build transition matrix P = D^{-1} A
65        // Store as sparse row-normalized adjacency
66        let row_sums = adj.row_sums();
67        let mut p_data: Vec<f64> = Vec::with_capacity(adj.nnz());
68        for (row, _col, val) in adj.iter() {
69            let d = row_sums[row];
70            if d > 0.0 {
71                p_data.push(val / d);
72            } else {
73                p_data.push(0.0);
74            }
75        }
76
77        let p = CsrMatrix {
78            n_rows: adj.n_rows,
79            n_cols: adj.n_cols,
80            indptr: adj.indptr.clone(),
81            indices: adj.indices.clone(),
82            data: p_data,
83        };
84
85        // Compute P^k diagonal via repeated sparse-matrix power
86        // We track the diagonal of P^k by multiplying P with vectors
87        let mut landing = Array2::zeros((n, self.walk_length));
88
89        // For each node, compute (P^k e_i)_i = diagonal entry
90        // More efficient: use the full matrix power on identity columns
91        // For moderate n, compute P^k columns directly
92
93        // Current power matrix diagonal tracker
94        // We use the approach: for each step k, compute p_k = P * p_{k-1}
95        // where p_0 = I, and extract diagonals
96        // But we only need diagonals, so we track n vectors e_i through P
97
98        // Efficient approach: track P^k as dense for small n, sparse power for large
99        if n <= 500 {
100            // Dense approach for small graphs
101            let mut power = Array2::<f64>::eye(n);
102            for k in 0..self.walk_length {
103                // power = P * power (sparse-dense multiplication)
104                let mut new_power = Array2::zeros((n, n));
105                for (row, col, val) in p.iter() {
106                    for j in 0..n {
107                        new_power[[row, j]] += val * power[[col, j]];
108                    }
109                }
110                power = new_power;
111                // Extract diagonal
112                for i in 0..n {
113                    landing[[i, k]] = power[[i, i]];
114                }
115            }
116        } else {
117            // For large graphs, compute per-node using sparse mat-vec
118            for i in 0..n {
119                let mut vec_cur = vec![0.0f64; n];
120                vec_cur[i] = 1.0;
121
122                for k in 0..self.walk_length {
123                    let mut vec_next = vec![0.0f64; n];
124                    for (row, col, val) in p.iter() {
125                        vec_next[row] += val * vec_cur[col];
126                    }
127                    landing[[i, k]] = vec_next[i];
128                    vec_cur = vec_next;
129                }
130            }
131        }
132
133        landing
134    }
135
136    /// Compute RWPE and project to pe_dim.
137    ///
138    /// # Arguments
139    /// * `adj` - Adjacency matrix
140    ///
141    /// # Returns
142    /// Positional encoding `[n, pe_dim]`
143    pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
144        let landing = self.compute_landing_probs(adj);
145        let n = adj.n_rows;
146
147        // Project: [n, walk_length] @ [walk_length, pe_dim] -> [n, pe_dim]
148        let mut pe = Array2::zeros((n, self.pe_dim));
149        for i in 0..n {
150            for j in 0..self.pe_dim {
151                let mut s = 0.0;
152                for k in 0..self.walk_length {
153                    s += landing[[i, k]] * self.projection[[k, j]];
154                }
155                pe[[i, j]] = s;
156            }
157        }
158
159        pe
160    }
161}
162
163/// Laplacian Positional Encoding.
164///
165/// Uses the eigenvectors of the graph Laplacian as positional encodings.
166/// Computes the k smallest non-trivial eigenvectors of L = D - A using
167/// power iteration.
168#[derive(Debug, Clone)]
169pub struct LaplacianPe {
170    /// Number of eigenvectors to use
171    pub k: usize,
172    /// Linear projection: `[k, pe_dim]`
173    pub projection: Array2<f64>,
174    /// Output PE dimension
175    pub pe_dim: usize,
176}
177
178impl LaplacianPe {
179    /// Create a new Laplacian PE module.
180    ///
181    /// # Arguments
182    /// * `k` - Number of Laplacian eigenvectors to use
183    /// * `pe_dim` - Output PE dimension
184    pub fn new(k: usize, pe_dim: usize) -> Self {
185        let mut rng = scirs2_core::random::rng();
186        let scale = (6.0_f64 / (k + pe_dim) as f64).sqrt();
187        let projection =
188            Array2::from_shape_fn((k, pe_dim), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale);
189
190        LaplacianPe {
191            k,
192            projection,
193            pe_dim,
194        }
195    }
196
197    /// Compute the k smallest non-trivial eigenvectors of the Laplacian.
198    ///
199    /// Uses inverse power iteration with deflation.
200    /// Returns `[n, k]` matrix of eigenvectors.
201    pub fn compute_eigenvectors(&self, adj: &CsrMatrix) -> Array2<f64> {
202        let n = adj.n_rows;
203        let actual_k = self.k.min(n.saturating_sub(1));
204        if actual_k == 0 || n < 2 {
205            return Array2::zeros((n, self.k));
206        }
207
208        // Build Laplacian L = D - A as dense (for small-moderate graphs)
209        let row_sums = adj.row_sums();
210        let mut lap = Array2::zeros((n, n));
211        for i in 0..n {
212            lap[[i, i]] = row_sums[i];
213        }
214        for (row, col, val) in adj.iter() {
215            lap[[row, col]] -= val;
216        }
217
218        // Power iteration for smallest eigenvectors
219        // We use shifted inverse iteration: solve (L - sigma*I) x = b
220        // For simplicity, use direct eigendecomposition for small n
221        let mut eigvecs = Array2::zeros((n, self.k));
222
223        // Simple approach: power iteration on (max_lambda * I - L)
224        // to find largest eigenvectors of (max_lambda * I - L),
225        // which correspond to smallest of L
226        let max_lambda_estimate = row_sums.iter().cloned().fold(0.0_f64, f64::max) * 2.0 + 1.0;
227
228        // Build shifted matrix M = max_lambda * I - L
229        let mut m_mat = Array2::zeros((n, n));
230        for i in 0..n {
231            for j in 0..n {
232                m_mat[[i, j]] = -lap[[i, j]];
233            }
234            m_mat[[i, i]] += max_lambda_estimate;
235        }
236
237        let mut found_vecs: Vec<Vec<f64>> = Vec::new();
238
239        // Skip the trivial eigenvector (constant) by deflating
240        let trivial: Vec<f64> = vec![1.0 / (n as f64).sqrt(); n];
241        found_vecs.push(trivial);
242
243        let num_iters = 200;
244
245        for _ev_idx in 0..actual_k {
246            // Initialize random vector
247            let mut rng = scirs2_core::random::rng();
248            let mut v: Vec<f64> = (0..n).map(|_| rng.random::<f64>() * 2.0 - 1.0).collect();
249
250            // Orthogonalize against found vectors
251            for fv in &found_vecs {
252                let dot: f64 = v.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
253                for (vi, fi) in v.iter_mut().zip(fv.iter()) {
254                    *vi -= dot * fi;
255                }
256            }
257
258            // Normalize
259            let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
260            v.iter_mut().for_each(|x| *x /= norm);
261
262            for _ in 0..num_iters {
263                // Multiply: v_new = M * v
264                let mut v_new = vec![0.0f64; n];
265                for i in 0..n {
266                    for j in 0..n {
267                        v_new[i] += m_mat[[i, j]] * v[j];
268                    }
269                }
270
271                // Orthogonalize against found vectors
272                for fv in &found_vecs {
273                    let dot: f64 = v_new.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
274                    for (vi, fi) in v_new.iter_mut().zip(fv.iter()) {
275                        *vi -= dot * fi;
276                    }
277                }
278
279                // Normalize
280                let norm: f64 = v_new.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
281                v_new.iter_mut().for_each(|x| *x /= norm);
282
283                v = v_new;
284            }
285
286            // Store eigenvector
287            found_vecs.push(v);
288        }
289
290        // Copy found eigenvectors (skip trivial) into output
291        for (idx, fv) in found_vecs.iter().skip(1).take(self.k).enumerate() {
292            for i in 0..n {
293                eigvecs[[i, idx]] = fv[i];
294            }
295        }
296        // Pad remaining columns with zeros if actual_k < self.k (already initialized to 0)
297
298        eigvecs
299    }
300
301    /// Compute Laplacian PE and project.
302    ///
303    /// # Arguments
304    /// * `adj` - Adjacency matrix
305    ///
306    /// # Returns
307    /// Positional encoding `[n, pe_dim]`
308    pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
309        let eigvecs = self.compute_eigenvectors(adj);
310        let n = adj.n_rows;
311
312        // Project: [n, k] @ [k, pe_dim] -> [n, pe_dim]
313        let mut pe = Array2::zeros((n, self.pe_dim));
314        for i in 0..n {
315            for j in 0..self.pe_dim {
316                let mut s = 0.0;
317                for m in 0..self.k {
318                    s += eigvecs[[i, m]] * self.projection[[m, j]];
319                }
320                pe[[i, j]] = s;
321            }
322        }
323
324        pe
325    }
326}
327
328// ============================================================================
329// Local MPNN: GIN-style
330// ============================================================================
331
332/// GIN (Graph Isomorphism Network) style local message passing.
333///
334/// Update rule:
335/// ```text
336/// h_i' = MLP( (1 + eps) * h_i + sum_{j in N(i)} h_j )
337/// ```
338#[derive(Debug, Clone)]
339struct GinLocal {
340    /// MLP first layer: `[hidden_dim, hidden_dim]`
341    w1: Array2<f64>,
342    /// MLP second layer: `[hidden_dim, hidden_dim]`
343    w2: Array2<f64>,
344    /// MLP biases
345    b1: Array1<f64>,
346    /// MLP output bias
347    b2: Array1<f64>,
348    /// Epsilon parameter
349    eps: f64,
350    /// Hidden dimension
351    hidden_dim: usize,
352}
353
354impl GinLocal {
355    fn new(hidden_dim: usize) -> Self {
356        let mut rng = scirs2_core::random::rng();
357        let scale = (6.0_f64 / (2 * hidden_dim) as f64).sqrt();
358
359        GinLocal {
360            w1: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
361                (rng.random::<f64>() * 2.0 - 1.0) * scale
362            }),
363            w2: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
364                (rng.random::<f64>() * 2.0 - 1.0) * scale
365            }),
366            b1: Array1::zeros(hidden_dim),
367            b2: Array1::zeros(hidden_dim),
368            eps: 0.0,
369            hidden_dim,
370        }
371    }
372
373    fn forward(&self, x: &Array2<f64>, adj: &CsrMatrix) -> Array2<f64> {
374        let n = x.dim().0;
375        let d = self.hidden_dim;
376
377        // Aggregate: (1 + eps) * x_i + sum_{j in N(i)} x_j
378        let mut agg = Array2::zeros((n, d));
379        for i in 0..n {
380            for j in 0..d {
381                agg[[i, j]] = (1.0 + self.eps) * x[[i, j]];
382            }
383        }
384        for (row, col, _) in adj.iter() {
385            for j in 0..d {
386                agg[[row, j]] += x[[col, j]];
387            }
388        }
389
390        // MLP: ReLU(W1 * agg + b1) then W2 + b2
391        let mut h = Array2::zeros((n, d));
392        for i in 0..n {
393            for j in 0..d {
394                let mut s = self.b1[j];
395                for m in 0..d {
396                    s += agg[[i, m]] * self.w1[[m, j]];
397                }
398                h[[i, j]] = s.max(0.0); // ReLU
399            }
400        }
401
402        let mut out = Array2::zeros((n, d));
403        for i in 0..n {
404            for j in 0..d {
405                let mut s = self.b2[j];
406                for m in 0..d {
407                    s += h[[i, m]] * self.w2[[m, j]];
408                }
409                out[[i, j]] = s;
410            }
411        }
412
413        out
414    }
415}
416
417// ============================================================================
418// Global Attention
419// ============================================================================
420
421/// Standard multi-head self-attention over all nodes (no graph structure bias).
422#[derive(Debug, Clone)]
423struct GlobalAttention {
424    w_q: Array2<f64>,
425    w_k: Array2<f64>,
426    w_v: Array2<f64>,
427    w_o: Array2<f64>,
428    num_heads: usize,
429    hidden_dim: usize,
430    head_dim: usize,
431}
432
433impl GlobalAttention {
434    fn new(hidden_dim: usize, num_heads: usize) -> Result<Self> {
435        if !hidden_dim.is_multiple_of(num_heads) {
436            return Err(GraphError::InvalidParameter {
437                param: "hidden_dim".to_string(),
438                value: format!("{hidden_dim}"),
439                expected: format!("divisible by num_heads={num_heads}"),
440                context: "GlobalAttention::new".to_string(),
441            });
442        }
443
444        let head_dim = hidden_dim / num_heads;
445        let mut rng = scirs2_core::random::rng();
446        let scale = (6.0_f64 / (2 * hidden_dim) as f64).sqrt();
447
448        let mut init = |r, c| -> Array2<f64> {
449            Array2::from_shape_fn((r, c), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
450        };
451
452        Ok(GlobalAttention {
453            w_q: init(hidden_dim, hidden_dim),
454            w_k: init(hidden_dim, hidden_dim),
455            w_v: init(hidden_dim, hidden_dim),
456            w_o: init(hidden_dim, hidden_dim),
457            num_heads,
458            hidden_dim,
459            head_dim,
460        })
461    }
462
463    fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
464        let n = x.dim().0;
465        let d = self.hidden_dim;
466        let h = self.num_heads;
467        let dk = self.head_dim;
468        let scale = 1.0 / (dk as f64).sqrt();
469
470        // Q, K, V projections
471        let mut q = Array2::zeros((n, d));
472        let mut k = Array2::zeros((n, d));
473        let mut v = Array2::zeros((n, d));
474
475        for i in 0..n {
476            for j in 0..d {
477                let mut sq = 0.0;
478                let mut sk = 0.0;
479                let mut sv = 0.0;
480                for m in 0..d {
481                    let xi = x[[i, m]];
482                    sq += xi * self.w_q[[m, j]];
483                    sk += xi * self.w_k[[m, j]];
484                    sv += xi * self.w_v[[m, j]];
485                }
486                q[[i, j]] = sq;
487                k[[i, j]] = sk;
488                v[[i, j]] = sv;
489            }
490        }
491
492        let mut output = Array2::<f64>::zeros((n, d));
493
494        for head in 0..h {
495            let offset = head * dk;
496
497            // Attention scores
498            let mut scores = vec![vec![0.0f64; n]; n];
499            for i in 0..n {
500                for j in 0..n {
501                    let mut dot = 0.0;
502                    for m in 0..dk {
503                        dot += q[[i, offset + m]] * k[[j, offset + m]];
504                    }
505                    scores[i][j] = dot * scale;
506                }
507            }
508
509            // Softmax + aggregate
510            for i in 0..n {
511                let alphas = softmax_row(&scores[i]);
512                for j in 0..n {
513                    let a = alphas[j];
514                    for m in 0..dk {
515                        output[[i, offset + m]] += a * v[[j, offset + m]];
516                    }
517                }
518            }
519        }
520
521        // Output projection
522        let mut projected = Array2::zeros((n, d));
523        for i in 0..n {
524            for j in 0..d {
525                let mut s = 0.0;
526                for m in 0..d {
527                    s += output[[i, m]] * self.w_o[[m, j]];
528                }
529                projected[[i, j]] = s;
530            }
531        }
532
533        projected
534    }
535}
536
537/// Numerically-stable softmax over a slice.
538fn softmax_row(row: &[f64]) -> Vec<f64> {
539    if row.is_empty() {
540        return Vec::new();
541    }
542    let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
543    let exps: Vec<f64> = row.iter().map(|x| (x - max_val).exp()).collect();
544    let sum = exps.iter().sum::<f64>().max(1e-12);
545    exps.iter().map(|e| e / sum).collect()
546}
547
548/// Layer normalization over the feature dimension.
549fn layer_norm_vec(x: &mut [f64], eps: f64) {
550    let n = x.len();
551    if n == 0 {
552        return;
553    }
554    let mean = x.iter().sum::<f64>() / n as f64;
555    let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n as f64;
556    let inv_std = 1.0 / (var + eps).sqrt();
557    for v in x.iter_mut() {
558        *v = (*v - mean) * inv_std;
559    }
560}
561
562// ============================================================================
563// GPS Configuration
564// ============================================================================
565
566/// Which local model to use in GPS.
567#[derive(Debug, Clone, PartialEq, Eq)]
568pub enum LocalModel {
569    /// GIN (Graph Isomorphism Network) local aggregation
570    Gin,
571    /// GAT (Graph Attention Network) local aggregation (simplified)
572    Gat,
573}
574
575/// Configuration for the GPS model.
576#[derive(Debug, Clone)]
577pub struct GpsConfig {
578    /// Input feature dimension
579    pub in_dim: usize,
580    /// Hidden dimension
581    pub hidden_dim: usize,
582    /// Number of attention heads for global attention
583    pub num_heads: usize,
584    /// Number of GPS layers
585    pub num_layers: usize,
586    /// FFN intermediate dimension
587    pub ffn_dim: usize,
588    /// Local model type
589    pub local_model: LocalModel,
590    /// PE dimension (added to hidden_dim for input)
591    pub pe_dim: usize,
592    /// Random walk length for RWPE
593    pub rw_walk_length: usize,
594    /// Layer norm epsilon
595    pub layer_norm_eps: f64,
596}
597
598impl Default for GpsConfig {
599    fn default() -> Self {
600        GpsConfig {
601            in_dim: 64,
602            hidden_dim: 64,
603            num_heads: 4,
604            num_layers: 3,
605            ffn_dim: 256,
606            local_model: LocalModel::Gin,
607            pe_dim: 16,
608            rw_walk_length: 8,
609            layer_norm_eps: 1e-5,
610        }
611    }
612}
613
614// ============================================================================
615// GPS Layer
616// ============================================================================
617
618/// A single GPS layer combining local MPNN and global attention.
619///
620/// ```text
621/// output = LayerNorm(x + MPNN(x) + Attention(x) + FFN(x))
622/// ```
623#[derive(Debug, Clone)]
624pub struct GpsLayer {
625    /// Local GIN aggregation
626    gin_local: GinLocal,
627    /// Global multi-head attention
628    global_attn: GlobalAttention,
629    /// FFN first layer: `[hidden_dim, ffn_dim]`
630    ffn_w1: Array2<f64>,
631    /// FFN second layer: `[ffn_dim, hidden_dim]`
632    ffn_w2: Array2<f64>,
633    /// FFN biases
634    ffn_b1: Array1<f64>,
635    /// FFN output bias
636    ffn_b2: Array1<f64>,
637    /// Hidden dimension
638    hidden_dim: usize,
639    /// Layer norm epsilon
640    layer_norm_eps: f64,
641}
642
643impl GpsLayer {
644    /// Create a new GPS layer.
645    pub fn new(
646        hidden_dim: usize,
647        num_heads: usize,
648        ffn_dim: usize,
649        layer_norm_eps: f64,
650    ) -> Result<Self> {
651        let mut rng = scirs2_core::random::rng();
652        let ffn_scale = (6.0_f64 / (hidden_dim + ffn_dim) as f64).sqrt();
653
654        Ok(GpsLayer {
655            gin_local: GinLocal::new(hidden_dim),
656            global_attn: GlobalAttention::new(hidden_dim, num_heads)?,
657            ffn_w1: Array2::from_shape_fn((hidden_dim, ffn_dim), |_| {
658                (rng.random::<f64>() * 2.0 - 1.0) * ffn_scale
659            }),
660            ffn_w2: Array2::from_shape_fn((ffn_dim, hidden_dim), |_| {
661                (rng.random::<f64>() * 2.0 - 1.0) * ffn_scale
662            }),
663            ffn_b1: Array1::zeros(ffn_dim),
664            ffn_b2: Array1::zeros(hidden_dim),
665            hidden_dim,
666            layer_norm_eps,
667        })
668    }
669
670    /// FFN with ReLU activation.
671    fn ffn(&self, x: &Array2<f64>) -> Array2<f64> {
672        let n = x.dim().0;
673        let d = self.hidden_dim;
674        let ffn_dim = self.ffn_w1.dim().1;
675
676        let mut h = Array2::zeros((n, ffn_dim));
677        for i in 0..n {
678            for j in 0..ffn_dim {
679                let mut s = self.ffn_b1[j];
680                for m in 0..d {
681                    s += x[[i, m]] * self.ffn_w1[[m, j]];
682                }
683                h[[i, j]] = s.max(0.0); // ReLU
684            }
685        }
686
687        let mut out = Array2::zeros((n, d));
688        for i in 0..n {
689            for j in 0..d {
690                let mut s = self.ffn_b2[j];
691                for m in 0..ffn_dim {
692                    s += h[[i, m]] * self.ffn_w2[[m, j]];
693                }
694                out[[i, j]] = s;
695            }
696        }
697
698        out
699    }
700
701    /// Forward pass of one GPS layer.
702    ///
703    /// # Arguments
704    /// * `x` - Input features `[n, hidden_dim]`
705    /// * `adj` - Adjacency matrix for local MPNN
706    pub fn forward(&self, x: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
707        let (n, d) = x.dim();
708        if d != self.hidden_dim {
709            return Err(GraphError::InvalidParameter {
710                param: "x".to_string(),
711                value: format!("dim={d}"),
712                expected: format!("dim={}", self.hidden_dim),
713                context: "GpsLayer::forward".to_string(),
714            });
715        }
716
717        // Local MPNN
718        let local_out = self.gin_local.forward(x, adj);
719
720        // Global attention
721        let global_out = self.global_attn.forward(x);
722
723        // FFN
724        let ffn_out = self.ffn(x);
725
726        // Combine with residual: x + local + global + ffn
727        let mut out = x.clone();
728        for i in 0..n {
729            for j in 0..d {
730                out[[i, j]] += local_out[[i, j]] + global_out[[i, j]] + ffn_out[[i, j]];
731            }
732        }
733
734        // Layer normalization
735        for i in 0..n {
736            let mut row: Vec<f64> = (0..d).map(|j| out[[i, j]]).collect();
737            layer_norm_vec(&mut row, self.layer_norm_eps);
738            for j in 0..d {
739                out[[i, j]] = row[j];
740            }
741        }
742
743        Ok(out)
744    }
745}
746
747// ============================================================================
748// GPS Model
749// ============================================================================
750
751/// Full GPS (General Powerful Scalable) Graph Transformer model.
752#[derive(Debug, Clone)]
753pub struct GpsModel {
754    /// Input projection: `[in_dim + pe_dim, hidden_dim]`
755    pub input_proj: Array2<f64>,
756    /// Random Walk PE module
757    pub rwpe: RandomWalkPe,
758    /// Stack of GPS layers
759    pub layers: Vec<GpsLayer>,
760    /// Configuration
761    pub config: GpsConfig,
762}
763
764impl GpsModel {
765    /// Create a new GPS model from configuration.
766    pub fn new(config: GpsConfig) -> Result<Self> {
767        let mut rng = scirs2_core::random::rng();
768        let total_in = config.in_dim + config.pe_dim;
769        let proj_scale = (6.0_f64 / (total_in + config.hidden_dim) as f64).sqrt();
770        let input_proj = Array2::from_shape_fn((total_in, config.hidden_dim), |_| {
771            (rng.random::<f64>() * 2.0 - 1.0) * proj_scale
772        });
773
774        let rwpe = RandomWalkPe::new(config.rw_walk_length, config.pe_dim);
775
776        let mut layers = Vec::with_capacity(config.num_layers);
777        for _ in 0..config.num_layers {
778            layers.push(GpsLayer::new(
779                config.hidden_dim,
780                config.num_heads,
781                config.ffn_dim,
782                config.layer_norm_eps,
783            )?);
784        }
785
786        Ok(GpsModel {
787            input_proj,
788            rwpe,
789            layers,
790            config,
791        })
792    }
793
794    /// Forward pass of the full GPS model.
795    ///
796    /// # Arguments
797    /// * `features` - Input node features `[n_nodes, in_dim]`
798    /// * `adj` - Sparse adjacency matrix
799    ///
800    /// # Returns
801    /// Node embeddings `[n_nodes, hidden_dim]`
802    pub fn forward(&self, features: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
803        let (n, in_dim) = features.dim();
804        if in_dim != self.config.in_dim {
805            return Err(GraphError::InvalidParameter {
806                param: "features".to_string(),
807                value: format!("in_dim={in_dim}"),
808                expected: format!("in_dim={}", self.config.in_dim),
809                context: "GpsModel::forward".to_string(),
810            });
811        }
812        if adj.n_rows != n {
813            return Err(GraphError::InvalidParameter {
814                param: "adj".to_string(),
815                value: format!("n_rows={}", adj.n_rows),
816                expected: format!("n_rows={n}"),
817                context: "GpsModel::forward".to_string(),
818            });
819        }
820
821        // Compute RWPE
822        let pe = self.rwpe.forward(adj);
823
824        // Concatenate features with PE: [n, in_dim + pe_dim]
825        let total_in = self.config.in_dim + self.config.pe_dim;
826        let mut concat = Array2::zeros((n, total_in));
827        for i in 0..n {
828            for j in 0..in_dim {
829                concat[[i, j]] = features[[i, j]];
830            }
831            for j in 0..self.config.pe_dim {
832                concat[[i, in_dim + j]] = pe[[i, j]];
833            }
834        }
835
836        // Project to hidden_dim
837        let d = self.config.hidden_dim;
838        let mut h = Array2::zeros((n, d));
839        for i in 0..n {
840            for j in 0..d {
841                let mut s = 0.0;
842                for m in 0..total_in {
843                    s += concat[[i, m]] * self.input_proj[[m, j]];
844                }
845                h[[i, j]] = s;
846            }
847        }
848
849        // Apply GPS layers
850        for layer in &self.layers {
851            h = layer.forward(&h, adj)?;
852        }
853
854        Ok(h)
855    }
856}
857
858// ============================================================================
859// Tests
860// ============================================================================
861
862#[cfg(test)]
863mod tests {
864    use super::*;
865
866    fn triangle_csr() -> CsrMatrix {
867        let coo = vec![
868            (0, 1, 1.0),
869            (1, 0, 1.0),
870            (1, 2, 1.0),
871            (2, 1, 1.0),
872            (0, 2, 1.0),
873            (2, 0, 1.0),
874        ];
875        CsrMatrix::from_coo(3, 3, &coo).expect("triangle CSR")
876    }
877
878    fn path_csr() -> CsrMatrix {
879        let coo = vec![
880            (0, 1, 1.0),
881            (1, 0, 1.0),
882            (1, 2, 1.0),
883            (2, 1, 1.0),
884            (2, 3, 1.0),
885            (3, 2, 1.0),
886        ];
887        CsrMatrix::from_coo(4, 4, &coo).expect("path CSR")
888    }
889
890    fn feats(n: usize, d: usize) -> Array2<f64> {
891        Array2::from_shape_fn((n, d), |(i, j)| (i * d + j) as f64 * 0.1)
892    }
893
894    #[test]
895    fn test_rwpe_landing_probs_shape() {
896        let adj = triangle_csr();
897        let rwpe = RandomWalkPe::new(4, 8);
898        let landing = rwpe.compute_landing_probs(&adj);
899        assert_eq!(landing.dim(), (3, 4));
900        for &v in landing.iter() {
901            assert!(v.is_finite(), "landing prob should be finite, got {v}");
902            assert!(v >= 0.0, "landing prob should be non-negative, got {v}");
903        }
904    }
905
906    #[test]
907    fn test_rwpe_produces_correct_features() {
908        let adj = triangle_csr();
909        let rwpe = RandomWalkPe::new(3, 6);
910        let pe = rwpe.forward(&adj);
911        assert_eq!(pe.dim(), (3, 6));
912
913        // For a complete triangle, all nodes have the same structure
914        // so their landing probabilities should be identical
915        let landing = rwpe.compute_landing_probs(&adj);
916        for k in 0..3 {
917            let val0 = landing[[0, k]];
918            let val1 = landing[[1, k]];
919            let val2 = landing[[2, k]];
920            assert!(
921                (val0 - val1).abs() < 1e-10 && (val1 - val2).abs() < 1e-10,
922                "symmetric graph should have equal landing probs at step {k}: {val0}, {val1}, {val2}"
923            );
924        }
925    }
926
927    #[test]
928    fn test_rwpe_path_graph_different_probs() {
929        let adj = path_csr();
930        let rwpe = RandomWalkPe::new(3, 4);
931        let landing = rwpe.compute_landing_probs(&adj);
932        assert_eq!(landing.dim(), (4, 3));
933
934        // Endpoints (degree 1) vs interior (degree 2) should differ
935        let end_prob = landing[[0, 0]]; // P^1 diagonal for endpoint
936        let mid_prob = landing[[1, 0]]; // P^1 diagonal for middle node
937
938        // P^1 diagonal = probability of returning in 1 step = 0 for all (no self-loops)
939        // P^2 diagonal should differ: endpoint returns with prob 1 (only neighbor sends back),
940        // middle node returns with prob 1/2 from each of 2 neighbors = 1/2
941        // Actually, let's just check the values are finite and make sense
942        assert!(end_prob.is_finite());
943        assert!(mid_prob.is_finite());
944    }
945
946    #[test]
947    fn test_laplacian_pe_shape() {
948        let adj = triangle_csr();
949        let lpe = LaplacianPe::new(2, 6);
950        let pe = lpe.forward(&adj);
951        assert_eq!(pe.dim(), (3, 6));
952        for &v in pe.iter() {
953            assert!(v.is_finite(), "Laplacian PE should be finite, got {v}");
954        }
955    }
956
957    #[test]
958    fn test_gps_hybrid_combines_local_and_global() {
959        let adj = triangle_csr();
960        let features = feats(3, 8);
961
962        let config = GpsConfig {
963            in_dim: 8,
964            hidden_dim: 8,
965            num_heads: 2,
966            num_layers: 1,
967            ffn_dim: 16,
968            local_model: LocalModel::Gin,
969            pe_dim: 4,
970            rw_walk_length: 3,
971            ..Default::default()
972        };
973
974        let model = GpsModel::new(config).expect("GPS model");
975        let out = model.forward(&features, &adj).expect("GPS forward");
976        assert_eq!(out.dim(), (3, 8));
977
978        for &v in out.iter() {
979            assert!(v.is_finite(), "GPS output should be finite, got {v}");
980        }
981
982        // Output should differ from trivially projected input
983        let has_variation = out.iter().any(|&v| v.abs() > 1e-12);
984        assert!(has_variation, "GPS output should have non-trivial values");
985    }
986
987    #[test]
988    fn test_gps_layer_forward_shape() {
989        let adj = triangle_csr();
990        let x = feats(3, 8);
991        let layer = GpsLayer::new(8, 2, 16, 1e-5).expect("GPS layer");
992        let out = layer.forward(&x, &adj).expect("GPS layer forward");
993        assert_eq!(out.dim(), (3, 8));
994
995        // After layer norm, output should have approximately zero mean per node
996        for i in 0..3 {
997            let mean: f64 = (0..8).map(|j| out[[i, j]]).sum::<f64>() / 8.0;
998            assert!(
999                mean.abs() < 0.1,
1000                "after layer norm, mean should be near 0, got {mean}"
1001            );
1002        }
1003    }
1004
1005    #[test]
1006    fn test_gps_multi_layer() {
1007        let adj = path_csr();
1008        let config = GpsConfig {
1009            in_dim: 4,
1010            hidden_dim: 8,
1011            num_heads: 2,
1012            num_layers: 3,
1013            ffn_dim: 16,
1014            pe_dim: 4,
1015            rw_walk_length: 3,
1016            ..Default::default()
1017        };
1018
1019        let model = GpsModel::new(config).expect("GPS model");
1020        let features = feats(4, 4);
1021        let out = model.forward(&features, &adj).expect("GPS forward");
1022        assert_eq!(out.dim(), (4, 8));
1023        for &v in out.iter() {
1024            assert!(v.is_finite(), "multi-layer GPS output should be finite");
1025        }
1026    }
1027
1028    #[test]
1029    fn test_gps_invalid_dim_error() {
1030        let adj = triangle_csr();
1031        let config = GpsConfig {
1032            in_dim: 4,
1033            hidden_dim: 7, // not divisible by num_heads=4
1034            num_heads: 4,
1035            ..Default::default()
1036        };
1037        let result = GpsModel::new(config);
1038        assert!(result.is_err());
1039    }
1040
1041    #[test]
1042    fn test_gin_local_aggregation() {
1043        let adj = triangle_csr();
1044        let x = feats(3, 8);
1045        let gin = GinLocal::new(8);
1046        let out = gin.forward(&x, &adj);
1047        assert_eq!(out.dim(), (3, 8));
1048        for &v in out.iter() {
1049            assert!(v.is_finite());
1050        }
1051    }
1052}