Skip to main content

scirs2_graph/gnn/
relation_message.rs

1//! Relation-aware message passing primitives
2//!
3//! This module provides:
4//!
5//! - [`RelationEmbedding`] – a learnable embedding table for relation types.
6//! - [`RotatEScoring`] – RotatE scoring function (Sun et al. 2019), which
7//!   models each relation as a **rotation** in complex space:
8//!   ```text
9//!     score(h, r, t) = − ‖ h_e ∘ r_e − t_e ‖
10//!   ```
11//!   where `∘` denotes element-wise complex multiplication
12//!   (`r_e = exp(i θ_r)` so `|r_e| = 1`).
13//! - [`HeterogeneousAdjacency`] – compact adjacency storage that organises
14//!   edges both by relation type and by source/destination node type.
15
16use std::collections::HashMap;
17
18use scirs2_core::ndarray::{Array1, Array2};
19use scirs2_core::random::{Rng, RngExt};
20
21use crate::gnn::rgcn::KgScorer;
22
23// ============================================================================
24// Helpers
25// ============================================================================
26
27/// Xavier uniform initialisation.
28fn xavier_uniform(rows: usize, cols: usize) -> Array2<f64> {
29    let mut rng = scirs2_core::random::rng();
30    let limit = (6.0_f64 / (rows + cols) as f64).sqrt();
31    Array2::from_shape_fn((rows, cols), |_| rng.random::<f64>() * 2.0 * limit - limit)
32}
33
34// ============================================================================
35// RelationEmbedding
36// ============================================================================
37
38/// Learnable embedding table for relation types.
39///
40/// Embeddings are initialised with Xavier uniform and can be updated via
41/// gradient-based methods (integration with an optimiser is the caller's
42/// responsibility; this struct just holds the parameter tensor).
43#[derive(Debug, Clone)]
44pub struct RelationEmbedding {
45    /// Embedding matrix `(n_relations, dim)`.
46    pub table: Array2<f64>,
47    /// Number of relation types.
48    pub n_relations: usize,
49    /// Embedding dimensionality.
50    pub dim: usize,
51}
52
53impl RelationEmbedding {
54    /// Create a new relation embedding table with Xavier initialisation.
55    ///
56    /// # Arguments
57    /// * `n_relations` – Number of distinct relation types.
58    /// * `dim`         – Embedding dimensionality.
59    pub fn new(n_relations: usize, dim: usize) -> Self {
60        Self {
61            table: xavier_uniform(n_relations, dim),
62            n_relations,
63            dim,
64        }
65    }
66
67    /// Look up the embedding vector for relation `r`.
68    ///
69    /// Returns `None` if `r >= n_relations`.
70    pub fn get(&self, r: usize) -> Option<Array1<f64>> {
71        if r < self.n_relations {
72            Some(self.table.row(r).to_owned())
73        } else {
74            None
75        }
76    }
77}
78
79// ============================================================================
80// RotatEScoring
81// ============================================================================
82
83/// RotatE scoring model (Sun et al. 2019).
84///
85/// Entity embeddings are complex-valued: each entity `e` has a real part
86/// `entity_re[e]` and an imaginary part `entity_im[e]`, both of shape
87/// `(n_entities, dim/2)` (so the full complex embedding has `dim` real
88/// parameters per entity).
89///
90/// Relation embeddings are phase angles `θ_r ∈ (−π, π]` of shape
91/// `(n_relations, dim/2)`.  The unit-modulus rotation vector is
92/// `r_e = cos(θ_r) + i sin(θ_r)`.
93///
94/// Score function (negated L2 in complex space):
95/// ```text
96///   score(h, r, t) = − ‖ h_re ⊙ cos(θ_r) − h_im ⊙ sin(θ_r) − t_re
97///                       + i ( h_re ⊙ sin(θ_r) + h_im ⊙ cos(θ_r) − t_im ) ‖
98/// ```
99/// Higher scores (less negative) imply a more plausible triple.
100#[derive(Debug, Clone)]
101pub struct RotatEScoring {
102    /// Real part of entity embeddings `(n_entities, dim)`.
103    pub entity_re: Array2<f64>,
104    /// Imaginary part of entity embeddings `(n_entities, dim)`.
105    pub entity_im: Array2<f64>,
106    /// Relation phase angles `(n_relations, dim)`.
107    pub relation_phase: Array2<f64>,
108    /// Number of entities.
109    pub n_entities: usize,
110    /// Number of relation types.
111    pub n_relations: usize,
112    /// Half-embedding dimension (complex degree of freedom per entity/relation).
113    pub dim: usize,
114}
115
116impl RotatEScoring {
117    /// Create a new RotatE model with random initialisation.
118    ///
119    /// Entity embeddings are Xavier-initialised; relation phases are uniform
120    /// in `(−π, π]`.
121    ///
122    /// # Arguments
123    /// * `n_entities`  – Number of distinct entities.
124    /// * `n_relations` – Number of distinct relation types.
125    /// * `dim`         – Complex embedding half-dimension (full parameter count
126    ///   per entity is `2 * dim`).
127    pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Self {
128        let mut rng = scirs2_core::random::rng();
129        let entity_re = xavier_uniform(n_entities, dim);
130        let entity_im = xavier_uniform(n_entities, dim);
131        // Relation phases uniform in (−π, π]
132        let relation_phase = Array2::from_shape_fn((n_relations, dim), |_| {
133            (rng.random::<f64>() * 2.0 - 1.0) * std::f64::consts::PI
134        });
135        Self {
136            entity_re,
137            entity_im,
138            relation_phase,
139            n_entities,
140            n_relations,
141            dim,
142        }
143    }
144
145    /// Compute the RotatE score for triple `(h, r, t)`.
146    ///
147    /// Returns the **negated** L2 norm of `h ∘ r − t` in complex space.
148    /// Scores closer to 0 indicate a more plausible triple.
149    pub fn score_triple(&self, h: usize, r: usize, t: usize) -> f64 {
150        let h_re = self.entity_re.row(h);
151        let h_im = self.entity_im.row(h);
152        let t_re = self.entity_re.row(t);
153        let t_im = self.entity_im.row(t);
154        let phase = self.relation_phase.row(r);
155
156        let mut norm_sq = 0.0_f64;
157        for k in 0..self.dim {
158            let cos_r = phase[k].cos();
159            let sin_r = phase[k].sin();
160            // Real part of (h ∘ r − t)
161            let diff_re = h_re[k] * cos_r - h_im[k] * sin_r - t_re[k];
162            // Imaginary part of (h ∘ r − t)
163            let diff_im = h_re[k] * sin_r + h_im[k] * cos_r - t_im[k];
164            norm_sq += diff_re * diff_re + diff_im * diff_im;
165        }
166        -norm_sq.sqrt()
167    }
168}
169
170impl KgScorer for RotatEScoring {
171    fn score(&self, h: usize, r: usize, t: usize) -> f64 {
172        self.score_triple(h, r, t)
173    }
174}
175
176// ============================================================================
177// HeterogeneousAdjacency
178// ============================================================================
179
180/// Compact adjacency representation for heterogeneous graphs.
181///
182/// Stores edges both:
183/// - Per relation type: `by_relation[r]` = list of `(src, dst)` edges with
184///   relation type `r`.
185/// - Per `(src_node_type, dst_node_type)` pair: `by_node_type[(ts, tt)]` =
186///   list of all node indices `src` that have at least one edge to a node of
187///   type `tt`.
188#[derive(Debug, Clone)]
189pub struct HeterogeneousAdjacency {
190    /// Adjacency list per relation type: `by_relation[r]` = `Vec<(src, dst)>`.
191    pub by_relation: Vec<Vec<(usize, usize)>>,
192    /// Nodes per `(src_type, dst_type)` pair.
193    ///
194    /// `by_node_type[(src_type, dst_type)]` contains the *source* node indices
195    /// for edges going from a node of `src_type` to a node of `dst_type`.
196    pub by_node_type: HashMap<(usize, usize), Vec<usize>>,
197    /// Number of relation types
198    pub n_relations: usize,
199    /// Number of node types
200    pub n_node_types: usize,
201}
202
203impl HeterogeneousAdjacency {
204    /// Construct a [`HeterogeneousAdjacency`] from a list of typed edges.
205    ///
206    /// # Arguments
207    /// * `n_relations`  – Total number of distinct relation types.
208    /// * `n_node_types` – Total number of distinct node types.
209    /// * `node_types`   – Node type assignment for every node `(len = n_nodes)`.
210    /// * `typed_edges`  – List of `(src, rel_type, dst)` directed edges.
211    pub fn from_typed_edges(
212        n_relations: usize,
213        n_node_types: usize,
214        node_types: &[usize],
215        typed_edges: &[(usize, usize, usize)], // (src, rel, dst)
216    ) -> Self {
217        let mut by_relation: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n_relations];
218        let mut by_node_type: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
219
220        for &(src, rel, dst) in typed_edges {
221            if rel < n_relations {
222                by_relation[rel].push((src, dst));
223            }
224            let src_t = node_types.get(src).copied().unwrap_or(0);
225            let dst_t = node_types.get(dst).copied().unwrap_or(0);
226            by_node_type.entry((src_t, dst_t)).or_default().push(src);
227        }
228
229        Self {
230            by_relation,
231            by_node_type,
232            n_relations,
233            n_node_types,
234        }
235    }
236
237    /// Return an iterator over all `(src, dst)` pairs for relation `r`.
238    pub fn edges_for_relation(&self, r: usize) -> &[(usize, usize)] {
239        self.by_relation.get(r).map(Vec::as_slice).unwrap_or(&[])
240    }
241
242    /// Return all source node indices going to nodes of type `dst_type` from
243    /// nodes of type `src_type`.
244    pub fn sources_by_type(&self, src_type: usize, dst_type: usize) -> &[usize] {
245        self.by_node_type
246            .get(&(src_type, dst_type))
247            .map(Vec::as_slice)
248            .unwrap_or(&[])
249    }
250
251    /// Total number of edges in the graph.
252    pub fn n_edges(&self) -> usize {
253        self.by_relation.iter().map(|v| v.len()).sum()
254    }
255}
256
257// ============================================================================
258// Tests
259// ============================================================================
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_relation_embedding_shape() {
267        let emb = RelationEmbedding::new(5, 16);
268        assert_eq!(emb.table.nrows(), 5);
269        assert_eq!(emb.table.ncols(), 16);
270    }
271
272    #[test]
273    fn test_relation_embedding_get() {
274        let emb = RelationEmbedding::new(3, 8);
275        assert!(emb.get(0).is_some());
276        assert!(emb.get(2).is_some());
277        assert!(emb.get(3).is_none());
278    }
279
280    #[test]
281    fn test_rotate_score_is_finite() {
282        let scorer = RotatEScoring::new(5, 3, 8);
283        let s = scorer.score_triple(0, 0, 1);
284        assert!(s.is_finite(), "RotatE score must be finite");
285    }
286
287    #[test]
288    fn test_rotate_self_score_is_highest() {
289        // For a model that perfectly maps h ∘ r = t the score should be 0.0.
290        // We just verify the score decreases when the tail does not match.
291        let scorer = RotatEScoring::new(4, 2, 4);
292        // Score same entity as both head and tail under identity rotation (r=0)
293        let s_same = scorer.score_triple(0, 0, 0);
294        let s_diff = scorer.score_triple(0, 0, 1);
295        // With a non-trivial random model this is not guaranteed to hold;
296        // we just check both are finite.
297        assert!(s_same.is_finite());
298        assert!(s_diff.is_finite());
299    }
300
301    #[test]
302    fn test_rotate_scorer_trait_object() {
303        let scorer: Box<dyn KgScorer> = Box::new(RotatEScoring::new(3, 2, 4));
304        let s = scorer.score(0, 0, 1);
305        assert!(s.is_finite());
306    }
307
308    #[test]
309    fn test_heterogeneous_adjacency_by_relation() {
310        let node_types = vec![0usize, 0, 1, 1];
311        let edges = vec![
312            (0usize, 0usize, 2usize), // rel 0
313            (1, 0, 3),                // rel 0
314            (0, 1, 1),                // rel 1
315        ];
316        let adj = HeterogeneousAdjacency::from_typed_edges(2, 2, &node_types, &edges);
317        assert_eq!(adj.by_relation.len(), 2);
318        assert_eq!(adj.edges_for_relation(0).len(), 2);
319        assert_eq!(adj.edges_for_relation(1).len(), 1);
320    }
321
322    #[test]
323    fn test_heterogeneous_adjacency_by_node_type() {
324        let node_types = vec![0usize, 0, 1, 1];
325        let edges = vec![(0usize, 0usize, 2usize), (1, 0, 3)];
326        let adj = HeterogeneousAdjacency::from_typed_edges(1, 2, &node_types, &edges);
327        // Both edges go from type 0 → type 1
328        let srcs = adj.sources_by_type(0, 1);
329        assert_eq!(srcs.len(), 2);
330    }
331
332    #[test]
333    fn test_heterogeneous_adjacency_n_edges() {
334        let node_types = vec![0usize; 5];
335        let edges: Vec<(usize, usize, usize)> = (0..4).map(|i| (i, 0, i + 1)).collect();
336        let adj = HeterogeneousAdjacency::from_typed_edges(1, 1, &node_types, &edges);
337        assert_eq!(adj.n_edges(), 4);
338    }
339}