scirs2_graph/gnn/relation_message.rs
1//! Relation-aware message passing primitives
2//!
3//! This module provides:
4//!
5//! - [`RelationEmbedding`] – a learnable embedding table for relation types.
6//! - [`RotatEScoring`] – RotatE scoring function (Sun et al. 2019), which
7//! models each relation as a **rotation** in complex space:
8//! ```text
9//! score(h, r, t) = − ‖ h_e ∘ r_e − t_e ‖
10//! ```
11//! where `∘` denotes element-wise complex multiplication
12//! (`r_e = exp(i θ_r)` so `|r_e| = 1`).
13//! - [`HeterogeneousAdjacency`] – compact adjacency storage that organises
14//! edges both by relation type and by source/destination node type.
15
16use std::collections::HashMap;
17
18use scirs2_core::ndarray::{Array1, Array2};
19use scirs2_core::random::{Rng, RngExt};
20
21use crate::gnn::rgcn::KgScorer;
22
23// ============================================================================
24// Helpers
25// ============================================================================
26
27/// Xavier uniform initialisation.
28fn xavier_uniform(rows: usize, cols: usize) -> Array2<f64> {
29 let mut rng = scirs2_core::random::rng();
30 let limit = (6.0_f64 / (rows + cols) as f64).sqrt();
31 Array2::from_shape_fn((rows, cols), |_| rng.random::<f64>() * 2.0 * limit - limit)
32}
33
34// ============================================================================
35// RelationEmbedding
36// ============================================================================
37
38/// Learnable embedding table for relation types.
39///
40/// Embeddings are initialised with Xavier uniform and can be updated via
41/// gradient-based methods (integration with an optimiser is the caller's
42/// responsibility; this struct just holds the parameter tensor).
43#[derive(Debug, Clone)]
44pub struct RelationEmbedding {
45 /// Embedding matrix `(n_relations, dim)`.
46 pub table: Array2<f64>,
47 /// Number of relation types.
48 pub n_relations: usize,
49 /// Embedding dimensionality.
50 pub dim: usize,
51}
52
53impl RelationEmbedding {
54 /// Create a new relation embedding table with Xavier initialisation.
55 ///
56 /// # Arguments
57 /// * `n_relations` – Number of distinct relation types.
58 /// * `dim` – Embedding dimensionality.
59 pub fn new(n_relations: usize, dim: usize) -> Self {
60 Self {
61 table: xavier_uniform(n_relations, dim),
62 n_relations,
63 dim,
64 }
65 }
66
67 /// Look up the embedding vector for relation `r`.
68 ///
69 /// Returns `None` if `r >= n_relations`.
70 pub fn get(&self, r: usize) -> Option<Array1<f64>> {
71 if r < self.n_relations {
72 Some(self.table.row(r).to_owned())
73 } else {
74 None
75 }
76 }
77}
78
79// ============================================================================
80// RotatEScoring
81// ============================================================================
82
83/// RotatE scoring model (Sun et al. 2019).
84///
85/// Entity embeddings are complex-valued: each entity `e` has a real part
86/// `entity_re[e]` and an imaginary part `entity_im[e]`, both of shape
87/// `(n_entities, dim/2)` (so the full complex embedding has `dim` real
88/// parameters per entity).
89///
90/// Relation embeddings are phase angles `θ_r ∈ (−π, π]` of shape
91/// `(n_relations, dim/2)`. The unit-modulus rotation vector is
92/// `r_e = cos(θ_r) + i sin(θ_r)`.
93///
94/// Score function (negated L2 in complex space):
95/// ```text
96/// score(h, r, t) = − ‖ h_re ⊙ cos(θ_r) − h_im ⊙ sin(θ_r) − t_re
97/// + i ( h_re ⊙ sin(θ_r) + h_im ⊙ cos(θ_r) − t_im ) ‖
98/// ```
99/// Higher scores (less negative) imply a more plausible triple.
100#[derive(Debug, Clone)]
101pub struct RotatEScoring {
102 /// Real part of entity embeddings `(n_entities, dim)`.
103 pub entity_re: Array2<f64>,
104 /// Imaginary part of entity embeddings `(n_entities, dim)`.
105 pub entity_im: Array2<f64>,
106 /// Relation phase angles `(n_relations, dim)`.
107 pub relation_phase: Array2<f64>,
108 /// Number of entities.
109 pub n_entities: usize,
110 /// Number of relation types.
111 pub n_relations: usize,
112 /// Half-embedding dimension (complex degree of freedom per entity/relation).
113 pub dim: usize,
114}
115
116impl RotatEScoring {
117 /// Create a new RotatE model with random initialisation.
118 ///
119 /// Entity embeddings are Xavier-initialised; relation phases are uniform
120 /// in `(−π, π]`.
121 ///
122 /// # Arguments
123 /// * `n_entities` – Number of distinct entities.
124 /// * `n_relations` – Number of distinct relation types.
125 /// * `dim` – Complex embedding half-dimension (full parameter count
126 /// per entity is `2 * dim`).
127 pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Self {
128 let mut rng = scirs2_core::random::rng();
129 let entity_re = xavier_uniform(n_entities, dim);
130 let entity_im = xavier_uniform(n_entities, dim);
131 // Relation phases uniform in (−π, π]
132 let relation_phase = Array2::from_shape_fn((n_relations, dim), |_| {
133 (rng.random::<f64>() * 2.0 - 1.0) * std::f64::consts::PI
134 });
135 Self {
136 entity_re,
137 entity_im,
138 relation_phase,
139 n_entities,
140 n_relations,
141 dim,
142 }
143 }
144
145 /// Compute the RotatE score for triple `(h, r, t)`.
146 ///
147 /// Returns the **negated** L2 norm of `h ∘ r − t` in complex space.
148 /// Scores closer to 0 indicate a more plausible triple.
149 pub fn score_triple(&self, h: usize, r: usize, t: usize) -> f64 {
150 let h_re = self.entity_re.row(h);
151 let h_im = self.entity_im.row(h);
152 let t_re = self.entity_re.row(t);
153 let t_im = self.entity_im.row(t);
154 let phase = self.relation_phase.row(r);
155
156 let mut norm_sq = 0.0_f64;
157 for k in 0..self.dim {
158 let cos_r = phase[k].cos();
159 let sin_r = phase[k].sin();
160 // Real part of (h ∘ r − t)
161 let diff_re = h_re[k] * cos_r - h_im[k] * sin_r - t_re[k];
162 // Imaginary part of (h ∘ r − t)
163 let diff_im = h_re[k] * sin_r + h_im[k] * cos_r - t_im[k];
164 norm_sq += diff_re * diff_re + diff_im * diff_im;
165 }
166 -norm_sq.sqrt()
167 }
168}
169
170impl KgScorer for RotatEScoring {
171 fn score(&self, h: usize, r: usize, t: usize) -> f64 {
172 self.score_triple(h, r, t)
173 }
174}
175
176// ============================================================================
177// HeterogeneousAdjacency
178// ============================================================================
179
180/// Compact adjacency representation for heterogeneous graphs.
181///
182/// Stores edges both:
183/// - Per relation type: `by_relation[r]` = list of `(src, dst)` edges with
184/// relation type `r`.
185/// - Per `(src_node_type, dst_node_type)` pair: `by_node_type[(ts, tt)]` =
186/// list of all node indices `src` that have at least one edge to a node of
187/// type `tt`.
188#[derive(Debug, Clone)]
189pub struct HeterogeneousAdjacency {
190 /// Adjacency list per relation type: `by_relation[r]` = `Vec<(src, dst)>`.
191 pub by_relation: Vec<Vec<(usize, usize)>>,
192 /// Nodes per `(src_type, dst_type)` pair.
193 ///
194 /// `by_node_type[(src_type, dst_type)]` contains the *source* node indices
195 /// for edges going from a node of `src_type` to a node of `dst_type`.
196 pub by_node_type: HashMap<(usize, usize), Vec<usize>>,
197 /// Number of relation types
198 pub n_relations: usize,
199 /// Number of node types
200 pub n_node_types: usize,
201}
202
203impl HeterogeneousAdjacency {
204 /// Construct a [`HeterogeneousAdjacency`] from a list of typed edges.
205 ///
206 /// # Arguments
207 /// * `n_relations` – Total number of distinct relation types.
208 /// * `n_node_types` – Total number of distinct node types.
209 /// * `node_types` – Node type assignment for every node `(len = n_nodes)`.
210 /// * `typed_edges` – List of `(src, rel_type, dst)` directed edges.
211 pub fn from_typed_edges(
212 n_relations: usize,
213 n_node_types: usize,
214 node_types: &[usize],
215 typed_edges: &[(usize, usize, usize)], // (src, rel, dst)
216 ) -> Self {
217 let mut by_relation: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n_relations];
218 let mut by_node_type: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
219
220 for &(src, rel, dst) in typed_edges {
221 if rel < n_relations {
222 by_relation[rel].push((src, dst));
223 }
224 let src_t = node_types.get(src).copied().unwrap_or(0);
225 let dst_t = node_types.get(dst).copied().unwrap_or(0);
226 by_node_type.entry((src_t, dst_t)).or_default().push(src);
227 }
228
229 Self {
230 by_relation,
231 by_node_type,
232 n_relations,
233 n_node_types,
234 }
235 }
236
237 /// Return an iterator over all `(src, dst)` pairs for relation `r`.
238 pub fn edges_for_relation(&self, r: usize) -> &[(usize, usize)] {
239 self.by_relation.get(r).map(Vec::as_slice).unwrap_or(&[])
240 }
241
242 /// Return all source node indices going to nodes of type `dst_type` from
243 /// nodes of type `src_type`.
244 pub fn sources_by_type(&self, src_type: usize, dst_type: usize) -> &[usize] {
245 self.by_node_type
246 .get(&(src_type, dst_type))
247 .map(Vec::as_slice)
248 .unwrap_or(&[])
249 }
250
251 /// Total number of edges in the graph.
252 pub fn n_edges(&self) -> usize {
253 self.by_relation.iter().map(|v| v.len()).sum()
254 }
255}
256
257// ============================================================================
258// Tests
259// ============================================================================
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_relation_embedding_shape() {
267 let emb = RelationEmbedding::new(5, 16);
268 assert_eq!(emb.table.nrows(), 5);
269 assert_eq!(emb.table.ncols(), 16);
270 }
271
272 #[test]
273 fn test_relation_embedding_get() {
274 let emb = RelationEmbedding::new(3, 8);
275 assert!(emb.get(0).is_some());
276 assert!(emb.get(2).is_some());
277 assert!(emb.get(3).is_none());
278 }
279
280 #[test]
281 fn test_rotate_score_is_finite() {
282 let scorer = RotatEScoring::new(5, 3, 8);
283 let s = scorer.score_triple(0, 0, 1);
284 assert!(s.is_finite(), "RotatE score must be finite");
285 }
286
287 #[test]
288 fn test_rotate_self_score_is_highest() {
289 // For a model that perfectly maps h ∘ r = t the score should be 0.0.
290 // We just verify the score decreases when the tail does not match.
291 let scorer = RotatEScoring::new(4, 2, 4);
292 // Score same entity as both head and tail under identity rotation (r=0)
293 let s_same = scorer.score_triple(0, 0, 0);
294 let s_diff = scorer.score_triple(0, 0, 1);
295 // With a non-trivial random model this is not guaranteed to hold;
296 // we just check both are finite.
297 assert!(s_same.is_finite());
298 assert!(s_diff.is_finite());
299 }
300
301 #[test]
302 fn test_rotate_scorer_trait_object() {
303 let scorer: Box<dyn KgScorer> = Box::new(RotatEScoring::new(3, 2, 4));
304 let s = scorer.score(0, 0, 1);
305 assert!(s.is_finite());
306 }
307
308 #[test]
309 fn test_heterogeneous_adjacency_by_relation() {
310 let node_types = vec![0usize, 0, 1, 1];
311 let edges = vec![
312 (0usize, 0usize, 2usize), // rel 0
313 (1, 0, 3), // rel 0
314 (0, 1, 1), // rel 1
315 ];
316 let adj = HeterogeneousAdjacency::from_typed_edges(2, 2, &node_types, &edges);
317 assert_eq!(adj.by_relation.len(), 2);
318 assert_eq!(adj.edges_for_relation(0).len(), 2);
319 assert_eq!(adj.edges_for_relation(1).len(), 1);
320 }
321
322 #[test]
323 fn test_heterogeneous_adjacency_by_node_type() {
324 let node_types = vec![0usize, 0, 1, 1];
325 let edges = vec![(0usize, 0usize, 2usize), (1, 0, 3)];
326 let adj = HeterogeneousAdjacency::from_typed_edges(1, 2, &node_types, &edges);
327 // Both edges go from type 0 → type 1
328 let srcs = adj.sources_by_type(0, 1);
329 assert_eq!(srcs.len(), 2);
330 }
331
332 #[test]
333 fn test_heterogeneous_adjacency_n_edges() {
334 let node_types = vec![0usize; 5];
335 let edges: Vec<(usize, usize, usize)> = (0..4).map(|i| (i, 0, i + 1)).collect();
336 let adj = HeterogeneousAdjacency::from_typed_edges(1, 1, &node_types, &edges);
337 assert_eq!(adj.n_edges(), 4);
338 }
339}