Skip to main content

scirs2_graph/gnn/transformers/
graphormer.rs

1//! Graphormer-style Graph Transformer
2//!
3//! Implements the Graphormer architecture from Ying et al. (2021),
4//! "Do Transformers Really Perform Bad for Graph Representation Learning?"
5//!
6//! Key components:
7//! - **Centrality encoding**: learnable embeddings based on node in-degree and out-degree
8//! - **Spatial encoding**: shortest-path distance (SPD) between node pairs as attention bias
9//! - **Edge encoding**: learnable edge feature encoding along shortest paths
10//! - **Multi-head self-attention with graph structural bias**:
11//!   `A_ij = softmax((Q_i * K_j / sqrt(d)) + spatial_bias[SPD(i,j)] + edge_bias(i,j))`
12//! - **Full Graphormer layer**: attention -> FFN -> LayerNorm (pre-norm)
13
14use std::collections::VecDeque;
15
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::random::{Rng, RngExt};
18
19use crate::error::{GraphError, Result};
20use crate::gnn::gcn::CsrMatrix;
21
22// ============================================================================
23// Centrality Encoding
24// ============================================================================
25
26/// Learnable centrality encoding based on node in-degree and out-degree.
27///
28/// Adds `z_in[deg_in(v)] + z_out[deg_out(v)]` to each node embedding,
29/// where `z_in` and `z_out` are learnable embedding tables indexed by degree.
30#[derive(Debug, Clone)]
31pub struct CentralityEncoding {
32    /// Embedding table for in-degree: `[max_degree + 1, hidden_dim]`
33    pub in_degree_embed: Array2<f64>,
34    /// Embedding table for out-degree: `[max_degree + 1, hidden_dim]`
35    pub out_degree_embed: Array2<f64>,
36    /// Maximum degree supported
37    pub max_degree: usize,
38    /// Hidden dimension
39    pub hidden_dim: usize,
40}
41
42impl CentralityEncoding {
43    /// Create a new centrality encoding module.
44    ///
45    /// # Arguments
46    /// * `max_degree` - Maximum node degree to encode (higher degrees are clamped)
47    /// * `hidden_dim` - Dimension of the centrality embedding
48    pub fn new(max_degree: usize, hidden_dim: usize) -> Self {
49        let mut rng = scirs2_core::random::rng();
50        let scale = (1.0 / hidden_dim as f64).sqrt();
51
52        let in_degree_embed = Array2::from_shape_fn((max_degree + 1, hidden_dim), |_| {
53            (rng.random::<f64>() * 2.0 - 1.0) * scale
54        });
55        let out_degree_embed = Array2::from_shape_fn((max_degree + 1, hidden_dim), |_| {
56            (rng.random::<f64>() * 2.0 - 1.0) * scale
57        });
58
59        CentralityEncoding {
60            in_degree_embed,
61            out_degree_embed,
62            max_degree,
63            hidden_dim,
64        }
65    }
66
67    /// Compute in-degree and out-degree for each node from adjacency.
68    pub fn compute_degrees(&self, adj: &CsrMatrix) -> (Vec<usize>, Vec<usize>) {
69        let n = adj.n_rows;
70        let mut in_deg = vec![0usize; n];
71        let mut out_deg = vec![0usize; n];
72
73        for (row, col, _) in adj.iter() {
74            out_deg[row] += 1;
75            if col < n {
76                in_deg[col] += 1;
77            }
78        }
79
80        (in_deg, out_deg)
81    }
82
83    /// Encode centrality information and add to node features.
84    ///
85    /// # Arguments
86    /// * `features` - Node features `[n_nodes, hidden_dim]`
87    /// * `adj` - Adjacency matrix
88    ///
89    /// # Returns
90    /// Updated features with centrality encoding added
91    pub fn forward(&self, features: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
92        let (n, dim) = features.dim();
93        if dim != self.hidden_dim {
94            return Err(GraphError::InvalidParameter {
95                param: "features".to_string(),
96                value: format!("dim={dim}"),
97                expected: format!("dim={}", self.hidden_dim),
98                context: "CentralityEncoding::forward".to_string(),
99            });
100        }
101
102        let (in_deg, out_deg) = self.compute_degrees(adj);
103        let mut output = features.clone();
104
105        for i in 0..n {
106            let in_d = in_deg[i].min(self.max_degree);
107            let out_d = out_deg[i].min(self.max_degree);
108            for j in 0..dim {
109                output[[i, j]] +=
110                    self.in_degree_embed[[in_d, j]] + self.out_degree_embed[[out_d, j]];
111            }
112        }
113
114        Ok(output)
115    }
116}
117
118// ============================================================================
119// Spatial Encoding
120// ============================================================================
121
122/// Spatial encoding using shortest-path distances (SPD) between node pairs.
123///
124/// Computes the all-pairs shortest path distance matrix via BFS on unweighted
125/// graphs, then provides learnable bias terms indexed by distance.
126#[derive(Debug, Clone)]
127pub struct SpatialEncoding {
128    /// Learnable bias for each distance value: `[max_distance + 1]`
129    /// Index 0 = self-loop, index k = distance k
130    pub spatial_bias: Array1<f64>,
131    /// Maximum distance to encode (larger distances use `max_distance` bias)
132    pub max_distance: usize,
133}
134
135impl SpatialEncoding {
136    /// Create a new spatial encoding module.
137    ///
138    /// # Arguments
139    /// * `max_distance` - Maximum SPD to encode distinctly
140    pub fn new(max_distance: usize) -> Self {
141        let mut rng = scirs2_core::random::rng();
142        let spatial_bias =
143            Array1::from_iter((0..=max_distance).map(|_| (rng.random::<f64>() * 2.0 - 1.0) * 0.1));
144
145        SpatialEncoding {
146            spatial_bias,
147            max_distance,
148        }
149    }
150
151    /// Compute all-pairs shortest path distances via BFS.
152    ///
153    /// Returns a matrix `[n, n]` where entry `(i, j)` is the shortest path
154    /// distance from node `i` to node `j`. Unreachable pairs get distance
155    /// `max_distance + 1`.
156    pub fn compute_spd_matrix(&self, adj: &CsrMatrix) -> Array2<usize> {
157        let n = adj.n_rows;
158        let unreachable = self.max_distance + 1;
159        let mut spd = Array2::from_elem((n, n), unreachable);
160
161        // Build adjacency list for BFS
162        let mut adj_list: Vec<Vec<usize>> = vec![Vec::new(); n];
163        for (row, col, _) in adj.iter() {
164            adj_list[row].push(col);
165        }
166
167        // BFS from each node
168        for src in 0..n {
169            spd[[src, src]] = 0;
170            let mut queue = VecDeque::new();
171            queue.push_back(src);
172            let mut visited = vec![false; n];
173            visited[src] = true;
174
175            while let Some(u) = queue.pop_front() {
176                let dist = spd[[src, u]];
177                if dist >= self.max_distance {
178                    continue;
179                }
180                for &v in &adj_list[u] {
181                    if !visited[v] {
182                        visited[v] = true;
183                        spd[[src, v]] = dist + 1;
184                        queue.push_back(v);
185                    }
186                }
187            }
188        }
189
190        spd
191    }
192
193    /// Get the spatial bias matrix `[n, n]` for attention.
194    ///
195    /// # Arguments
196    /// * `adj` - Adjacency matrix
197    ///
198    /// # Returns
199    /// Bias matrix where entry `(i, j)` is the learnable bias for SPD(i, j)
200    pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
201        let spd = self.compute_spd_matrix(adj);
202        let n = adj.n_rows;
203        let mut bias = Array2::zeros((n, n));
204
205        for i in 0..n {
206            for j in 0..n {
207                let d = spd[[i, j]].min(self.max_distance);
208                bias[[i, j]] = self.spatial_bias[d];
209            }
210        }
211
212        bias
213    }
214}
215
216// ============================================================================
217// Edge Encoding
218// ============================================================================
219
220/// Edge encoding along shortest paths between node pairs.
221///
222/// For each pair `(i, j)`, takes the edges along the shortest path from `i` to `j`
223/// and averages the learnable edge embeddings. This provides additional structural
224/// information to the attention mechanism.
225#[derive(Debug, Clone)]
226pub struct EdgeEncoding {
227    /// Learnable edge embedding: `[max_edge_types, hidden_dim]`
228    pub edge_embed: Array2<f64>,
229    /// Projection from hidden_dim to scalar bias
230    pub projection: Array1<f64>,
231    /// Maximum number of edge types
232    pub max_edge_types: usize,
233    /// Hidden dimension for edge embedding
234    pub hidden_dim: usize,
235}
236
237impl EdgeEncoding {
238    /// Create a new edge encoding module.
239    ///
240    /// # Arguments
241    /// * `max_edge_types` - Maximum number of distinct edge feature types
242    /// * `hidden_dim` - Dimension for edge embeddings
243    pub fn new(max_edge_types: usize, hidden_dim: usize) -> Self {
244        let mut rng = scirs2_core::random::rng();
245        let scale = (1.0 / hidden_dim as f64).sqrt();
246
247        let edge_embed = Array2::from_shape_fn((max_edge_types, hidden_dim), |_| {
248            (rng.random::<f64>() * 2.0 - 1.0) * scale
249        });
250        let projection =
251            Array1::from_iter((0..hidden_dim).map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale));
252
253        EdgeEncoding {
254            edge_embed,
255            projection,
256            max_edge_types,
257            hidden_dim,
258        }
259    }
260
261    /// Compute edge bias matrix.
262    ///
263    /// For simplicity, uses edge weights discretized to integer types.
264    /// Each edge on the shortest path contributes its embedding, which is
265    /// averaged and projected to a scalar.
266    ///
267    /// # Arguments
268    /// * `adj` - Adjacency matrix with edge weights
269    /// * `spd` - Shortest path distance matrix from `SpatialEncoding::compute_spd_matrix`
270    ///
271    /// # Returns
272    /// Edge bias matrix `[n, n]`
273    pub fn forward(&self, adj: &CsrMatrix, spd: &Array2<usize>) -> Array2<f64> {
274        let n = adj.n_rows;
275        let mut bias = Array2::zeros((n, n));
276
277        // Build adjacency list with edge types for path reconstruction
278        let mut adj_list: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n];
279        for (row, col, val) in adj.iter() {
280            // Discretize edge weight to type index
281            let edge_type = (val.abs() as usize).min(self.max_edge_types - 1);
282            adj_list[row].push((col, edge_type));
283        }
284
285        // For each pair, reconstruct shortest path edges and compute embedding
286        for src in 0..n {
287            // BFS parent tracking from src
288            let mut parent: Vec<Option<(usize, usize)>> = vec![None; n]; // (parent_node, edge_type)
289            let mut visited = vec![false; n];
290            visited[src] = true;
291            let mut queue = VecDeque::new();
292            queue.push_back(src);
293
294            while let Some(u) = queue.pop_front() {
295                for &(v, etype) in &adj_list[u] {
296                    if !visited[v] {
297                        visited[v] = true;
298                        parent[v] = Some((u, etype));
299                        queue.push_back(v);
300                    }
301                }
302            }
303
304            // For each target, trace back and compute average edge embedding
305            for dst in 0..n {
306                if src == dst || spd[[src, dst]] == 0 {
307                    continue;
308                }
309                if parent[dst].is_none() {
310                    continue; // unreachable
311                }
312
313                // Trace path and accumulate edge embeddings
314                let mut avg_embed = vec![0.0f64; self.hidden_dim];
315                let mut path_len = 0usize;
316                let mut cur = dst;
317
318                while let Some((p, etype)) = parent[cur] {
319                    for k in 0..self.hidden_dim {
320                        avg_embed[k] += self.edge_embed[[etype, k]];
321                    }
322                    path_len += 1;
323                    cur = p;
324                    if cur == src {
325                        break;
326                    }
327                }
328
329                if path_len > 0 {
330                    let inv = 1.0 / path_len as f64;
331                    let mut scalar = 0.0f64;
332                    for k in 0..self.hidden_dim {
333                        scalar += avg_embed[k] * inv * self.projection[k];
334                    }
335                    bias[[src, dst]] = scalar;
336                }
337            }
338        }
339
340        bias
341    }
342}
343
344// ============================================================================
345// Multi-head Self-Attention with Graph Structural Bias
346// ============================================================================
347
348/// Numerically-stable softmax over a row.
349fn softmax_row(row: &[f64]) -> Vec<f64> {
350    if row.is_empty() {
351        return Vec::new();
352    }
353    let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
354    let exps: Vec<f64> = row.iter().map(|x| (x - max_val).exp()).collect();
355    let sum = exps.iter().sum::<f64>().max(1e-12);
356    exps.iter().map(|e| e / sum).collect()
357}
358
359/// Layer normalization over the feature dimension.
360fn layer_norm(x: &mut [f64], eps: f64) {
361    let n = x.len();
362    if n == 0 {
363        return;
364    }
365    let mean = x.iter().sum::<f64>() / n as f64;
366    let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n as f64;
367    let inv_std = 1.0 / (var + eps).sqrt();
368    for v in x.iter_mut() {
369        *v = (*v - mean) * inv_std;
370    }
371}
372
373// ============================================================================
374// Graphormer Configuration
375// ============================================================================
376
377/// Configuration for the Graphormer model.
378#[derive(Debug, Clone)]
379pub struct GraphormerConfig {
380    /// Input feature dimension
381    pub in_dim: usize,
382    /// Hidden dimension (also the dimension for Q, K, V)
383    pub hidden_dim: usize,
384    /// Number of attention heads
385    pub num_heads: usize,
386    /// Number of Graphormer layers
387    pub num_layers: usize,
388    /// FFN intermediate dimension (typically 4 * hidden_dim)
389    pub ffn_dim: usize,
390    /// Maximum SPD for spatial encoding
391    pub max_distance: usize,
392    /// Maximum node degree for centrality encoding
393    pub max_degree: usize,
394    /// Maximum edge types for edge encoding
395    pub max_edge_types: usize,
396    /// Dropout rate (stored for reference, applied stochastically)
397    pub dropout: f64,
398    /// Layer norm epsilon
399    pub layer_norm_eps: f64,
400}
401
402impl Default for GraphormerConfig {
403    fn default() -> Self {
404        GraphormerConfig {
405            in_dim: 64,
406            hidden_dim: 64,
407            num_heads: 4,
408            num_layers: 3,
409            ffn_dim: 256,
410            max_distance: 10,
411            max_degree: 50,
412            max_edge_types: 4,
413            dropout: 0.1,
414            layer_norm_eps: 1e-5,
415        }
416    }
417}
418
419// ============================================================================
420// Graphormer Layer
421// ============================================================================
422
423/// A single Graphormer transformer layer.
424///
425/// Implements pre-norm architecture:
426/// ```text
427/// x = x + MHA(LayerNorm(x), spatial_bias, edge_bias)
428/// x = x + FFN(LayerNorm(x))
429/// ```
430#[derive(Debug, Clone)]
431pub struct GraphormerLayer {
432    /// Query projection: `[hidden_dim, hidden_dim]`
433    pub w_q: Array2<f64>,
434    /// Key projection: `[hidden_dim, hidden_dim]`
435    pub w_k: Array2<f64>,
436    /// Value projection: `[hidden_dim, hidden_dim]`
437    pub w_v: Array2<f64>,
438    /// Output projection: `[hidden_dim, hidden_dim]`
439    pub w_o: Array2<f64>,
440    /// FFN first linear: `[hidden_dim, ffn_dim]`
441    pub ffn_w1: Array2<f64>,
442    /// FFN second linear: `[ffn_dim, hidden_dim]`
443    pub ffn_w2: Array2<f64>,
444    /// FFN biases
445    pub ffn_b1: Array1<f64>,
446    /// FFN output bias
447    pub ffn_b2: Array1<f64>,
448    /// Number of attention heads
449    pub num_heads: usize,
450    /// Hidden dimension
451    pub hidden_dim: usize,
452    /// Dimension per head
453    pub head_dim: usize,
454    /// Layer norm epsilon
455    pub layer_norm_eps: f64,
456}
457
458impl GraphormerLayer {
459    /// Create a new Graphormer layer.
460    pub fn new(
461        hidden_dim: usize,
462        num_heads: usize,
463        ffn_dim: usize,
464        layer_norm_eps: f64,
465    ) -> Result<Self> {
466        if !hidden_dim.is_multiple_of(num_heads) {
467            return Err(GraphError::InvalidParameter {
468                param: "hidden_dim".to_string(),
469                value: format!("{hidden_dim}"),
470                expected: format!("divisible by num_heads={num_heads}"),
471                context: "GraphormerLayer::new".to_string(),
472            });
473        }
474
475        let head_dim = hidden_dim / num_heads;
476        let mut rng = scirs2_core::random::rng();
477        let w_scale = (6.0_f64 / (hidden_dim + hidden_dim) as f64).sqrt();
478        let ffn_scale = (6.0_f64 / (hidden_dim + ffn_dim) as f64).sqrt();
479
480        let mut init_w = |rows: usize, cols: usize, scale: f64| -> Array2<f64> {
481            Array2::from_shape_fn((rows, cols), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
482        };
483
484        Ok(GraphormerLayer {
485            w_q: init_w(hidden_dim, hidden_dim, w_scale),
486            w_k: init_w(hidden_dim, hidden_dim, w_scale),
487            w_v: init_w(hidden_dim, hidden_dim, w_scale),
488            w_o: init_w(hidden_dim, hidden_dim, w_scale),
489            ffn_w1: init_w(hidden_dim, ffn_dim, ffn_scale),
490            ffn_w2: init_w(ffn_dim, hidden_dim, ffn_scale),
491            ffn_b1: Array1::zeros(ffn_dim),
492            ffn_b2: Array1::zeros(hidden_dim),
493            num_heads,
494            hidden_dim,
495            head_dim,
496            layer_norm_eps,
497        })
498    }
499
500    /// Multi-head self-attention with spatial and edge bias.
501    ///
502    /// # Arguments
503    /// * `x` - Input features `[n, hidden_dim]`
504    /// * `spatial_bias` - Spatial bias matrix `[n, n]`
505    /// * `edge_bias` - Edge bias matrix `[n, n]`
506    ///
507    /// # Returns
508    /// Attention output `[n, hidden_dim]`
509    fn multi_head_attention(
510        &self,
511        x: &Array2<f64>,
512        spatial_bias: &Array2<f64>,
513        edge_bias: &Array2<f64>,
514    ) -> Array2<f64> {
515        let n = x.dim().0;
516        let d = self.hidden_dim;
517        let h = self.num_heads;
518        let dk = self.head_dim;
519        let scale = 1.0 / (dk as f64).sqrt();
520
521        // Compute Q, K, V: [n, hidden_dim]
522        let mut q = Array2::zeros((n, d));
523        let mut k = Array2::zeros((n, d));
524        let mut v = Array2::zeros((n, d));
525
526        for i in 0..n {
527            for j in 0..d {
528                let mut sq = 0.0;
529                let mut sk = 0.0;
530                let mut sv = 0.0;
531                for m in 0..d {
532                    let xi = x[[i, m]];
533                    sq += xi * self.w_q[[m, j]];
534                    sk += xi * self.w_k[[m, j]];
535                    sv += xi * self.w_v[[m, j]];
536                }
537                q[[i, j]] = sq;
538                k[[i, j]] = sk;
539                v[[i, j]] = sv;
540            }
541        }
542
543        // Multi-head attention with structural bias
544        let mut output = Array2::<f64>::zeros((n, d));
545
546        for head in 0..h {
547            let offset = head * dk;
548
549            // Compute attention scores for this head: [n, n]
550            let mut scores = vec![vec![0.0f64; n]; n];
551            for i in 0..n {
552                for j in 0..n {
553                    let mut dot = 0.0;
554                    for m in 0..dk {
555                        dot += q[[i, offset + m]] * k[[j, offset + m]];
556                    }
557                    // Add graph structural bias (shared across heads)
558                    scores[i][j] = dot * scale + spatial_bias[[i, j]] + edge_bias[[i, j]];
559                }
560            }
561
562            // Softmax per row and aggregate values
563            for i in 0..n {
564                let alphas = softmax_row(&scores[i]);
565                for j in 0..n {
566                    let a = alphas[j];
567                    for m in 0..dk {
568                        output[[i, offset + m]] += a * v[[j, offset + m]];
569                    }
570                }
571            }
572        }
573
574        // Output projection
575        let mut projected = Array2::zeros((n, d));
576        for i in 0..n {
577            for j in 0..d {
578                let mut s = 0.0;
579                for m in 0..d {
580                    s += output[[i, m]] * self.w_o[[m, j]];
581                }
582                projected[[i, j]] = s;
583            }
584        }
585
586        projected
587    }
588
589    /// Feed-forward network with GELU activation.
590    fn ffn(&self, x: &Array2<f64>) -> Array2<f64> {
591        let n = x.dim().0;
592        let ffn_dim = self.ffn_w1.dim().1;
593        let d = self.hidden_dim;
594
595        // First linear + GELU
596        let mut h = Array2::zeros((n, ffn_dim));
597        for i in 0..n {
598            for j in 0..ffn_dim {
599                let mut s = self.ffn_b1[j];
600                for m in 0..d {
601                    s += x[[i, m]] * self.ffn_w1[[m, j]];
602                }
603                // GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
604                let x3 = s * s * s;
605                let inner = std::f64::consts::FRAC_2_PI.sqrt() * (s + 0.044715 * x3);
606                h[[i, j]] = 0.5 * s * (1.0 + inner.tanh());
607            }
608        }
609
610        // Second linear
611        let mut out = Array2::zeros((n, d));
612        for i in 0..n {
613            for j in 0..d {
614                let mut s = self.ffn_b2[j];
615                for m in 0..ffn_dim {
616                    s += h[[i, m]] * self.ffn_w2[[m, j]];
617                }
618                out[[i, j]] = s;
619            }
620        }
621
622        out
623    }
624
625    /// Forward pass of one Graphormer layer.
626    ///
627    /// Pre-norm architecture:
628    /// ```text
629    /// x = x + MHA(LayerNorm(x))
630    /// x = x + FFN(LayerNorm(x))
631    /// ```
632    ///
633    /// # Arguments
634    /// * `x` - Input features `[n, hidden_dim]`
635    /// * `spatial_bias` - SPD-based attention bias `[n, n]`
636    /// * `edge_bias` - Edge encoding bias `[n, n]`
637    pub fn forward(
638        &self,
639        x: &Array2<f64>,
640        spatial_bias: &Array2<f64>,
641        edge_bias: &Array2<f64>,
642    ) -> Result<Array2<f64>> {
643        let (n, d) = x.dim();
644        if d != self.hidden_dim {
645            return Err(GraphError::InvalidParameter {
646                param: "x".to_string(),
647                value: format!("dim={d}"),
648                expected: format!("dim={}", self.hidden_dim),
649                context: "GraphormerLayer::forward".to_string(),
650            });
651        }
652
653        // Pre-norm for attention
654        let mut normed = x.clone();
655        for i in 0..n {
656            let mut row: Vec<f64> = (0..d).map(|j| normed[[i, j]]).collect();
657            layer_norm(&mut row, self.layer_norm_eps);
658            for j in 0..d {
659                normed[[i, j]] = row[j];
660            }
661        }
662
663        // Multi-head attention + residual
664        let attn_out = self.multi_head_attention(&normed, spatial_bias, edge_bias);
665        let mut out = x.clone();
666        for i in 0..n {
667            for j in 0..d {
668                out[[i, j]] += attn_out[[i, j]];
669            }
670        }
671
672        // Pre-norm for FFN
673        let mut normed2 = out.clone();
674        for i in 0..n {
675            let mut row: Vec<f64> = (0..d).map(|j| normed2[[i, j]]).collect();
676            layer_norm(&mut row, self.layer_norm_eps);
677            for j in 0..d {
678                normed2[[i, j]] = row[j];
679            }
680        }
681
682        // FFN + residual
683        let ffn_out = self.ffn(&normed2);
684        for i in 0..n {
685            for j in 0..d {
686                out[[i, j]] += ffn_out[[i, j]];
687            }
688        }
689
690        Ok(out)
691    }
692}
693
694// ============================================================================
695// Graphormer Model
696// ============================================================================
697
698/// Full Graphormer model stacking multiple transformer layers with
699/// centrality, spatial, and edge encodings.
700#[derive(Debug, Clone)]
701pub struct GraphormerModel {
702    /// Input projection: `[in_dim, hidden_dim]`
703    pub input_proj: Array2<f64>,
704    /// Centrality encoding module
705    pub centrality_encoding: CentralityEncoding,
706    /// Spatial encoding module
707    pub spatial_encoding: SpatialEncoding,
708    /// Edge encoding module
709    pub edge_encoding: EdgeEncoding,
710    /// Stack of Graphormer layers
711    pub layers: Vec<GraphormerLayer>,
712    /// Configuration
713    pub config: GraphormerConfig,
714}
715
716impl GraphormerModel {
717    /// Create a new Graphormer model from configuration.
718    pub fn new(config: GraphormerConfig) -> Result<Self> {
719        let mut rng = scirs2_core::random::rng();
720        let proj_scale = (6.0_f64 / (config.in_dim + config.hidden_dim) as f64).sqrt();
721        let input_proj = Array2::from_shape_fn((config.in_dim, config.hidden_dim), |_| {
722            (rng.random::<f64>() * 2.0 - 1.0) * proj_scale
723        });
724
725        let centrality_encoding = CentralityEncoding::new(config.max_degree, config.hidden_dim);
726        let spatial_encoding = SpatialEncoding::new(config.max_distance);
727        let edge_encoding = EdgeEncoding::new(config.max_edge_types, config.hidden_dim);
728
729        let mut layers = Vec::with_capacity(config.num_layers);
730        for _ in 0..config.num_layers {
731            layers.push(GraphormerLayer::new(
732                config.hidden_dim,
733                config.num_heads,
734                config.ffn_dim,
735                config.layer_norm_eps,
736            )?);
737        }
738
739        Ok(GraphormerModel {
740            input_proj,
741            centrality_encoding,
742            spatial_encoding,
743            edge_encoding,
744            layers,
745            config,
746        })
747    }
748
749    /// Forward pass of the full Graphormer model.
750    ///
751    /// # Arguments
752    /// * `features` - Input node features `[n_nodes, in_dim]`
753    /// * `adj` - Sparse adjacency matrix
754    ///
755    /// # Returns
756    /// Node embeddings `[n_nodes, hidden_dim]`
757    pub fn forward(&self, features: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
758        let (n, in_dim) = features.dim();
759        if in_dim != self.config.in_dim {
760            return Err(GraphError::InvalidParameter {
761                param: "features".to_string(),
762                value: format!("in_dim={in_dim}"),
763                expected: format!("in_dim={}", self.config.in_dim),
764                context: "GraphormerModel::forward".to_string(),
765            });
766        }
767        if adj.n_rows != n {
768            return Err(GraphError::InvalidParameter {
769                param: "adj".to_string(),
770                value: format!("n_rows={}", adj.n_rows),
771                expected: format!("n_rows={n}"),
772                context: "GraphormerModel::forward".to_string(),
773            });
774        }
775
776        // Project input to hidden dim
777        let d = self.config.hidden_dim;
778        let mut h = Array2::zeros((n, d));
779        for i in 0..n {
780            for j in 0..d {
781                let mut s = 0.0;
782                for m in 0..in_dim {
783                    s += features[[i, m]] * self.input_proj[[m, j]];
784                }
785                h[[i, j]] = s;
786            }
787        }
788
789        // Add centrality encoding
790        h = self.centrality_encoding.forward(&h, adj)?;
791
792        // Compute structural biases
793        let spatial_bias = self.spatial_encoding.forward(adj);
794        let spd = self.spatial_encoding.compute_spd_matrix(adj);
795        let edge_bias = self.edge_encoding.forward(adj, &spd);
796
797        // Apply Graphormer layers
798        for layer in &self.layers {
799            h = layer.forward(&h, &spatial_bias, &edge_bias)?;
800        }
801
802        Ok(h)
803    }
804}
805
806// ============================================================================
807// Tests
808// ============================================================================
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813
814    fn triangle_csr() -> CsrMatrix {
815        let coo = vec![
816            (0, 1, 1.0),
817            (1, 0, 1.0),
818            (1, 2, 1.0),
819            (2, 1, 1.0),
820            (0, 2, 1.0),
821            (2, 0, 1.0),
822        ];
823        CsrMatrix::from_coo(3, 3, &coo).expect("triangle CSR")
824    }
825
826    fn path_csr() -> CsrMatrix {
827        // Path graph: 0 -- 1 -- 2 -- 3
828        let coo = vec![
829            (0, 1, 1.0),
830            (1, 0, 1.0),
831            (1, 2, 1.0),
832            (2, 1, 1.0),
833            (2, 3, 1.0),
834            (3, 2, 1.0),
835        ];
836        CsrMatrix::from_coo(4, 4, &coo).expect("path CSR")
837    }
838
839    fn feats(n: usize, d: usize) -> Array2<f64> {
840        Array2::from_shape_fn((n, d), |(i, j)| (i * d + j) as f64 * 0.1)
841    }
842
843    #[test]
844    fn test_spatial_encoding_spd_matrix() {
845        let adj = path_csr();
846        let se = SpatialEncoding::new(10);
847        let spd = se.compute_spd_matrix(&adj);
848
849        // Self-distances should be 0
850        for i in 0..4 {
851            assert_eq!(spd[[i, i]], 0, "self-distance should be 0 for node {i}");
852        }
853
854        // Adjacent nodes should have distance 1
855        assert_eq!(spd[[0, 1]], 1);
856        assert_eq!(spd[[1, 2]], 1);
857        assert_eq!(spd[[2, 3]], 1);
858
859        // Path distances
860        assert_eq!(spd[[0, 2]], 2);
861        assert_eq!(spd[[0, 3]], 3);
862        assert_eq!(spd[[1, 3]], 2);
863
864        // Symmetry
865        for i in 0..4 {
866            for j in 0..4 {
867                assert_eq!(spd[[i, j]], spd[[j, i]], "SPD should be symmetric");
868            }
869        }
870    }
871
872    #[test]
873    fn test_centrality_encoding_degrees() {
874        let adj = triangle_csr();
875        let ce = CentralityEncoding::new(10, 8);
876        let (in_deg, out_deg) = ce.compute_degrees(&adj);
877
878        // Triangle: each node has degree 2 (2 outgoing edges in the symmetric representation)
879        for i in 0..3 {
880            assert_eq!(in_deg[i], 2, "in-degree of node {i}");
881            assert_eq!(out_deg[i], 2, "out-degree of node {i}");
882        }
883    }
884
885    #[test]
886    fn test_centrality_encoding_forward_shape() {
887        let adj = triangle_csr();
888        let ce = CentralityEncoding::new(10, 8);
889        let features = feats(3, 8);
890        let result = ce.forward(&features, &adj).expect("centrality forward");
891        assert_eq!(result.dim(), (3, 8));
892
893        // Output should differ from input (centrality added)
894        let mut differs = false;
895        for i in 0..3 {
896            for j in 0..8 {
897                if (result[[i, j]] - features[[i, j]]).abs() > 1e-12 {
898                    differs = true;
899                }
900            }
901        }
902        assert!(differs, "centrality encoding should modify features");
903    }
904
905    #[test]
906    fn test_graphormer_attention_with_bias_output_shape() {
907        let adj = triangle_csr();
908        let config = GraphormerConfig {
909            in_dim: 4,
910            hidden_dim: 8,
911            num_heads: 2,
912            num_layers: 1,
913            ffn_dim: 16,
914            max_distance: 5,
915            max_degree: 10,
916            max_edge_types: 2,
917            ..Default::default()
918        };
919
920        let layer = GraphormerLayer::new(8, 2, 16, 1e-5).expect("layer");
921        let x = feats(3, 8);
922        let se = SpatialEncoding::new(5);
923        let spatial_bias = se.forward(&adj);
924        let edge_bias = Array2::zeros((3, 3));
925
926        let out = layer
927            .forward(&x, &spatial_bias, &edge_bias)
928            .expect("forward");
929        assert_eq!(out.dim(), (3, 8));
930        for &v in out.iter() {
931            assert!(v.is_finite(), "output should be finite, got {v}");
932        }
933    }
934
935    #[test]
936    fn test_graphormer_model_forward() {
937        let adj = triangle_csr();
938        let config = GraphormerConfig {
939            in_dim: 4,
940            hidden_dim: 8,
941            num_heads: 2,
942            num_layers: 2,
943            ffn_dim: 16,
944            max_distance: 5,
945            max_degree: 10,
946            max_edge_types: 2,
947            ..Default::default()
948        };
949
950        let model = GraphormerModel::new(config).expect("model");
951        let features = feats(3, 4);
952        let out = model.forward(&features, &adj).expect("forward");
953        assert_eq!(out.dim(), (3, 8));
954        for &v in out.iter() {
955            assert!(v.is_finite(), "output should be finite, got {v}");
956        }
957    }
958
959    #[test]
960    fn test_graphormer_edge_encoding() {
961        let adj = path_csr();
962        let se = SpatialEncoding::new(5);
963        let spd = se.compute_spd_matrix(&adj);
964        let ee = EdgeEncoding::new(2, 4);
965        let bias = ee.forward(&adj, &spd);
966
967        assert_eq!(bias.dim(), (4, 4));
968        // Diagonal should be 0 (no self-path edges)
969        for i in 0..4 {
970            assert!(bias[[i, i]].abs() < 1e-12, "self edge bias should be 0");
971        }
972        // Off-diagonal should have values for connected pairs
973        for &v in bias.iter() {
974            assert!(v.is_finite(), "edge bias should be finite");
975        }
976    }
977
978    #[test]
979    fn test_graphormer_invalid_hidden_dim() {
980        // hidden_dim=7 not divisible by num_heads=2
981        let result = GraphormerLayer::new(7, 2, 16, 1e-5);
982        assert!(result.is_err());
983    }
984
985    #[test]
986    fn test_spatial_encoding_disconnected() {
987        // Two disconnected components: {0, 1} and {2, 3}
988        let coo = vec![(0, 1, 1.0), (1, 0, 1.0), (2, 3, 1.0), (3, 2, 1.0)];
989        let adj = CsrMatrix::from_coo(4, 4, &coo).expect("disconnected CSR");
990        let se = SpatialEncoding::new(5);
991        let spd = se.compute_spd_matrix(&adj);
992
993        // Within-component distances
994        assert_eq!(spd[[0, 1]], 1);
995        assert_eq!(spd[[2, 3]], 1);
996
997        // Cross-component: should be max_distance + 1 = 6
998        assert_eq!(spd[[0, 2]], 6);
999        assert_eq!(spd[[0, 3]], 6);
1000        assert_eq!(spd[[1, 2]], 6);
1001    }
1002}