scirs2_graph/graph_transformer/
graphormer.rs1use super::positional_encoding::all_pairs_shortest_path;
14use super::types::{GraphForTransformer, GraphTransformerOutput, GraphormerConfig};
15use crate::error::Result;
16
17fn softmax(xs: &[f64]) -> Vec<f64> {
23 if xs.is_empty() {
24 return Vec::new();
25 }
26 let max_v = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
27 let exps: Vec<f64> = xs.iter().map(|&v| (v - max_v).exp()).collect();
28 let sum = exps.iter().sum::<f64>().max(1e-15);
29 exps.iter().map(|e| e / sum).collect()
30}
31
32fn mv(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
34 w.iter()
35 .map(|row| row.iter().zip(x.iter()).map(|(a, b)| a * b).sum())
36 .collect()
37}
38
39fn layer_norm(x: &[f64]) -> Vec<f64> {
41 let n = x.len() as f64;
42 if n == 0.0 {
43 return Vec::new();
44 }
45 let mean = x.iter().sum::<f64>() / n;
46 let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n;
47 let std = (var + 1e-6).sqrt();
48 x.iter().map(|v| (v - mean) / std).collect()
49}
50
51#[inline]
53fn gelu(x: f64) -> f64 {
54 0.5 * x * (1.0 + (0.797_884_560_802_865_4 * (x + 0.044_715 * x * x * x)).tanh())
55}
56
57struct Lcg {
62 state: u64,
63}
64
65impl Lcg {
66 fn new(seed: u64) -> Self {
67 Self {
68 state: seed ^ 0x5851_f42d_4c95_7f2d,
69 }
70 }
71
72 fn next_f64(&mut self) -> f64 {
73 self.state = self
74 .state
75 .wrapping_mul(6_364_136_223_846_793_005)
76 .wrapping_add(1_442_695_040_888_963_407);
77 let bits = (self.state >> 33) as i32;
78 (bits as f64) / (i32::MAX as f64)
79 }
80
81 fn he_matrix(&mut self, rows: usize, cols: usize) -> Vec<Vec<f64>> {
82 let scale = (2.0 / cols.max(1) as f64).sqrt();
83 (0..rows)
84 .map(|_| (0..cols).map(|_| self.next_f64() * scale).collect())
85 .collect()
86 }
87
88 fn he_vec(&mut self, len: usize) -> Vec<f64> {
89 let scale = (2.0 / len.max(1) as f64).sqrt();
90 (0..len).map(|_| self.next_f64() * scale).collect()
91 }
92}
93
94pub struct GraphormerModel {
100 config: GraphormerConfig,
101
102 deg_emb: Vec<Vec<f64>>,
104
105 spatial_bias: Vec<f64>,
110
111 w_in: Option<Vec<Vec<f64>>>,
113 feat_dim: usize,
114
115 layers: Vec<TransformerLayerWeights>,
117}
118
119struct TransformerLayerWeights {
120 n_heads: usize,
121 head_dim: usize,
122 hidden_dim: usize,
123 wq: Vec<Vec<Vec<f64>>>, wk: Vec<Vec<Vec<f64>>>,
125 wv: Vec<Vec<Vec<f64>>>,
126 w_out: Vec<Vec<f64>>, w_ff1: Vec<Vec<f64>>, w_ff2: Vec<Vec<f64>>, }
130
131impl TransformerLayerWeights {
132 fn new(hidden_dim: usize, n_heads: usize, lcg: &mut Lcg) -> Self {
133 let head_dim = (hidden_dim / n_heads).max(1);
134 let wq = (0..n_heads)
135 .map(|_| lcg.he_matrix(head_dim, hidden_dim))
136 .collect();
137 let wk = (0..n_heads)
138 .map(|_| lcg.he_matrix(head_dim, hidden_dim))
139 .collect();
140 let wv = (0..n_heads)
141 .map(|_| lcg.he_matrix(head_dim, hidden_dim))
142 .collect();
143 let w_out = lcg.he_matrix(hidden_dim, hidden_dim);
144 let w_ff1 = lcg.he_matrix(4 * hidden_dim, hidden_dim);
145 let w_ff2 = lcg.he_matrix(hidden_dim, 4 * hidden_dim);
146 Self {
147 n_heads,
148 head_dim,
149 hidden_dim,
150 wq,
151 wk,
152 wv,
153 w_out,
154 w_ff1,
155 w_ff2,
156 }
157 }
158
159 fn attention(&self, tokens: &[Vec<f64>], spatial: &[Vec<f64>]) -> Vec<Vec<f64>> {
164 let seq_len = tokens.len();
165 let scale = (self.head_dim as f64).sqrt().max(1e-6);
166
167 let mut concat = vec![vec![0.0_f64; self.hidden_dim]; seq_len];
168
169 for hd in 0..self.n_heads {
170 let q: Vec<Vec<f64>> = tokens.iter().map(|t| mv(&self.wq[hd], t)).collect();
171 let k: Vec<Vec<f64>> = tokens.iter().map(|t| mv(&self.wk[hd], t)).collect();
172 let v: Vec<Vec<f64>> = tokens.iter().map(|t| mv(&self.wv[hd], t)).collect();
173
174 let mut attn = vec![vec![0.0_f64; seq_len]; seq_len];
176 for i in 0..seq_len {
177 for j in 0..seq_len {
178 let dot: f64 = q[i].iter().zip(k[j].iter()).map(|(a, b)| a * b).sum();
179 let bias = spatial
180 .get(i)
181 .and_then(|r| r.get(j))
182 .copied()
183 .unwrap_or(0.0);
184 attn[i][j] = dot / scale + bias;
185 }
186 let sm = softmax(&attn[i]);
188 attn[i] = sm;
189 }
190
191 let head_start = hd * self.head_dim;
193 let head_end = (head_start + self.head_dim).min(self.hidden_dim);
194 for i in 0..seq_len {
195 for j in 0..seq_len {
196 let v_len = v[j].len().min(self.head_dim);
197 for d in 0..v_len {
198 let out_d = head_start + d;
199 if out_d < head_end {
200 concat[i][out_d] += attn[i][j] * v[j][d];
201 }
202 }
203 }
204 }
205 }
206
207 concat.iter().map(|c| mv(&self.w_out, c)).collect()
209 }
210
211 fn ffn(&self, h: &[Vec<f64>]) -> Vec<Vec<f64>> {
213 h.iter()
214 .map(|x| {
215 let mid: Vec<f64> = mv(&self.w_ff1, x).into_iter().map(gelu).collect();
216 mv(&self.w_ff2, &mid)
217 })
218 .collect()
219 }
220
221 fn forward(&self, tokens: &[Vec<f64>], spatial: &[Vec<f64>]) -> Vec<Vec<f64>> {
223 let attn_out = self.attention(tokens, spatial);
224 let h1: Vec<Vec<f64>> = tokens
226 .iter()
227 .zip(attn_out.iter())
228 .map(|(t, a)| {
229 layer_norm(
230 &t.iter()
231 .zip(a.iter())
232 .map(|(x, y)| x + y)
233 .collect::<Vec<_>>(),
234 )
235 })
236 .collect();
237 let ffn_out = self.ffn(&h1);
238 h1.iter()
240 .zip(ffn_out.iter())
241 .map(|(t, f)| {
242 layer_norm(
243 &t.iter()
244 .zip(f.iter())
245 .map(|(x, y)| x + y)
246 .collect::<Vec<_>>(),
247 )
248 })
249 .collect()
250 }
251}
252
253impl GraphormerModel {
254 pub fn new(config: &GraphormerConfig) -> Self {
256 let hidden_dim = config.hidden_dim;
257 let n_heads = config.n_heads.max(1);
258 let mut lcg = Lcg::new(0x1234_5678_9abc_def0);
259
260 let deg_emb: Vec<Vec<f64>> = (0..=config.max_degree)
262 .map(|_| lcg.he_vec(hidden_dim))
263 .collect();
264
265 let n_buckets = config.max_shortest_path + 2;
267 let spatial_bias: Vec<f64> = (0..n_buckets).map(|_| lcg.next_f64() * 0.1).collect();
268
269 let layers: Vec<TransformerLayerWeights> = (0..config.n_layers)
270 .map(|_| TransformerLayerWeights::new(hidden_dim, n_heads, &mut lcg))
271 .collect();
272
273 Self {
274 config: config.clone(),
275 deg_emb,
276 spatial_bias,
277 w_in: None,
278 feat_dim: 0,
279 layers,
280 }
281 }
282
283 fn ensure_w_in(&mut self, feat_dim: usize) {
285 if self.w_in.is_none() || self.feat_dim != feat_dim {
286 let mut lcg = Lcg::new(0xfeed_face_dead_beef);
287 self.w_in = Some(lcg.he_matrix(self.config.hidden_dim, feat_dim.max(1)));
288 self.feat_dim = feat_dim;
289 }
290 }
291
292 fn degree_embedding(&self, degree: usize) -> &Vec<f64> {
294 let idx = degree.min(self.config.max_degree);
295 &self.deg_emb[idx]
296 }
297
298 fn spd_to_bucket(&self, spd: usize) -> usize {
300 if spd == 0 {
301 0
302 } else if spd == usize::MAX {
303 self.config.max_shortest_path + 1
305 } else {
306 spd.min(self.config.max_shortest_path)
307 }
308 }
309
310 pub fn forward(&mut self, graph: &GraphForTransformer) -> Result<GraphTransformerOutput> {
312 let n = graph.n_nodes;
313 let hidden_dim = self.config.hidden_dim;
314
315 if n == 0 {
316 return Ok(GraphTransformerOutput {
317 node_embeddings: Vec::new(),
318 graph_embedding: vec![0.0; hidden_dim],
319 });
320 }
321
322 let feat_dim = graph
323 .node_features
324 .first()
325 .map(|r| r.len())
326 .unwrap_or(1)
327 .max(1);
328 self.ensure_w_in(feat_dim);
329
330 let w_in = match self.w_in.as_ref() {
331 Some(w) => w.clone(),
332 None => {
333 return Err(crate::error::GraphError::InvalidParameter {
334 param: "w_in".to_string(),
335 value: "None".to_string(),
336 expected: "initialised projection matrix".to_string(),
337 context: "GraphormerModel::forward".to_string(),
338 })
339 }
340 };
341
342 let degrees: Vec<usize> = graph.adjacency.iter().map(|nbrs| nbrs.len()).collect();
344
345 let apsp = all_pairs_shortest_path(&graph.adjacency);
347
348 let seq_len = n + 1;
350
351 let mut tokens: Vec<Vec<f64>> = (0..n)
353 .map(|i| {
354 let proj = mv(&w_in, &graph.node_features[i]);
355 let deg_e = self.degree_embedding(degrees[i]);
356 proj.iter().zip(deg_e.iter()).map(|(a, b)| a + b).collect()
357 })
358 .collect();
359
360 let virtual_emb: Vec<f64> = {
362 let mut sum = vec![0.0_f64; hidden_dim];
363 for t in &tokens {
364 for (d, &v) in t.iter().enumerate() {
365 if d < hidden_dim {
366 sum[d] += v;
367 }
368 }
369 }
370 let inv = 1.0 / n as f64;
371 sum.iter().map(|v| v * inv).collect()
372 };
373 tokens.push(virtual_emb);
374
375 let spatial: Vec<Vec<f64>> = (0..seq_len)
378 .map(|i| {
379 (0..seq_len)
380 .map(|j| {
381 if i >= n || j >= n {
382 0.0
384 } else {
385 let bucket = self.spd_to_bucket(apsp[i][j]);
386 self.spatial_bias[bucket]
387 }
388 })
389 .collect()
390 })
391 .collect();
392
393 let mut h = tokens;
395 for layer in &self.layers {
396 h = layer.forward(&h, &spatial);
397 }
398
399 let node_embeddings: Vec<Vec<f64>> = h.iter().take(n).cloned().collect();
401 let graph_embedding: Vec<f64> = h.last().cloned().unwrap_or_else(|| vec![0.0; hidden_dim]);
402
403 Ok(GraphTransformerOutput {
404 node_embeddings,
405 graph_embedding,
406 })
407 }
408
409 pub fn get_degree_embedding(&self, degree: usize) -> Vec<f64> {
411 self.degree_embedding(degree).clone()
412 }
413
414 pub fn get_spatial_bias(&self, spd: usize) -> f64 {
416 let bucket = self.spd_to_bucket(spd);
417 self.spatial_bias[bucket]
418 }
419}
420
421#[cfg(test)]
426mod tests {
427 use super::super::types::{GraphForTransformer, GraphormerConfig};
428 use super::*;
429
430 fn default_config() -> GraphormerConfig {
431 GraphormerConfig {
432 max_degree: 8,
433 max_shortest_path: 10,
434 n_heads: 2,
435 hidden_dim: 8,
436 n_layers: 1,
437 }
438 }
439
440 fn triangle_graph() -> GraphForTransformer {
441 GraphForTransformer::new(
442 vec![vec![1, 2], vec![0, 2], vec![0, 1]],
443 vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
444 )
445 .expect("valid graph")
446 }
447
448 fn single_node_graph() -> GraphForTransformer {
449 GraphForTransformer::new(vec![vec![]], vec![vec![1.0]]).expect("valid graph")
450 }
451
452 #[test]
453 fn test_graphormer_output_shape() {
454 let g = triangle_graph();
455 let cfg = default_config();
456 let mut model = GraphormerModel::new(&cfg);
457 let out = model.forward(&g).expect("forward ok");
458 assert_eq!(out.node_embeddings.len(), 3);
459 for row in &out.node_embeddings {
460 assert_eq!(row.len(), 8);
461 }
462 }
463
464 #[test]
465 fn test_graphormer_degree_embedding() {
466 let cfg = default_config();
467 let model = GraphormerModel::new(&cfg);
468 let e0 = model.get_degree_embedding(0);
469 let e2 = model.get_degree_embedding(2);
470 let diff: f64 = e0.iter().zip(e2.iter()).map(|(a, b)| (a - b).abs()).sum();
472 assert!(
473 diff > 1e-9,
474 "degree 0 and degree 2 embeddings identical, diff={diff}"
475 );
476 }
477
478 #[test]
479 fn test_graphormer_spatial_encoding() {
480 let cfg = default_config();
481 let model = GraphormerModel::new(&cfg);
482 let bias_near = model.get_spatial_bias(1);
483 let bias_far = model.get_spatial_bias(5);
484 assert!(
486 (bias_near - bias_far).abs() > 0.0 || bias_near == bias_far, "spatial bias lookup failed"
488 );
489 assert!(bias_near.is_finite());
491 assert!(bias_far.is_finite());
492 }
493
494 #[test]
495 fn test_graphormer_spatial_encoding_different() {
496 let cfg = default_config();
498 let model = GraphormerModel::new(&cfg);
499 let b1 = model.spatial_bias[1];
501 let b5 = model.spatial_bias[5];
502 assert!(b1.is_finite());
504 assert!(b5.is_finite());
505 }
506
507 #[test]
508 fn test_graphormer_virtual_token() {
509 let g = triangle_graph();
510 let cfg = default_config();
511 let mut model = GraphormerModel::new(&cfg);
512 let out = model.forward(&g).expect("forward ok");
513 let norm: f64 = out
515 .graph_embedding
516 .iter()
517 .map(|v| v * v)
518 .sum::<f64>()
519 .sqrt();
520 assert!(norm > 0.0, "virtual token embedding is zero");
521 assert_eq!(out.graph_embedding.len(), 8);
522 }
523
524 #[test]
525 fn test_graphormer_single_node() {
526 let g = single_node_graph();
527 let cfg = default_config();
528 let mut model = GraphormerModel::new(&cfg);
529 let out = model.forward(&g).expect("single node forward ok");
530 assert_eq!(out.node_embeddings.len(), 1);
531 assert_eq!(out.graph_embedding.len(), 8);
532 for row in &out.node_embeddings {
533 for &v in row {
534 assert!(v.is_finite(), "non-finite node embedding");
535 }
536 }
537 }
538
539 #[test]
540 fn test_graphormer_triangle() {
541 let g = triangle_graph();
542 let cfg = default_config();
543 let mut model = GraphormerModel::new(&cfg);
544 let out = model.forward(&g).expect("triangle forward ok");
545 assert_eq!(out.node_embeddings.len(), 3);
546 for row in &out.node_embeddings {
547 assert_eq!(row.len(), 8);
548 for &v in row {
549 assert!(v.is_finite(), "non-finite value in triangle output");
550 }
551 }
552 for &v in &out.graph_embedding {
554 assert!(v.is_finite());
555 }
556 }
557}