Skip to main content

scirs2_graph/graph_transformer/
graphgps.rs

1//! GraphGPS: General, Powerful, Scalable Graph Transformers
2//!
3//! Implements Rampasek et al. 2022 "Recipe for a General, Powerful, Scalable
4//! Graph Transformer".  Each GPS layer combines:
5//! - **Local MPNN**: message-passing over edges (mean aggregation + linear)
6//! - **Global Transformer**: full self-attention over all nodes with PE bias
7//! - **Combination gate**: learned α balances local vs. global branch
8//! - **FFN**: two-layer MLP with GELU activation
9
10use crate::error::{GraphError, Result};
11
12use super::types::{GraphForTransformer, GraphTransformerConfig, GraphTransformerOutput};
13
14// ============================================================================
15// Activation helpers
16// ============================================================================
17
18/// GELU activation (tanh approximation)
19#[inline]
20fn gelu(x: f64) -> f64 {
21    0.5 * x * (1.0 + (0.797_884_560_802_865_4 * (x + 0.044_715 * x * x * x)).tanh())
22}
23
24/// Layer normalisation: (x - mean) / std with learned γ=1, β=0
25fn layer_norm(x: &[f64]) -> Vec<f64> {
26    let n = x.len() as f64;
27    if n == 0.0 {
28        return Vec::new();
29    }
30    let mean = x.iter().sum::<f64>() / n;
31    let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n;
32    let std = (var + 1e-6).sqrt();
33    x.iter().map(|v| (v - mean) / std).collect()
34}
35
36/// Softmax over a slice
37fn softmax(xs: &[f64]) -> Vec<f64> {
38    if xs.is_empty() {
39        return Vec::new();
40    }
41    let max_v = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
42    let exps: Vec<f64> = xs.iter().map(|&v| (v - max_v).exp()).collect();
43    let sum = exps.iter().sum::<f64>().max(1e-15);
44    exps.iter().map(|e| e / sum).collect()
45}
46
47/// Dense matrix–vector multiply
48fn mv(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
49    w.iter()
50        .map(|row| row.iter().zip(x.iter()).map(|(a, b)| a * b).sum())
51        .collect()
52}
53
54/// Element-wise vector addition
55fn vadd(a: &[f64], b: &[f64]) -> Vec<f64> {
56    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
57}
58
59// ============================================================================
60// Learnable weight matrices — initialised with deterministic LCG values
61// ============================================================================
62
63/// A simple LCG pseudo-random number generator for weight initialisation.
64struct Lcg {
65    state: u64,
66}
67
68impl Lcg {
69    fn new(seed: u64) -> Self {
70        Self {
71            state: seed ^ 0x5851_f42d_4c95_7f2d,
72        }
73    }
74
75    fn next_f64(&mut self) -> f64 {
76        self.state = self
77            .state
78            .wrapping_mul(6_364_136_223_846_793_005)
79            .wrapping_add(1_442_695_040_888_963_407);
80        let bits = (self.state >> 33) as i32;
81        (bits as f64) / (i32::MAX as f64)
82    }
83
84    /// He initialisation scale: sqrt(2 / fan_in)
85    fn he_matrix(&mut self, rows: usize, cols: usize) -> Vec<Vec<f64>> {
86        let scale = (2.0 / cols as f64).sqrt();
87        (0..rows)
88            .map(|_| (0..cols).map(|_| self.next_f64() * scale).collect())
89            .collect()
90    }
91}
92
93// ============================================================================
94// GpsLayer
95// ============================================================================
96
97/// A single GPS layer: parallel MPNN + Transformer + FFN.
98struct GpsLayer {
99    hidden_dim: usize,
100    n_heads: usize,
101    pe_dim: usize,
102
103    // Local MPNN branch
104    w_msg: Vec<Vec<f64>>, // hidden_dim × hidden_dim
105
106    // Global Transformer: Q, K, V projections per head
107    wq: Vec<Vec<Vec<f64>>>, // n_heads × head_dim × (hidden_dim + pe_dim)
108    wk: Vec<Vec<Vec<f64>>>,
109    wv: Vec<Vec<Vec<f64>>>,
110    w_out: Vec<Vec<f64>>, // hidden_dim × hidden_dim  (concat head outputs → hidden)
111
112    // Combination gate parameter (α initialised to 0.5)
113    alpha: f64,
114
115    // FFN
116    w_ff1: Vec<Vec<f64>>, // 4*hidden_dim × hidden_dim
117    w_ff2: Vec<Vec<f64>>, // hidden_dim × 4*hidden_dim
118}
119
120impl GpsLayer {
121    fn new(hidden_dim: usize, n_heads: usize, pe_dim: usize, seed: u64) -> Self {
122        let head_dim = (hidden_dim / n_heads).max(1);
123        let in_dim = hidden_dim + pe_dim;
124        let mut lcg = Lcg::new(seed);
125
126        let wq = (0..n_heads)
127            .map(|_| lcg.he_matrix(head_dim, in_dim))
128            .collect();
129        let wk = (0..n_heads)
130            .map(|_| lcg.he_matrix(head_dim, in_dim))
131            .collect();
132        let wv = (0..n_heads)
133            .map(|_| lcg.he_matrix(head_dim, hidden_dim))
134            .collect();
135        let w_out = lcg.he_matrix(hidden_dim, hidden_dim);
136        let w_msg = lcg.he_matrix(hidden_dim, hidden_dim);
137        let w_ff1 = lcg.he_matrix(4 * hidden_dim, hidden_dim);
138        let w_ff2 = lcg.he_matrix(hidden_dim, 4 * hidden_dim);
139
140        Self {
141            hidden_dim,
142            n_heads,
143            pe_dim,
144            w_msg,
145            wq,
146            wk,
147            wv,
148            w_out,
149            alpha: 0.5,
150            w_ff1,
151            w_ff2,
152        }
153    }
154
155    /// Local MPNN: for each node aggregate neighbour features then apply W_msg + ReLU.
156    fn local_mpnn(&self, h: &[Vec<f64>], adj: &[Vec<usize>]) -> Vec<Vec<f64>> {
157        let n = h.len();
158        let mut out = vec![vec![0.0_f64; self.hidden_dim]; n];
159        for i in 0..n {
160            let nbrs = &adj[i];
161            let agg = if nbrs.is_empty() {
162                h[i].clone()
163            } else {
164                // Mean aggregation
165                let mut sum = vec![0.0_f64; self.hidden_dim];
166                for &j in nbrs {
167                    for d in 0..self.hidden_dim.min(h[j].len()) {
168                        sum[d] += h[j][d];
169                    }
170                }
171                let cnt = nbrs.len() as f64;
172                sum.iter().map(|v| v / cnt).collect()
173            };
174            let msg = mv(&self.w_msg, &agg);
175            out[i] = msg.into_iter().map(|v| v.max(0.0)).collect(); // ReLU
176        }
177        out
178    }
179
180    /// Global Transformer: full self-attention with PE bias on Q and K.
181    ///
182    /// Returns node embeddings and a flat attention weight vector (for testing).
183    fn global_transformer(&self, h: &[Vec<f64>], pe: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<f64>) {
184        let n = h.len();
185        if n == 0 {
186            return (Vec::new(), Vec::new());
187        }
188        let head_dim = (self.hidden_dim / self.n_heads).max(1);
189        let scale = (head_dim as f64).sqrt().max(1e-6);
190
191        // Augmented input: concat h[i] and pe[i]
192        let aug: Vec<Vec<f64>> = (0..n)
193            .map(|i| {
194                let hi = if i < h.len() { &h[i] } else { &h[0] };
195                let pi = if i < pe.len() { &pe[i] } else { &pe[0] };
196                let mut v = hi.clone();
197                v.extend_from_slice(pi);
198                v
199            })
200            .collect();
201
202        // Multi-head attention
203        let mut head_outputs: Vec<Vec<Vec<f64>>> = Vec::with_capacity(self.n_heads);
204        let mut all_attn: Vec<f64> = Vec::new();
205
206        for hd in 0..self.n_heads {
207            // Compute Q, K for augmented input; V for h only
208            let q: Vec<Vec<f64>> = aug.iter().map(|a| mv(&self.wq[hd], a)).collect();
209            let k: Vec<Vec<f64>> = aug.iter().map(|a| mv(&self.wk[hd], a)).collect();
210            let v: Vec<Vec<f64>> = h.iter().map(|hi| mv(&self.wv[hd], hi)).collect();
211
212            // Attention matrix n×n
213            let mut attn_logits = vec![vec![0.0_f64; n]; n];
214            for i in 0..n {
215                for j in 0..n {
216                    let dot: f64 = q[i].iter().zip(k[j].iter()).map(|(a, b)| a * b).sum();
217                    attn_logits[i][j] = dot / scale;
218                }
219            }
220
221            // Softmax over each row
222            let attn_weights: Vec<Vec<f64>> = attn_logits.iter().map(|row| softmax(row)).collect();
223
224            if hd == 0 {
225                // Collect first-head weights for test inspection
226                for row in &attn_weights {
227                    all_attn.extend_from_slice(row);
228                }
229            }
230
231            // Weighted sum of V
232            let mut head_out = vec![vec![0.0_f64; head_dim.min(v[0].len())]; n];
233            for i in 0..n {
234                for j in 0..n {
235                    let vj_len = v[j].len().min(head_dim);
236                    for d in 0..vj_len {
237                        head_out[i][d] += attn_weights[i][j] * v[j][d];
238                    }
239                }
240            }
241            head_outputs.push(head_out);
242        }
243
244        // Concatenate heads (take first hidden_dim dims from each head output)
245        let head_dim_out = (self.hidden_dim / self.n_heads).max(1);
246        let mut concat = vec![vec![0.0_f64; self.hidden_dim]; n];
247        for i in 0..n {
248            for hd in 0..self.n_heads {
249                let start = hd * head_dim_out;
250                let end = (start + head_dim_out).min(self.hidden_dim);
251                for d in start..end {
252                    let local_d = d - start;
253                    if local_d < head_outputs[hd][i].len() {
254                        concat[i][d] = head_outputs[hd][i][local_d];
255                    }
256                }
257            }
258        }
259
260        // Final output projection
261        let out: Vec<Vec<f64>> = concat.iter().map(|c| mv(&self.w_out, c)).collect();
262        (out, all_attn)
263    }
264
265    /// FFN: 2-layer MLP with GELU.
266    fn ffn(&self, h: &[Vec<f64>]) -> Vec<Vec<f64>> {
267        h.iter()
268            .map(|x| {
269                let mid: Vec<f64> = mv(&self.w_ff1, x).into_iter().map(gelu).collect();
270                mv(&self.w_ff2, &mid)
271            })
272            .collect()
273    }
274
275    /// Full GPS layer forward pass.
276    pub fn forward(
277        &self,
278        h: &[Vec<f64>],
279        adj: &[Vec<usize>],
280        pe: &[Vec<f64>],
281    ) -> (Vec<Vec<f64>>, Vec<f64>) {
282        let n = h.len();
283        if n == 0 {
284            return (Vec::new(), Vec::new());
285        }
286
287        // Ensure h has the right dimensionality (pad if needed)
288        let h_norm: Vec<Vec<f64>> = h
289            .iter()
290            .map(|row| {
291                let mut r = row.clone();
292                r.resize(self.hidden_dim, 0.0);
293                r
294            })
295            .collect();
296
297        // Ensure pe has the right dimensionality
298        let pe_norm: Vec<Vec<f64>> = pe
299            .iter()
300            .map(|row| {
301                let mut r = row.clone();
302                r.resize(self.pe_dim, 0.0);
303                r
304            })
305            .collect();
306
307        let h_local = self.local_mpnn(&h_norm, adj);
308        let (h_global, attn_weights) = self.global_transformer(&h_norm, &pe_norm);
309
310        // Combine: h_out = LN(h + α·h_local + (1-α)·h_global)
311        let alpha = self.alpha.clamp(0.0, 1.0);
312        let combined: Vec<Vec<f64>> = (0..n)
313            .map(|i| {
314                let combined_raw: Vec<f64> = (0..self.hidden_dim)
315                    .map(|d| h_norm[i][d] + alpha * h_local[i][d] + (1.0 - alpha) * h_global[i][d])
316                    .collect();
317                layer_norm(&combined_raw)
318            })
319            .collect();
320
321        // FFN + residual + LayerNorm
322        let h_ffn = self.ffn(&combined);
323        let h_out: Vec<Vec<f64>> = (0..n)
324            .map(|i| {
325                let res = vadd(&combined[i], &h_ffn[i]);
326                layer_norm(&res)
327            })
328            .collect();
329
330        (h_out, attn_weights)
331    }
332}
333
334// ============================================================================
335// GpsModel
336// ============================================================================
337
338/// Full GPS model: stack of GPS layers + mean-pooling.
339pub struct GpsModel {
340    layers: Vec<GpsLayer>,
341    hidden_dim: usize,
342    pe_dim: usize,
343    /// Input projection: hidden_dim × feat_dim  (built lazily on first forward)
344    w_in: Option<Vec<Vec<f64>>>,
345    feat_dim: usize,
346}
347
348impl GpsModel {
349    /// Create a new GPS model from configuration.
350    pub fn new(config: &GraphTransformerConfig) -> Self {
351        let pe_dim = config.pe_dim;
352        let hidden_dim = config.hidden_dim;
353        let n_heads = config.n_heads.max(1);
354        let layers: Vec<GpsLayer> = (0..config.n_layers)
355            .map(|i| {
356                let seed = (i as u64)
357                    .wrapping_add(1)
358                    .wrapping_mul(0x9e37_79b9_7f4a_7c15_u64);
359                GpsLayer::new(hidden_dim, n_heads, pe_dim, seed)
360            })
361            .collect();
362
363        Self {
364            layers,
365            hidden_dim,
366            pe_dim,
367            w_in: None,
368            feat_dim: 0,
369        }
370    }
371
372    /// Build or return the input-projection matrix.
373    fn ensure_w_in(&mut self, feat_dim: usize) {
374        if self.w_in.is_none() || self.feat_dim != feat_dim {
375            let mut lcg = Lcg::new(0xdead_beef_cafe_babe);
376            self.w_in = Some(lcg.he_matrix(self.hidden_dim, feat_dim.max(1)));
377            self.feat_dim = feat_dim;
378        }
379    }
380
381    /// Run a forward pass through the GPS model.
382    ///
383    /// `pe` should be an `n × pe_dim` matrix produced by `laplacian_pe` or `rwpe`.
384    pub fn forward(
385        &mut self,
386        graph: &GraphForTransformer,
387        pe: &[Vec<f64>],
388    ) -> Result<(GraphTransformerOutput, Vec<f64>)> {
389        let n = graph.n_nodes;
390        if n == 0 {
391            return Ok((
392                GraphTransformerOutput {
393                    node_embeddings: Vec::new(),
394                    graph_embedding: vec![0.0; self.hidden_dim],
395                },
396                Vec::new(),
397            ));
398        }
399
400        let feat_dim = graph
401            .node_features
402            .first()
403            .map(|r| r.len())
404            .unwrap_or(1)
405            .max(1);
406        self.ensure_w_in(feat_dim);
407
408        let w_in = self
409            .w_in
410            .as_ref()
411            .ok_or_else(|| GraphError::InvalidParameter {
412                param: "w_in".to_string(),
413                value: "None".to_string(),
414                expected: "initialised weight matrix".to_string(),
415                context: "GpsModel forward".to_string(),
416            })?;
417
418        // Project features to hidden_dim
419        let mut h: Vec<Vec<f64>> = graph.node_features.iter().map(|f| mv(w_in, f)).collect();
420
421        let mut last_attn: Vec<f64> = Vec::new();
422        for layer in &self.layers {
423            let (h_new, attn) = layer.forward(&h, &graph.adjacency, pe);
424            h = h_new;
425            last_attn = attn;
426        }
427
428        // Mean-pool for graph embedding
429        let mut graph_emb = vec![0.0_f64; self.hidden_dim];
430        for row in &h {
431            for (d, &v) in row.iter().enumerate() {
432                if d < self.hidden_dim {
433                    graph_emb[d] += v;
434                }
435            }
436        }
437        let inv_n = 1.0 / n as f64;
438        for v in graph_emb.iter_mut() {
439            *v *= inv_n;
440        }
441
442        Ok((
443            GraphTransformerOutput {
444                node_embeddings: h,
445                graph_embedding: graph_emb,
446            },
447            last_attn,
448        ))
449    }
450}
451
452// ============================================================================
453// Unit tests
454// ============================================================================
455
456#[cfg(test)]
457mod tests {
458    use super::super::positional_encoding::{laplacian_pe, rwpe};
459    use super::super::types::{GraphForTransformer, GraphTransformerConfig};
460    use super::*;
461
462    fn triangle_graph() -> GraphForTransformer {
463        GraphForTransformer::new(
464            vec![vec![1, 2], vec![0, 2], vec![0, 1]],
465            vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
466        )
467        .expect("valid graph")
468    }
469
470    fn single_node_graph() -> GraphForTransformer {
471        GraphForTransformer::new(vec![vec![]], vec![vec![1.0]]).expect("valid graph")
472    }
473
474    fn default_config() -> GraphTransformerConfig {
475        GraphTransformerConfig {
476            n_heads: 2,
477            hidden_dim: 8,
478            n_layers: 1,
479            dropout: 0.0,
480            pe_type: super::super::types::PeType::LapPE,
481            pe_dim: 4,
482        }
483    }
484
485    fn two_layer_config() -> GraphTransformerConfig {
486        GraphTransformerConfig {
487            n_layers: 2,
488            ..default_config()
489        }
490    }
491
492    #[test]
493    fn test_gps_output_shape() {
494        let g = triangle_graph();
495        let pe = laplacian_pe(&g.adjacency, 4);
496        let cfg = default_config();
497        let mut model = GpsModel::new(&cfg);
498        let (out, _) = model.forward(&g, &pe).expect("forward ok");
499        assert_eq!(out.node_embeddings.len(), 3);
500        for row in &out.node_embeddings {
501            assert_eq!(row.len(), 8);
502        }
503    }
504
505    #[test]
506    fn test_gps_graph_embedding_shape() {
507        let g = triangle_graph();
508        let pe = laplacian_pe(&g.adjacency, 4);
509        let cfg = default_config();
510        let mut model = GpsModel::new(&cfg);
511        let (out, _) = model.forward(&g, &pe).expect("forward ok");
512        assert_eq!(out.graph_embedding.len(), 8);
513    }
514
515    #[test]
516    fn test_gps_single_node() {
517        let g = single_node_graph();
518        let pe = laplacian_pe(&g.adjacency, 4);
519        let cfg = default_config();
520        let mut model = GpsModel::new(&cfg);
521        let (out, _) = model.forward(&g, &pe).expect("forward ok");
522        assert_eq!(out.node_embeddings.len(), 1);
523        assert_eq!(out.graph_embedding.len(), 8);
524    }
525
526    #[test]
527    fn test_gps_no_edges() {
528        // 3 isolated nodes
529        let g = GraphForTransformer::new(
530            vec![vec![], vec![], vec![]],
531            vec![vec![1.0], vec![2.0], vec![3.0]],
532        )
533        .expect("valid");
534        let pe = rwpe(&g.adjacency, 4);
535        let cfg = default_config();
536        let mut model = GpsModel::new(&cfg);
537        let (out, _) = model.forward(&g, &pe).expect("forward ok");
538        assert_eq!(out.node_embeddings.len(), 3);
539    }
540
541    #[test]
542    fn test_gps_attention_softmax() {
543        // Attention weights returned for first head; each row should sum ≈ 1
544        let g = triangle_graph();
545        let pe = laplacian_pe(&g.adjacency, 4);
546        let cfg = default_config();
547        let layer = GpsLayer::new(cfg.hidden_dim, cfg.n_heads, cfg.pe_dim, 42);
548        let h: Vec<Vec<f64>> = g
549            .node_features
550            .iter()
551            .map(|f| {
552                let mut r = f.clone();
553                r.resize(cfg.hidden_dim, 0.0);
554                r
555            })
556            .collect();
557        let pe_norm: Vec<Vec<f64>> = pe
558            .iter()
559            .map(|p| {
560                let mut r = p.clone();
561                r.resize(cfg.pe_dim, 0.0);
562                r
563            })
564            .collect();
565        let (_out, attn) = layer.global_transformer(&h, &pe_norm);
566        // attn contains n*n values for first head; each row of n should sum to ~1
567        let n = g.n_nodes;
568        for i in 0..n {
569            let row_sum: f64 = (0..n).map(|j| attn[i * n + j]).sum();
570            assert!((row_sum - 1.0).abs() < 1e-10, "row {i} sum={row_sum}");
571        }
572    }
573
574    #[test]
575    fn test_gps_layers_stack() {
576        let g = triangle_graph();
577        let pe = laplacian_pe(&g.adjacency, 4);
578        let cfg = two_layer_config();
579        let mut model = GpsModel::new(&cfg);
580        let (out, _) = model.forward(&g, &pe).expect("2-layer forward ok");
581        assert_eq!(out.node_embeddings.len(), 3);
582        for row in &out.node_embeddings {
583            assert_eq!(row.len(), 8);
584            // Values should be finite
585            for &v in row {
586                assert!(v.is_finite(), "non-finite value in output");
587            }
588        }
589    }
590}