1use crate::error::{GraphError, Result};
11
12use super::types::{GraphForTransformer, GraphTransformerConfig, GraphTransformerOutput};
13
14#[inline]
20fn gelu(x: f64) -> f64 {
21 0.5 * x * (1.0 + (0.797_884_560_802_865_4 * (x + 0.044_715 * x * x * x)).tanh())
22}
23
24fn layer_norm(x: &[f64]) -> Vec<f64> {
26 let n = x.len() as f64;
27 if n == 0.0 {
28 return Vec::new();
29 }
30 let mean = x.iter().sum::<f64>() / n;
31 let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n;
32 let std = (var + 1e-6).sqrt();
33 x.iter().map(|v| (v - mean) / std).collect()
34}
35
36fn softmax(xs: &[f64]) -> Vec<f64> {
38 if xs.is_empty() {
39 return Vec::new();
40 }
41 let max_v = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
42 let exps: Vec<f64> = xs.iter().map(|&v| (v - max_v).exp()).collect();
43 let sum = exps.iter().sum::<f64>().max(1e-15);
44 exps.iter().map(|e| e / sum).collect()
45}
46
47fn mv(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
49 w.iter()
50 .map(|row| row.iter().zip(x.iter()).map(|(a, b)| a * b).sum())
51 .collect()
52}
53
54fn vadd(a: &[f64], b: &[f64]) -> Vec<f64> {
56 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
57}
58
59struct Lcg {
65 state: u64,
66}
67
68impl Lcg {
69 fn new(seed: u64) -> Self {
70 Self {
71 state: seed ^ 0x5851_f42d_4c95_7f2d,
72 }
73 }
74
75 fn next_f64(&mut self) -> f64 {
76 self.state = self
77 .state
78 .wrapping_mul(6_364_136_223_846_793_005)
79 .wrapping_add(1_442_695_040_888_963_407);
80 let bits = (self.state >> 33) as i32;
81 (bits as f64) / (i32::MAX as f64)
82 }
83
84 fn he_matrix(&mut self, rows: usize, cols: usize) -> Vec<Vec<f64>> {
86 let scale = (2.0 / cols as f64).sqrt();
87 (0..rows)
88 .map(|_| (0..cols).map(|_| self.next_f64() * scale).collect())
89 .collect()
90 }
91}
92
93struct GpsLayer {
99 hidden_dim: usize,
100 n_heads: usize,
101 pe_dim: usize,
102
103 w_msg: Vec<Vec<f64>>, wq: Vec<Vec<Vec<f64>>>, wk: Vec<Vec<Vec<f64>>>,
109 wv: Vec<Vec<Vec<f64>>>,
110 w_out: Vec<Vec<f64>>, alpha: f64,
114
115 w_ff1: Vec<Vec<f64>>, w_ff2: Vec<Vec<f64>>, }
119
120impl GpsLayer {
121 fn new(hidden_dim: usize, n_heads: usize, pe_dim: usize, seed: u64) -> Self {
122 let head_dim = (hidden_dim / n_heads).max(1);
123 let in_dim = hidden_dim + pe_dim;
124 let mut lcg = Lcg::new(seed);
125
126 let wq = (0..n_heads)
127 .map(|_| lcg.he_matrix(head_dim, in_dim))
128 .collect();
129 let wk = (0..n_heads)
130 .map(|_| lcg.he_matrix(head_dim, in_dim))
131 .collect();
132 let wv = (0..n_heads)
133 .map(|_| lcg.he_matrix(head_dim, hidden_dim))
134 .collect();
135 let w_out = lcg.he_matrix(hidden_dim, hidden_dim);
136 let w_msg = lcg.he_matrix(hidden_dim, hidden_dim);
137 let w_ff1 = lcg.he_matrix(4 * hidden_dim, hidden_dim);
138 let w_ff2 = lcg.he_matrix(hidden_dim, 4 * hidden_dim);
139
140 Self {
141 hidden_dim,
142 n_heads,
143 pe_dim,
144 w_msg,
145 wq,
146 wk,
147 wv,
148 w_out,
149 alpha: 0.5,
150 w_ff1,
151 w_ff2,
152 }
153 }
154
155 fn local_mpnn(&self, h: &[Vec<f64>], adj: &[Vec<usize>]) -> Vec<Vec<f64>> {
157 let n = h.len();
158 let mut out = vec![vec![0.0_f64; self.hidden_dim]; n];
159 for i in 0..n {
160 let nbrs = &adj[i];
161 let agg = if nbrs.is_empty() {
162 h[i].clone()
163 } else {
164 let mut sum = vec![0.0_f64; self.hidden_dim];
166 for &j in nbrs {
167 for d in 0..self.hidden_dim.min(h[j].len()) {
168 sum[d] += h[j][d];
169 }
170 }
171 let cnt = nbrs.len() as f64;
172 sum.iter().map(|v| v / cnt).collect()
173 };
174 let msg = mv(&self.w_msg, &agg);
175 out[i] = msg.into_iter().map(|v| v.max(0.0)).collect(); }
177 out
178 }
179
180 fn global_transformer(&self, h: &[Vec<f64>], pe: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<f64>) {
184 let n = h.len();
185 if n == 0 {
186 return (Vec::new(), Vec::new());
187 }
188 let head_dim = (self.hidden_dim / self.n_heads).max(1);
189 let scale = (head_dim as f64).sqrt().max(1e-6);
190
191 let aug: Vec<Vec<f64>> = (0..n)
193 .map(|i| {
194 let hi = if i < h.len() { &h[i] } else { &h[0] };
195 let pi = if i < pe.len() { &pe[i] } else { &pe[0] };
196 let mut v = hi.clone();
197 v.extend_from_slice(pi);
198 v
199 })
200 .collect();
201
202 let mut head_outputs: Vec<Vec<Vec<f64>>> = Vec::with_capacity(self.n_heads);
204 let mut all_attn: Vec<f64> = Vec::new();
205
206 for hd in 0..self.n_heads {
207 let q: Vec<Vec<f64>> = aug.iter().map(|a| mv(&self.wq[hd], a)).collect();
209 let k: Vec<Vec<f64>> = aug.iter().map(|a| mv(&self.wk[hd], a)).collect();
210 let v: Vec<Vec<f64>> = h.iter().map(|hi| mv(&self.wv[hd], hi)).collect();
211
212 let mut attn_logits = vec![vec![0.0_f64; n]; n];
214 for i in 0..n {
215 for j in 0..n {
216 let dot: f64 = q[i].iter().zip(k[j].iter()).map(|(a, b)| a * b).sum();
217 attn_logits[i][j] = dot / scale;
218 }
219 }
220
221 let attn_weights: Vec<Vec<f64>> = attn_logits.iter().map(|row| softmax(row)).collect();
223
224 if hd == 0 {
225 for row in &attn_weights {
227 all_attn.extend_from_slice(row);
228 }
229 }
230
231 let mut head_out = vec![vec![0.0_f64; head_dim.min(v[0].len())]; n];
233 for i in 0..n {
234 for j in 0..n {
235 let vj_len = v[j].len().min(head_dim);
236 for d in 0..vj_len {
237 head_out[i][d] += attn_weights[i][j] * v[j][d];
238 }
239 }
240 }
241 head_outputs.push(head_out);
242 }
243
244 let head_dim_out = (self.hidden_dim / self.n_heads).max(1);
246 let mut concat = vec![vec![0.0_f64; self.hidden_dim]; n];
247 for i in 0..n {
248 for hd in 0..self.n_heads {
249 let start = hd * head_dim_out;
250 let end = (start + head_dim_out).min(self.hidden_dim);
251 for d in start..end {
252 let local_d = d - start;
253 if local_d < head_outputs[hd][i].len() {
254 concat[i][d] = head_outputs[hd][i][local_d];
255 }
256 }
257 }
258 }
259
260 let out: Vec<Vec<f64>> = concat.iter().map(|c| mv(&self.w_out, c)).collect();
262 (out, all_attn)
263 }
264
265 fn ffn(&self, h: &[Vec<f64>]) -> Vec<Vec<f64>> {
267 h.iter()
268 .map(|x| {
269 let mid: Vec<f64> = mv(&self.w_ff1, x).into_iter().map(gelu).collect();
270 mv(&self.w_ff2, &mid)
271 })
272 .collect()
273 }
274
275 pub fn forward(
277 &self,
278 h: &[Vec<f64>],
279 adj: &[Vec<usize>],
280 pe: &[Vec<f64>],
281 ) -> (Vec<Vec<f64>>, Vec<f64>) {
282 let n = h.len();
283 if n == 0 {
284 return (Vec::new(), Vec::new());
285 }
286
287 let h_norm: Vec<Vec<f64>> = h
289 .iter()
290 .map(|row| {
291 let mut r = row.clone();
292 r.resize(self.hidden_dim, 0.0);
293 r
294 })
295 .collect();
296
297 let pe_norm: Vec<Vec<f64>> = pe
299 .iter()
300 .map(|row| {
301 let mut r = row.clone();
302 r.resize(self.pe_dim, 0.0);
303 r
304 })
305 .collect();
306
307 let h_local = self.local_mpnn(&h_norm, adj);
308 let (h_global, attn_weights) = self.global_transformer(&h_norm, &pe_norm);
309
310 let alpha = self.alpha.clamp(0.0, 1.0);
312 let combined: Vec<Vec<f64>> = (0..n)
313 .map(|i| {
314 let combined_raw: Vec<f64> = (0..self.hidden_dim)
315 .map(|d| h_norm[i][d] + alpha * h_local[i][d] + (1.0 - alpha) * h_global[i][d])
316 .collect();
317 layer_norm(&combined_raw)
318 })
319 .collect();
320
321 let h_ffn = self.ffn(&combined);
323 let h_out: Vec<Vec<f64>> = (0..n)
324 .map(|i| {
325 let res = vadd(&combined[i], &h_ffn[i]);
326 layer_norm(&res)
327 })
328 .collect();
329
330 (h_out, attn_weights)
331 }
332}
333
334pub struct GpsModel {
340 layers: Vec<GpsLayer>,
341 hidden_dim: usize,
342 pe_dim: usize,
343 w_in: Option<Vec<Vec<f64>>>,
345 feat_dim: usize,
346}
347
348impl GpsModel {
349 pub fn new(config: &GraphTransformerConfig) -> Self {
351 let pe_dim = config.pe_dim;
352 let hidden_dim = config.hidden_dim;
353 let n_heads = config.n_heads.max(1);
354 let layers: Vec<GpsLayer> = (0..config.n_layers)
355 .map(|i| {
356 let seed = (i as u64)
357 .wrapping_add(1)
358 .wrapping_mul(0x9e37_79b9_7f4a_7c15_u64);
359 GpsLayer::new(hidden_dim, n_heads, pe_dim, seed)
360 })
361 .collect();
362
363 Self {
364 layers,
365 hidden_dim,
366 pe_dim,
367 w_in: None,
368 feat_dim: 0,
369 }
370 }
371
372 fn ensure_w_in(&mut self, feat_dim: usize) {
374 if self.w_in.is_none() || self.feat_dim != feat_dim {
375 let mut lcg = Lcg::new(0xdead_beef_cafe_babe);
376 self.w_in = Some(lcg.he_matrix(self.hidden_dim, feat_dim.max(1)));
377 self.feat_dim = feat_dim;
378 }
379 }
380
381 pub fn forward(
385 &mut self,
386 graph: &GraphForTransformer,
387 pe: &[Vec<f64>],
388 ) -> Result<(GraphTransformerOutput, Vec<f64>)> {
389 let n = graph.n_nodes;
390 if n == 0 {
391 return Ok((
392 GraphTransformerOutput {
393 node_embeddings: Vec::new(),
394 graph_embedding: vec![0.0; self.hidden_dim],
395 },
396 Vec::new(),
397 ));
398 }
399
400 let feat_dim = graph
401 .node_features
402 .first()
403 .map(|r| r.len())
404 .unwrap_or(1)
405 .max(1);
406 self.ensure_w_in(feat_dim);
407
408 let w_in = self
409 .w_in
410 .as_ref()
411 .ok_or_else(|| GraphError::InvalidParameter {
412 param: "w_in".to_string(),
413 value: "None".to_string(),
414 expected: "initialised weight matrix".to_string(),
415 context: "GpsModel forward".to_string(),
416 })?;
417
418 let mut h: Vec<Vec<f64>> = graph.node_features.iter().map(|f| mv(w_in, f)).collect();
420
421 let mut last_attn: Vec<f64> = Vec::new();
422 for layer in &self.layers {
423 let (h_new, attn) = layer.forward(&h, &graph.adjacency, pe);
424 h = h_new;
425 last_attn = attn;
426 }
427
428 let mut graph_emb = vec![0.0_f64; self.hidden_dim];
430 for row in &h {
431 for (d, &v) in row.iter().enumerate() {
432 if d < self.hidden_dim {
433 graph_emb[d] += v;
434 }
435 }
436 }
437 let inv_n = 1.0 / n as f64;
438 for v in graph_emb.iter_mut() {
439 *v *= inv_n;
440 }
441
442 Ok((
443 GraphTransformerOutput {
444 node_embeddings: h,
445 graph_embedding: graph_emb,
446 },
447 last_attn,
448 ))
449 }
450}
451
452#[cfg(test)]
457mod tests {
458 use super::super::positional_encoding::{laplacian_pe, rwpe};
459 use super::super::types::{GraphForTransformer, GraphTransformerConfig};
460 use super::*;
461
462 fn triangle_graph() -> GraphForTransformer {
463 GraphForTransformer::new(
464 vec![vec![1, 2], vec![0, 2], vec![0, 1]],
465 vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
466 )
467 .expect("valid graph")
468 }
469
470 fn single_node_graph() -> GraphForTransformer {
471 GraphForTransformer::new(vec![vec![]], vec![vec![1.0]]).expect("valid graph")
472 }
473
474 fn default_config() -> GraphTransformerConfig {
475 GraphTransformerConfig {
476 n_heads: 2,
477 hidden_dim: 8,
478 n_layers: 1,
479 dropout: 0.0,
480 pe_type: super::super::types::PeType::LapPE,
481 pe_dim: 4,
482 }
483 }
484
485 fn two_layer_config() -> GraphTransformerConfig {
486 GraphTransformerConfig {
487 n_layers: 2,
488 ..default_config()
489 }
490 }
491
492 #[test]
493 fn test_gps_output_shape() {
494 let g = triangle_graph();
495 let pe = laplacian_pe(&g.adjacency, 4);
496 let cfg = default_config();
497 let mut model = GpsModel::new(&cfg);
498 let (out, _) = model.forward(&g, &pe).expect("forward ok");
499 assert_eq!(out.node_embeddings.len(), 3);
500 for row in &out.node_embeddings {
501 assert_eq!(row.len(), 8);
502 }
503 }
504
505 #[test]
506 fn test_gps_graph_embedding_shape() {
507 let g = triangle_graph();
508 let pe = laplacian_pe(&g.adjacency, 4);
509 let cfg = default_config();
510 let mut model = GpsModel::new(&cfg);
511 let (out, _) = model.forward(&g, &pe).expect("forward ok");
512 assert_eq!(out.graph_embedding.len(), 8);
513 }
514
515 #[test]
516 fn test_gps_single_node() {
517 let g = single_node_graph();
518 let pe = laplacian_pe(&g.adjacency, 4);
519 let cfg = default_config();
520 let mut model = GpsModel::new(&cfg);
521 let (out, _) = model.forward(&g, &pe).expect("forward ok");
522 assert_eq!(out.node_embeddings.len(), 1);
523 assert_eq!(out.graph_embedding.len(), 8);
524 }
525
526 #[test]
527 fn test_gps_no_edges() {
528 let g = GraphForTransformer::new(
530 vec![vec![], vec![], vec![]],
531 vec![vec![1.0], vec![2.0], vec![3.0]],
532 )
533 .expect("valid");
534 let pe = rwpe(&g.adjacency, 4);
535 let cfg = default_config();
536 let mut model = GpsModel::new(&cfg);
537 let (out, _) = model.forward(&g, &pe).expect("forward ok");
538 assert_eq!(out.node_embeddings.len(), 3);
539 }
540
541 #[test]
542 fn test_gps_attention_softmax() {
543 let g = triangle_graph();
545 let pe = laplacian_pe(&g.adjacency, 4);
546 let cfg = default_config();
547 let layer = GpsLayer::new(cfg.hidden_dim, cfg.n_heads, cfg.pe_dim, 42);
548 let h: Vec<Vec<f64>> = g
549 .node_features
550 .iter()
551 .map(|f| {
552 let mut r = f.clone();
553 r.resize(cfg.hidden_dim, 0.0);
554 r
555 })
556 .collect();
557 let pe_norm: Vec<Vec<f64>> = pe
558 .iter()
559 .map(|p| {
560 let mut r = p.clone();
561 r.resize(cfg.pe_dim, 0.0);
562 r
563 })
564 .collect();
565 let (_out, attn) = layer.global_transformer(&h, &pe_norm);
566 let n = g.n_nodes;
568 for i in 0..n {
569 let row_sum: f64 = (0..n).map(|j| attn[i * n + j]).sum();
570 assert!((row_sum - 1.0).abs() < 1e-10, "row {i} sum={row_sum}");
571 }
572 }
573
574 #[test]
575 fn test_gps_layers_stack() {
576 let g = triangle_graph();
577 let pe = laplacian_pe(&g.adjacency, 4);
578 let cfg = two_layer_config();
579 let mut model = GpsModel::new(&cfg);
580 let (out, _) = model.forward(&g, &pe).expect("2-layer forward ok");
581 assert_eq!(out.node_embeddings.len(), 3);
582 for row in &out.node_embeddings {
583 assert_eq!(row.len(), 8);
584 for &v in row {
586 assert!(v.is_finite(), "non-finite value in output");
587 }
588 }
589 }
590}