scirs2_graph/gnn/rgcn.rs
1//! Relational Graph Convolutional Network (R-GCN)
2//!
3//! Implements the encoder from Schlichtkrull et al. (2018),
4//! "Modeling Relational Data with Graph Convolutional Networks".
5//!
6//! R-GCN extends GCN to handle multi-relational graphs (knowledge graphs)
7//! by maintaining a separate weight matrix per relation type. To keep the
8//! parameter count manageable, **basis decomposition** is used:
9//!
10//! ```text
11//! W_r = Σ_b a_{r,b} V_b
12//! ```
13//!
14//! where `V_b` are `n_bases` shared basis matrices and `a_{r,b}` are
15//! per-relation scalar coefficients.
16//!
17//! The layer update rule is:
18//! ```text
19//! h_i^{(l+1)} = ReLU( Σ_r Σ_{j ∈ N_r(i)} (W_r h_j^{(l)}) / |N_r(i)|
20//! + W_0 h_i^{(l)} )
21//! ```
22
23use std::collections::HashMap;
24
25use scirs2_core::ndarray::{Array1, Array2};
26use scirs2_core::random::{Rng, RngExt};
27
28use crate::error::{GraphError, Result};
29
30// ============================================================================
31// Helpers
32// ============================================================================
33
34/// Xavier uniform initialisation: U[-gain*sqrt(6/(fan_in+fan_out)), ...]
35fn xavier_uniform(rows: usize, cols: usize) -> Array2<f64> {
36 let mut rng = scirs2_core::random::rng();
37 let limit = (6.0_f64 / (rows + cols) as f64).sqrt();
38 Array2::from_shape_fn((rows, cols), |_| rng.random::<f64>() * 2.0 * limit - limit)
39}
40
41/// ReLU activation applied element-wise to a 2-D array.
42fn relu2(x: &Array2<f64>) -> Array2<f64> {
43 x.mapv(|v| v.max(0.0))
44}
45
46// ============================================================================
47// RgcnConfig
48// ============================================================================
49
50/// Configuration for an R-GCN encoder.
51#[derive(Debug, Clone)]
52pub struct RgcnConfig {
53 /// Hidden (output) dimensionality of each layer
54 pub hidden_dim: usize,
55 /// Number of basis matrices for weight decomposition
56 pub n_bases: usize,
57 /// Number of R-GCN layers to stack
58 pub n_layers: usize,
59 /// Dropout probability (applied on node features between layers)
60 pub dropout: f64,
61 /// Whether to include a self-loop weight W_0
62 pub self_loop: bool,
63}
64
65impl Default for RgcnConfig {
66 fn default() -> Self {
67 Self {
68 hidden_dim: 64,
69 n_bases: 4,
70 n_layers: 2,
71 dropout: 0.1,
72 self_loop: true,
73 }
74 }
75}
76
77// ============================================================================
78// RgcnBasisDecomposition
79// ============================================================================
80
81/// Basis-decomposition weight for a single R-GCN layer.
82///
83/// Stores `n_bases` shared matrices `V_b` (each `out_dim × in_dim`) and
84/// per-relation coefficient vectors `a_r` (length `n_bases`). The effective
85/// weight for relation `r` is:
86/// ```text
87/// W_r = Σ_b a_{r,b} V_b
88/// ```
89#[derive(Debug, Clone)]
90pub struct RgcnBasisDecomposition {
91 /// Shared basis matrices V_b, each of shape `(out_dim, in_dim)`.
92 pub basis_matrices: Vec<Array2<f64>>,
93 /// Per-relation coefficients: `coefficients[r][b] = a_{r,b}`.
94 pub coefficients: Vec<Vec<f64>>,
95 /// Input feature dimension.
96 pub in_dim: usize,
97 /// Output feature dimension.
98 pub out_dim: usize,
99}
100
101impl RgcnBasisDecomposition {
102 /// Build a new basis decomposition for `n_relations` relation types.
103 ///
104 /// # Arguments
105 /// * `in_dim` – Input feature dimensionality.
106 /// * `out_dim` – Output feature dimensionality.
107 /// * `n_bases` – Number of shared basis matrices.
108 /// * `n_relations` – Number of distinct relation types.
109 pub fn new(in_dim: usize, out_dim: usize, n_bases: usize, n_relations: usize) -> Result<Self> {
110 if n_bases == 0 {
111 return Err(GraphError::InvalidParameter {
112 param: "n_bases".to_string(),
113 value: "0".to_string(),
114 expected: ">= 1".to_string(),
115 context: "RgcnBasisDecomposition::new".to_string(),
116 });
117 }
118
119 let basis_matrices: Vec<Array2<f64>> = (0..n_bases)
120 .map(|_| xavier_uniform(out_dim, in_dim))
121 .collect();
122
123 let mut rng = scirs2_core::random::rng();
124 let coefficients: Vec<Vec<f64>> = (0..n_relations)
125 .map(|_| (0..n_bases).map(|_| rng.random::<f64>() * 0.1).collect())
126 .collect();
127
128 Ok(Self {
129 basis_matrices,
130 coefficients,
131 in_dim,
132 out_dim,
133 })
134 }
135
136 /// Compute the effective weight matrix for `relation`.
137 ///
138 /// Returns `W_r = Σ_b a_{r,b} V_b` as an `(out_dim, in_dim)` matrix.
139 pub fn effective_weight(&self, relation: usize) -> Result<Array2<f64>> {
140 let coeffs =
141 self.coefficients
142 .get(relation)
143 .ok_or_else(|| GraphError::InvalidParameter {
144 param: "relation".to_string(),
145 value: relation.to_string(),
146 expected: format!("< {}", self.coefficients.len()),
147 context: "RgcnBasisDecomposition::effective_weight".to_string(),
148 })?;
149
150 let mut w = Array2::<f64>::zeros((self.out_dim, self.in_dim));
151 for (b, &coeff) in coeffs.iter().enumerate() {
152 w = w + coeff * &self.basis_matrices[b];
153 }
154 Ok(w)
155 }
156}
157
158// ============================================================================
159// RgcnLayer
160// ============================================================================
161
162/// Single R-GCN layer.
163///
164/// # Forward pass
165/// For each relation `r` and each node `i`:
166/// 1. Aggregate messages: `M_{r,i} = mean_{j ∈ N_r(i)} W_r h_j`
167/// 2. Combine: `h_i' = ReLU( Σ_r M_{r,i} + W_0 h_i )`
168#[derive(Debug, Clone)]
169pub struct RgcnLayer {
170 /// Basis decomposition (handles all relation-specific weights)
171 pub basis_decomp: RgcnBasisDecomposition,
172 /// Self-loop weight W_0 (shape `out_dim × in_dim`), `None` when `self_loop=false`
173 pub self_loop_weight: Option<Array2<f64>>,
174 /// Bias term (length `out_dim`)
175 pub bias: Array1<f64>,
176 /// Number of relation types
177 pub n_relations: usize,
178 /// Output dimensionality
179 pub out_dim: usize,
180}
181
182impl RgcnLayer {
183 /// Construct a new R-GCN layer.
184 ///
185 /// # Arguments
186 /// * `in_dim` – Input feature dimensionality.
187 /// * `out_dim` – Output feature dimensionality.
188 /// * `n_relations` – Number of distinct relation types.
189 /// * `n_bases` – Number of basis matrices for weight decomposition.
190 /// * `self_loop` – Whether to include a self-loop weight.
191 pub fn new(
192 in_dim: usize,
193 out_dim: usize,
194 n_relations: usize,
195 n_bases: usize,
196 self_loop: bool,
197 ) -> Result<Self> {
198 let basis_decomp = RgcnBasisDecomposition::new(in_dim, out_dim, n_bases, n_relations)?;
199 let self_loop_weight = if self_loop {
200 Some(xavier_uniform(out_dim, in_dim))
201 } else {
202 None
203 };
204 Ok(Self {
205 basis_decomp,
206 self_loop_weight,
207 bias: Array1::zeros(out_dim),
208 n_relations,
209 out_dim,
210 })
211 }
212
213 /// Forward pass over a heterogeneous graph.
214 ///
215 /// # Arguments
216 /// * `node_feats` – Node feature matrix `(n_nodes, in_dim)`.
217 /// * `adj_by_relation` – For each relation `r`, a list of `(src, dst)` edges.
218 ///
219 /// # Returns
220 /// Updated node feature matrix `(n_nodes, out_dim)`.
221 pub fn forward(
222 &self,
223 node_feats: &Array2<f64>,
224 adj_by_relation: &[Vec<(usize, usize)>],
225 ) -> Result<Array2<f64>> {
226 let n_nodes = node_feats.nrows();
227 let in_dim = node_feats.ncols();
228
229 if in_dim != self.basis_decomp.in_dim {
230 return Err(GraphError::InvalidParameter {
231 param: "node_feats".to_string(),
232 value: format!("in_dim={}", in_dim),
233 expected: format!("in_dim={}", self.basis_decomp.in_dim),
234 context: "RgcnLayer::forward".to_string(),
235 });
236 }
237
238 // Accumulator for the combined relational aggregation
239 let mut combined = Array2::<f64>::zeros((n_nodes, self.out_dim));
240
241 // ---- Relational aggregation ----------------------------------------
242 for (r, edges) in adj_by_relation.iter().enumerate() {
243 if r >= self.n_relations {
244 break;
245 }
246 let w_r = self.basis_decomp.effective_weight(r)?;
247
248 // Count in-degree per destination for normalisation
249 let mut in_deg: Vec<usize> = vec![0usize; n_nodes];
250 for &(_, dst) in edges {
251 if dst < n_nodes {
252 in_deg[dst] += 1;
253 }
254 }
255
256 // Aggregate: sum W_r h_j for each destination
257 for &(src, dst) in edges {
258 if src >= n_nodes || dst >= n_nodes {
259 continue;
260 }
261 let h_j = node_feats.row(src);
262 // w_r has shape (out_dim, in_dim); multiply w_r @ h_j
263 let msg = w_r.dot(&h_j);
264 let deg = in_deg[dst].max(1) as f64;
265 let mut row = combined.row_mut(dst);
266 row.zip_mut_with(&msg, |acc, &m| *acc += m / deg);
267 }
268 }
269
270 // ---- Self-loop contribution ----------------------------------------
271 if let Some(ref w0) = self.self_loop_weight {
272 for i in 0..n_nodes {
273 let h_i = node_feats.row(i);
274 let self_msg = w0.dot(&h_i);
275 let mut row = combined.row_mut(i);
276 row.zip_mut_with(&self_msg, |acc, &v| *acc += v);
277 }
278 }
279
280 // ---- Bias + ReLU ---------------------------------------------------
281 for mut row in combined.rows_mut() {
282 row.zip_mut_with(&self.bias, |v, &b| *v += b);
283 }
284
285 Ok(relu2(&combined))
286 }
287}
288
289// ============================================================================
290// Rgcn — stacked R-GCN
291// ============================================================================
292
293/// Multi-layer R-GCN encoder.
294///
295/// Stacks `n_layers` [`RgcnLayer`]s, projecting node features from
296/// `in_dim` → `hidden_dim` → … → `hidden_dim`.
297#[derive(Debug, Clone)]
298pub struct Rgcn {
299 /// Ordered list of R-GCN layers
300 pub layers: Vec<RgcnLayer>,
301}
302
303impl Rgcn {
304 /// Build an R-GCN from a [`RgcnConfig`].
305 ///
306 /// # Arguments
307 /// * `in_dim` – Initial node feature dimensionality.
308 /// * `n_relations` – Number of relation types in the graph.
309 /// * `config` – Hyper-parameter configuration.
310 pub fn new(in_dim: usize, n_relations: usize, config: &RgcnConfig) -> Result<Self> {
311 let mut layers = Vec::with_capacity(config.n_layers);
312 for i in 0..config.n_layers {
313 let layer_in = if i == 0 { in_dim } else { config.hidden_dim };
314 let layer = RgcnLayer::new(
315 layer_in,
316 config.hidden_dim,
317 n_relations,
318 config.n_bases,
319 config.self_loop,
320 )?;
321 layers.push(layer);
322 }
323 Ok(Self { layers })
324 }
325
326 /// Forward pass through all layers.
327 ///
328 /// # Arguments
329 /// * `node_feats` – Initial node features `(n_nodes, in_dim)`.
330 /// * `adj_by_relation` – Per-relation edge lists `(src, dst)`.
331 ///
332 /// # Returns
333 /// Final node embeddings `(n_nodes, hidden_dim)`.
334 pub fn forward(
335 &self,
336 node_feats: &Array2<f64>,
337 adj_by_relation: &[Vec<(usize, usize)>],
338 ) -> Result<Array2<f64>> {
339 let mut h = node_feats.clone();
340 for layer in &self.layers {
341 h = layer.forward(&h, adj_by_relation)?;
342 }
343 Ok(h)
344 }
345}
346
347// ============================================================================
348// DistMult scoring
349// ============================================================================
350
351/// DistMult bilinear scoring model (Yang et al. 2015).
352///
353/// Score function:
354/// ```text
355/// score(h, r, t) = Σ_k h_k · r_k · t_k
356/// ```
357///
358/// This is a symmetric scoring function: `score(h,r,t) = score(t,r,h)`.
359#[derive(Debug, Clone)]
360pub struct DistMultScoring {
361 /// Entity embedding table `(n_entities, dim)`
362 pub entity_embeds: Array2<f64>,
363 /// Relation embedding table `(n_relations, dim)`
364 pub relation_embeds: Array2<f64>,
365}
366
367impl DistMultScoring {
368 /// Create a new DistMult scorer with Xavier-initialised embeddings.
369 pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Self {
370 Self {
371 entity_embeds: xavier_uniform(n_entities, dim),
372 relation_embeds: xavier_uniform(n_relations, dim),
373 }
374 }
375
376 /// Compute DistMult score for triple `(h, r, t)`.
377 pub fn score(&self, h: usize, r: usize, t: usize) -> f64 {
378 let h_emb = self.entity_embeds.row(h);
379 let r_emb = self.relation_embeds.row(r);
380 let t_emb = self.entity_embeds.row(t);
381 h_emb
382 .iter()
383 .zip(r_emb.iter())
384 .zip(t_emb.iter())
385 .map(|((&hk, &rk), &tk)| hk * rk * tk)
386 .sum()
387 }
388}
389
390// ============================================================================
391// KgScorer trait (also used by kg_completion module)
392// ============================================================================
393
394/// Trait for knowledge-graph triple scoring models.
395pub trait KgScorer: Send + Sync {
396 /// Return a scalar score for the triple `(h, r, t)`.
397 ///
398 /// Higher scores indicate more plausible triples.
399 fn score(&self, h: usize, r: usize, t: usize) -> f64;
400}
401
402impl KgScorer for DistMultScoring {
403 fn score(&self, h: usize, r: usize, t: usize) -> f64 {
404 DistMultScoring::score(self, h, r, t)
405 }
406}
407
408// ============================================================================
409// RgcnLinkPredictor
410// ============================================================================
411
412/// End-to-end R-GCN encoder + DistMult decoder for link prediction.
413///
414/// The encoder produces node embeddings with R-GCN; the decoder scores triples
415/// using the DistMult bilinear form.
416#[derive(Debug, Clone)]
417pub struct RgcnLinkPredictor {
418 /// R-GCN encoder
419 pub encoder: Rgcn,
420 /// DistMult decoder
421 pub decoder: DistMultScoring,
422 /// Cached node embeddings (populated after calling `encode`)
423 pub node_embeddings: Option<Array2<f64>>,
424 /// Relation embedding dimension (same as hidden_dim)
425 pub dim: usize,
426 /// Number of relation types
427 pub n_relations: usize,
428}
429
430impl RgcnLinkPredictor {
431 /// Create a new link predictor.
432 ///
433 /// # Arguments
434 /// * `in_dim` – Raw input node feature dimensionality.
435 /// * `n_entities` – Number of entities in the KG.
436 /// * `n_relations` – Number of relation types.
437 /// * `config` – R-GCN hyper-parameters.
438 pub fn new(
439 in_dim: usize,
440 n_entities: usize,
441 n_relations: usize,
442 config: &RgcnConfig,
443 ) -> Result<Self> {
444 let encoder = Rgcn::new(in_dim, n_relations, config)?;
445 let decoder = DistMultScoring::new(n_entities, n_relations, config.hidden_dim);
446 Ok(Self {
447 encoder,
448 decoder,
449 node_embeddings: None,
450 dim: config.hidden_dim,
451 n_relations,
452 })
453 }
454
455 /// Run the R-GCN encoder and cache the node embeddings.
456 pub fn encode(
457 &mut self,
458 node_feats: &Array2<f64>,
459 adj_by_relation: &[Vec<(usize, usize)>],
460 ) -> Result<()> {
461 let h = self.encoder.forward(node_feats, adj_by_relation)?;
462 self.node_embeddings = Some(h);
463 Ok(())
464 }
465
466 /// Score a triple using encoder embeddings (falls back to DistMult table if
467 /// `encode` has not been called).
468 pub fn score_triple(&self, h: usize, r: usize, t: usize) -> f64 {
469 match &self.node_embeddings {
470 Some(embeds) => {
471 let h_emb = embeds.row(h);
472 let r_emb = self.decoder.relation_embeds.row(r);
473 let t_emb = embeds.row(t);
474 h_emb
475 .iter()
476 .zip(r_emb.iter())
477 .zip(t_emb.iter())
478 .map(|((&hk, &rk), &tk)| hk * rk * tk)
479 .sum()
480 }
481 None => self.decoder.score(h, r, t),
482 }
483 }
484}
485
486impl KgScorer for RgcnLinkPredictor {
487 fn score(&self, h: usize, r: usize, t: usize) -> f64 {
488 self.score_triple(h, r, t)
489 }
490}
491
492// ============================================================================
493// Tests
494// ============================================================================
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use scirs2_core::ndarray::Array2;
500
501 fn eye_feats(n: usize, dim: usize) -> Array2<f64> {
502 let mut m = Array2::<f64>::zeros((n, dim));
503 for i in 0..n.min(dim) {
504 m[(i, i)] = 1.0;
505 }
506 m
507 }
508
509 fn single_relation_adj(n: usize) -> Vec<Vec<(usize, usize)>> {
510 // A ring graph with one relation type
511 let edges: Vec<(usize, usize)> = (0..n).map(|i| (i, (i + 1) % n)).collect();
512 vec![edges]
513 }
514
515 #[test]
516 fn test_basis_decomp_single_basis_recovers_weight() {
517 // With n_bases=1 and coefficient 1.0 the effective weight equals the
518 // single basis matrix exactly.
519 let mut decomp = RgcnBasisDecomposition::new(4, 4, 1, 1).expect("decomp");
520 // Force coefficient to 1.0
521 decomp.coefficients[0][0] = 1.0;
522 let w = decomp.effective_weight(0).expect("w");
523 let diff = (&w - &decomp.basis_matrices[0]).mapv(|v| v.abs()).sum();
524 assert!(
525 diff < 1e-10,
526 "effective_weight should equal basis[0] when a=1"
527 );
528 }
529
530 #[test]
531 fn test_rgcn_layer_output_shape() {
532 let feats = eye_feats(5, 8);
533 let adj = single_relation_adj(5);
534 let layer = RgcnLayer::new(8, 16, 1, 2, true).expect("layer");
535 let out = layer.forward(&feats, &adj).expect("forward");
536 assert_eq!(out.nrows(), 5);
537 assert_eq!(out.ncols(), 16);
538 }
539
540 #[test]
541 fn test_rgcn_layer_isolated_node_self_loop() {
542 // Isolated node (no edges) should still get a non-zero output via self-loop.
543 let feats = eye_feats(3, 4);
544 let adj: Vec<Vec<(usize, usize)>> = vec![vec![]]; // no edges
545 let layer = RgcnLayer::new(4, 4, 1, 1, true).expect("layer");
546 let out = layer.forward(&feats, &adj).expect("forward");
547 // Not all-zero (due to self-loop + ReLU + non-zero weights)
548 let row_norm: f64 = out.row(0).iter().map(|x| x * x).sum::<f64>().sqrt();
549 assert!(row_norm >= 0.0, "isolated node output must be finite");
550 assert_eq!(out.nrows(), 3);
551 }
552
553 #[test]
554 fn test_rgcn_stacked_output_shape() {
555 let config = RgcnConfig {
556 hidden_dim: 8,
557 n_bases: 2,
558 n_layers: 3,
559 ..Default::default()
560 };
561 let feats = eye_feats(6, 4);
562 let adj = single_relation_adj(6);
563 let rgcn = Rgcn::new(4, 1, &config).expect("rgcn");
564 let out = rgcn.forward(&feats, &adj).expect("forward");
565 assert_eq!(out.nrows(), 6);
566 assert_eq!(out.ncols(), 8);
567 }
568
569 #[test]
570 fn test_distmult_symmetry() {
571 // DistMult scoring is symmetric: score(h,r,t) == score(t,r,h)
572 let dm = DistMultScoring::new(4, 2, 8);
573 let s1 = dm.score(0, 0, 1);
574 let s2 = dm.score(1, 0, 0);
575 assert!((s1 - s2).abs() < 1e-10, "DistMult should be symmetric");
576 }
577
578 #[test]
579 fn test_rgcn_link_predictor_encode() {
580 let config = RgcnConfig::default();
581 let mut predictor = RgcnLinkPredictor::new(4, 5, 2, &config).expect("predictor");
582 let feats = eye_feats(5, 4);
583 let adj: Vec<Vec<(usize, usize)>> = vec![vec![(0, 1), (1, 2)], vec![(2, 3)]];
584 predictor.encode(&feats, &adj).expect("encode");
585 // After encoding, node_embeddings should be populated
586 assert!(predictor.node_embeddings.is_some());
587 let embeds = predictor.node_embeddings.as_ref().expect("embeds");
588 assert_eq!(embeds.nrows(), 5);
589 assert_eq!(embeds.ncols(), config.hidden_dim);
590 // Score should be finite
591 let s = predictor.score_triple(0, 0, 1);
592 assert!(s.is_finite());
593 }
594}