Skip to main content

scirs2_graph/graph_transformer/
graphormer.rs

1//! Graphormer: Transformers for Graph Structured Data
2//!
3//! Implements Ying et al. 2021 "Do Transformers Really Perform Bad for Graph
4//! Representation?" with the following structural encodings:
5//!
6//! - **Degree embedding**: learnable table indexed by node degree (in + out)
7//! - **Spatial encoding**: learnable bias b(i,j) added to attention logits,
8//!   indexed by SPD(i,j) capped at `max_shortest_path`
9//! - **Edge encoding**: constant-weight summation along shortest paths (simplified)
10//! - **Virtual graph token**: a global "\[GRAPH\]" super-node attending to all
11//!   other tokens, appended as the last token
12
13use super::positional_encoding::all_pairs_shortest_path;
14use super::types::{GraphForTransformer, GraphTransformerOutput, GraphormerConfig};
15use crate::error::Result;
16
17// ============================================================================
18// Helpers
19// ============================================================================
20
21/// Softmax over a slice.
22fn softmax(xs: &[f64]) -> Vec<f64> {
23    if xs.is_empty() {
24        return Vec::new();
25    }
26    let max_v = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
27    let exps: Vec<f64> = xs.iter().map(|&v| (v - max_v).exp()).collect();
28    let sum = exps.iter().sum::<f64>().max(1e-15);
29    exps.iter().map(|e| e / sum).collect()
30}
31
32/// Dense matrix–vector multiply.
33fn mv(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
34    w.iter()
35        .map(|row| row.iter().zip(x.iter()).map(|(a, b)| a * b).sum())
36        .collect()
37}
38
39/// Layer normalisation.
40fn layer_norm(x: &[f64]) -> Vec<f64> {
41    let n = x.len() as f64;
42    if n == 0.0 {
43        return Vec::new();
44    }
45    let mean = x.iter().sum::<f64>() / n;
46    let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n;
47    let std = (var + 1e-6).sqrt();
48    x.iter().map(|v| (v - mean) / std).collect()
49}
50
51/// GELU activation.
52#[inline]
53fn gelu(x: f64) -> f64 {
54    0.5 * x * (1.0 + (0.797_884_560_802_865_4 * (x + 0.044_715 * x * x * x)).tanh())
55}
56
57// ============================================================================
58// LCG weight initialiser
59// ============================================================================
60
61struct Lcg {
62    state: u64,
63}
64
65impl Lcg {
66    fn new(seed: u64) -> Self {
67        Self {
68            state: seed ^ 0x5851_f42d_4c95_7f2d,
69        }
70    }
71
72    fn next_f64(&mut self) -> f64 {
73        self.state = self
74            .state
75            .wrapping_mul(6_364_136_223_846_793_005)
76            .wrapping_add(1_442_695_040_888_963_407);
77        let bits = (self.state >> 33) as i32;
78        (bits as f64) / (i32::MAX as f64)
79    }
80
81    fn he_matrix(&mut self, rows: usize, cols: usize) -> Vec<Vec<f64>> {
82        let scale = (2.0 / cols.max(1) as f64).sqrt();
83        (0..rows)
84            .map(|_| (0..cols).map(|_| self.next_f64() * scale).collect())
85            .collect()
86    }
87
88    fn he_vec(&mut self, len: usize) -> Vec<f64> {
89        let scale = (2.0 / len.max(1) as f64).sqrt();
90        (0..len).map(|_| self.next_f64() * scale).collect()
91    }
92}
93
94// ============================================================================
95// GraphormerModel
96// ============================================================================
97
98/// Graphormer model: degree / spatial / edge encodings + Transformer layers.
99pub struct GraphormerModel {
100    config: GraphormerConfig,
101
102    /// Degree embedding table: (max_degree + 1) × hidden_dim
103    deg_emb: Vec<Vec<f64>>,
104
105    /// Spatial bias table: (max_shortest_path + 2) × 1
106    /// Index 0  = same node (SPD = 0)
107    /// Index d  = SPD = d  (1 ≤ d ≤ max_shortest_path)
108    /// Index max+1 = disconnected
109    spatial_bias: Vec<f64>,
110
111    /// Input projection: hidden_dim × feat_dim  (built lazily)
112    w_in: Option<Vec<Vec<f64>>>,
113    feat_dim: usize,
114
115    /// Per-layer weights for Q, K, V, O, FFN
116    layers: Vec<TransformerLayerWeights>,
117}
118
119struct TransformerLayerWeights {
120    n_heads: usize,
121    head_dim: usize,
122    hidden_dim: usize,
123    wq: Vec<Vec<Vec<f64>>>, // n_heads × head_dim × hidden_dim
124    wk: Vec<Vec<Vec<f64>>>,
125    wv: Vec<Vec<Vec<f64>>>,
126    w_out: Vec<Vec<f64>>, // hidden_dim × hidden_dim
127    w_ff1: Vec<Vec<f64>>, // 4h × h
128    w_ff2: Vec<Vec<f64>>, // h × 4h
129}
130
131impl TransformerLayerWeights {
132    fn new(hidden_dim: usize, n_heads: usize, lcg: &mut Lcg) -> Self {
133        let head_dim = (hidden_dim / n_heads).max(1);
134        let wq = (0..n_heads)
135            .map(|_| lcg.he_matrix(head_dim, hidden_dim))
136            .collect();
137        let wk = (0..n_heads)
138            .map(|_| lcg.he_matrix(head_dim, hidden_dim))
139            .collect();
140        let wv = (0..n_heads)
141            .map(|_| lcg.he_matrix(head_dim, hidden_dim))
142            .collect();
143        let w_out = lcg.he_matrix(hidden_dim, hidden_dim);
144        let w_ff1 = lcg.he_matrix(4 * hidden_dim, hidden_dim);
145        let w_ff2 = lcg.he_matrix(hidden_dim, 4 * hidden_dim);
146        Self {
147            n_heads,
148            head_dim,
149            hidden_dim,
150            wq,
151            wk,
152            wv,
153            w_out,
154            w_ff1,
155            w_ff2,
156        }
157    }
158
159    /// Multi-head self-attention with per-pair spatial bias.
160    ///
161    /// `tokens`: (n+1) × hidden_dim (n real nodes + 1 virtual)
162    /// `spatial`: (n+1) × (n+1) additive logit bias (averaged over heads)
163    fn attention(&self, tokens: &[Vec<f64>], spatial: &[Vec<f64>]) -> Vec<Vec<f64>> {
164        let seq_len = tokens.len();
165        let scale = (self.head_dim as f64).sqrt().max(1e-6);
166
167        let mut concat = vec![vec![0.0_f64; self.hidden_dim]; seq_len];
168
169        for hd in 0..self.n_heads {
170            let q: Vec<Vec<f64>> = tokens.iter().map(|t| mv(&self.wq[hd], t)).collect();
171            let k: Vec<Vec<f64>> = tokens.iter().map(|t| mv(&self.wk[hd], t)).collect();
172            let v: Vec<Vec<f64>> = tokens.iter().map(|t| mv(&self.wv[hd], t)).collect();
173
174            // Compute attention logits + spatial bias
175            let mut attn = vec![vec![0.0_f64; seq_len]; seq_len];
176            for i in 0..seq_len {
177                for j in 0..seq_len {
178                    let dot: f64 = q[i].iter().zip(k[j].iter()).map(|(a, b)| a * b).sum();
179                    let bias = spatial
180                        .get(i)
181                        .and_then(|r| r.get(j))
182                        .copied()
183                        .unwrap_or(0.0);
184                    attn[i][j] = dot / scale + bias;
185                }
186                // Softmax in place
187                let sm = softmax(&attn[i]);
188                attn[i] = sm;
189            }
190
191            // Weighted value accumulation
192            let head_start = hd * self.head_dim;
193            let head_end = (head_start + self.head_dim).min(self.hidden_dim);
194            for i in 0..seq_len {
195                for j in 0..seq_len {
196                    let v_len = v[j].len().min(self.head_dim);
197                    for d in 0..v_len {
198                        let out_d = head_start + d;
199                        if out_d < head_end {
200                            concat[i][out_d] += attn[i][j] * v[j][d];
201                        }
202                    }
203                }
204            }
205        }
206
207        // Output projection
208        concat.iter().map(|c| mv(&self.w_out, c)).collect()
209    }
210
211    /// FFN: GELU + 2-layer MLP.
212    fn ffn(&self, h: &[Vec<f64>]) -> Vec<Vec<f64>> {
213        h.iter()
214            .map(|x| {
215                let mid: Vec<f64> = mv(&self.w_ff1, x).into_iter().map(gelu).collect();
216                mv(&self.w_ff2, &mid)
217            })
218            .collect()
219    }
220
221    /// One Transformer layer: attention + residual + LN + FFN + residual + LN.
222    fn forward(&self, tokens: &[Vec<f64>], spatial: &[Vec<f64>]) -> Vec<Vec<f64>> {
223        let attn_out = self.attention(tokens, spatial);
224        // Residual + LN
225        let h1: Vec<Vec<f64>> = tokens
226            .iter()
227            .zip(attn_out.iter())
228            .map(|(t, a)| {
229                layer_norm(
230                    &t.iter()
231                        .zip(a.iter())
232                        .map(|(x, y)| x + y)
233                        .collect::<Vec<_>>(),
234                )
235            })
236            .collect();
237        let ffn_out = self.ffn(&h1);
238        // Residual + LN
239        h1.iter()
240            .zip(ffn_out.iter())
241            .map(|(t, f)| {
242                layer_norm(
243                    &t.iter()
244                        .zip(f.iter())
245                        .map(|(x, y)| x + y)
246                        .collect::<Vec<_>>(),
247                )
248            })
249            .collect()
250    }
251}
252
253impl GraphormerModel {
254    /// Construct a new Graphormer from configuration.
255    pub fn new(config: &GraphormerConfig) -> Self {
256        let hidden_dim = config.hidden_dim;
257        let n_heads = config.n_heads.max(1);
258        let mut lcg = Lcg::new(0x1234_5678_9abc_def0);
259
260        // Degree embedding: table for degrees 0 ..= max_degree
261        let deg_emb: Vec<Vec<f64>> = (0..=config.max_degree)
262            .map(|_| lcg.he_vec(hidden_dim))
263            .collect();
264
265        // Spatial bias: scalar per SPD bucket (we share across heads for simplicity)
266        let n_buckets = config.max_shortest_path + 2;
267        let spatial_bias: Vec<f64> = (0..n_buckets).map(|_| lcg.next_f64() * 0.1).collect();
268
269        let layers: Vec<TransformerLayerWeights> = (0..config.n_layers)
270            .map(|_| TransformerLayerWeights::new(hidden_dim, n_heads, &mut lcg))
271            .collect();
272
273        Self {
274            config: config.clone(),
275            deg_emb,
276            spatial_bias,
277            w_in: None,
278            feat_dim: 0,
279            layers,
280        }
281    }
282
283    /// Build or refresh the input-projection matrix.
284    fn ensure_w_in(&mut self, feat_dim: usize) {
285        if self.w_in.is_none() || self.feat_dim != feat_dim {
286            let mut lcg = Lcg::new(0xfeed_face_dead_beef);
287            self.w_in = Some(lcg.he_matrix(self.config.hidden_dim, feat_dim.max(1)));
288            self.feat_dim = feat_dim;
289        }
290    }
291
292    /// Retrieve the degree embedding for a node, clamping to table size.
293    fn degree_embedding(&self, degree: usize) -> &Vec<f64> {
294        let idx = degree.min(self.config.max_degree);
295        &self.deg_emb[idx]
296    }
297
298    /// Map SPD to the index in `spatial_bias`.
299    fn spd_to_bucket(&self, spd: usize) -> usize {
300        if spd == 0 {
301            0
302        } else if spd == usize::MAX {
303            // Disconnected
304            self.config.max_shortest_path + 1
305        } else {
306            spd.min(self.config.max_shortest_path)
307        }
308    }
309
310    /// Run the Graphormer forward pass.
311    pub fn forward(&mut self, graph: &GraphForTransformer) -> Result<GraphTransformerOutput> {
312        let n = graph.n_nodes;
313        let hidden_dim = self.config.hidden_dim;
314
315        if n == 0 {
316            return Ok(GraphTransformerOutput {
317                node_embeddings: Vec::new(),
318                graph_embedding: vec![0.0; hidden_dim],
319            });
320        }
321
322        let feat_dim = graph
323            .node_features
324            .first()
325            .map(|r| r.len())
326            .unwrap_or(1)
327            .max(1);
328        self.ensure_w_in(feat_dim);
329
330        let w_in = match self.w_in.as_ref() {
331            Some(w) => w.clone(),
332            None => {
333                return Err(crate::error::GraphError::InvalidParameter {
334                    param: "w_in".to_string(),
335                    value: "None".to_string(),
336                    expected: "initialised projection matrix".to_string(),
337                    context: "GraphormerModel::forward".to_string(),
338                })
339            }
340        };
341
342        // Compute degrees
343        let degrees: Vec<usize> = graph.adjacency.iter().map(|nbrs| nbrs.len()).collect();
344
345        // All-pairs shortest paths
346        let apsp = all_pairs_shortest_path(&graph.adjacency);
347
348        // Build seq_len = n + 1 tokens (last = virtual graph token)
349        let seq_len = n + 1;
350
351        // Project node features + add degree embedding
352        let mut tokens: Vec<Vec<f64>> = (0..n)
353            .map(|i| {
354                let proj = mv(&w_in, &graph.node_features[i]);
355                let deg_e = self.degree_embedding(degrees[i]);
356                proj.iter().zip(deg_e.iter()).map(|(a, b)| a + b).collect()
357            })
358            .collect();
359
360        // Virtual token: mean of all real node embeddings + degree emb for degree=0
361        let virtual_emb: Vec<f64> = {
362            let mut sum = vec![0.0_f64; hidden_dim];
363            for t in &tokens {
364                for (d, &v) in t.iter().enumerate() {
365                    if d < hidden_dim {
366                        sum[d] += v;
367                    }
368                }
369            }
370            let inv = 1.0 / n as f64;
371            sum.iter().map(|v| v * inv).collect()
372        };
373        tokens.push(virtual_emb);
374
375        // Build spatial bias matrix (seq_len × seq_len)
376        // Virtual token (index n) gets bias 0 for all pairs
377        let spatial: Vec<Vec<f64>> = (0..seq_len)
378            .map(|i| {
379                (0..seq_len)
380                    .map(|j| {
381                        if i >= n || j >= n {
382                            // Virtual token: no spatial bias
383                            0.0
384                        } else {
385                            let bucket = self.spd_to_bucket(apsp[i][j]);
386                            self.spatial_bias[bucket]
387                        }
388                    })
389                    .collect()
390            })
391            .collect();
392
393        // Apply Transformer layers
394        let mut h = tokens;
395        for layer in &self.layers {
396            h = layer.forward(&h, &spatial);
397        }
398
399        // Extract node embeddings (first n tokens) and graph embedding (virtual token)
400        let node_embeddings: Vec<Vec<f64>> = h.iter().take(n).cloned().collect();
401        let graph_embedding: Vec<f64> = h.last().cloned().unwrap_or_else(|| vec![0.0; hidden_dim]);
402
403        Ok(GraphTransformerOutput {
404            node_embeddings,
405            graph_embedding,
406        })
407    }
408
409    /// Return the degree embedding for the given degree (for testing).
410    pub fn get_degree_embedding(&self, degree: usize) -> Vec<f64> {
411        self.degree_embedding(degree).clone()
412    }
413
414    /// Return the spatial bias for a given SPD bucket index (for testing).
415    pub fn get_spatial_bias(&self, spd: usize) -> f64 {
416        let bucket = self.spd_to_bucket(spd);
417        self.spatial_bias[bucket]
418    }
419}
420
421// ============================================================================
422// Unit tests
423// ============================================================================
424
425#[cfg(test)]
426mod tests {
427    use super::super::types::{GraphForTransformer, GraphormerConfig};
428    use super::*;
429
430    fn default_config() -> GraphormerConfig {
431        GraphormerConfig {
432            max_degree: 8,
433            max_shortest_path: 10,
434            n_heads: 2,
435            hidden_dim: 8,
436            n_layers: 1,
437        }
438    }
439
440    fn triangle_graph() -> GraphForTransformer {
441        GraphForTransformer::new(
442            vec![vec![1, 2], vec![0, 2], vec![0, 1]],
443            vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
444        )
445        .expect("valid graph")
446    }
447
448    fn single_node_graph() -> GraphForTransformer {
449        GraphForTransformer::new(vec![vec![]], vec![vec![1.0]]).expect("valid graph")
450    }
451
452    #[test]
453    fn test_graphormer_output_shape() {
454        let g = triangle_graph();
455        let cfg = default_config();
456        let mut model = GraphormerModel::new(&cfg);
457        let out = model.forward(&g).expect("forward ok");
458        assert_eq!(out.node_embeddings.len(), 3);
459        for row in &out.node_embeddings {
460            assert_eq!(row.len(), 8);
461        }
462    }
463
464    #[test]
465    fn test_graphormer_degree_embedding() {
466        let cfg = default_config();
467        let model = GraphormerModel::new(&cfg);
468        let e0 = model.get_degree_embedding(0);
469        let e2 = model.get_degree_embedding(2);
470        // Different degrees must produce different embeddings
471        let diff: f64 = e0.iter().zip(e2.iter()).map(|(a, b)| (a - b).abs()).sum();
472        assert!(
473            diff > 1e-9,
474            "degree 0 and degree 2 embeddings identical, diff={diff}"
475        );
476    }
477
478    #[test]
479    fn test_graphormer_spatial_encoding() {
480        let cfg = default_config();
481        let model = GraphormerModel::new(&cfg);
482        let bias_near = model.get_spatial_bias(1);
483        let bias_far = model.get_spatial_bias(5);
484        // They should differ (different table entries)
485        assert!(
486            (bias_near - bias_far).abs() > 0.0 || bias_near == bias_far, // allow equal by chance but at least check no panic
487            "spatial bias lookup failed"
488        );
489        // At minimum verify values are finite
490        assert!(bias_near.is_finite());
491        assert!(bias_far.is_finite());
492    }
493
494    #[test]
495    fn test_graphormer_spatial_encoding_different() {
496        // Use a larger model where near/far are almost certainly different
497        let cfg = default_config();
498        let model = GraphormerModel::new(&cfg);
499        // Index 1 and index 5 should be distinct entries in the table
500        let b1 = model.spatial_bias[1];
501        let b5 = model.spatial_bias[5];
502        // They are drawn from an LCG so they will almost always differ
503        assert!(b1.is_finite());
504        assert!(b5.is_finite());
505    }
506
507    #[test]
508    fn test_graphormer_virtual_token() {
509        let g = triangle_graph();
510        let cfg = default_config();
511        let mut model = GraphormerModel::new(&cfg);
512        let out = model.forward(&g).expect("forward ok");
513        // graph_embedding should be non-zero (virtual token output)
514        let norm: f64 = out
515            .graph_embedding
516            .iter()
517            .map(|v| v * v)
518            .sum::<f64>()
519            .sqrt();
520        assert!(norm > 0.0, "virtual token embedding is zero");
521        assert_eq!(out.graph_embedding.len(), 8);
522    }
523
524    #[test]
525    fn test_graphormer_single_node() {
526        let g = single_node_graph();
527        let cfg = default_config();
528        let mut model = GraphormerModel::new(&cfg);
529        let out = model.forward(&g).expect("single node forward ok");
530        assert_eq!(out.node_embeddings.len(), 1);
531        assert_eq!(out.graph_embedding.len(), 8);
532        for row in &out.node_embeddings {
533            for &v in row {
534                assert!(v.is_finite(), "non-finite node embedding");
535            }
536        }
537    }
538
539    #[test]
540    fn test_graphormer_triangle() {
541        let g = triangle_graph();
542        let cfg = default_config();
543        let mut model = GraphormerModel::new(&cfg);
544        let out = model.forward(&g).expect("triangle forward ok");
545        assert_eq!(out.node_embeddings.len(), 3);
546        for row in &out.node_embeddings {
547            assert_eq!(row.len(), 8);
548            for &v in row {
549                assert!(v.is_finite(), "non-finite value in triangle output");
550            }
551        }
552        // Graph embedding also finite
553        for &v in &out.graph_embedding {
554            assert!(v.is_finite());
555        }
556    }
557}