Skip to main content

scirs2_graph/
heterogeneous.rs

1//! Heterogeneous graphs with multiple node and edge types
2//!
3//! A *heterogeneous* (or *typed*) graph contains nodes and edges belonging to
4//! distinct *types*.  Each node type represents a different entity class (e.g.
5//! `"user"`, `"item"`, `"category"`) and each edge type captures a specific
6//! relation between two entity classes (e.g. `"user" --buys--> "item"`).
7//!
8//! This representation is standard in *relational* machine learning and
9//! *knowledge graphs*.  See also [`crate::knowledge_graph`] for specialised
10//! KGE embedding models.
11//!
12//! ## Architecture
13//!
14//! ```text
15//!   HeteroGraph
16//!   ├── node_types : HashMap<String, Vec<NodeId>>
17//!   └── edge_types : HashMap<HeteroEdgeType, Vec<(NodeId, NodeId)>>
18//!                         (src_type, relation, dst_type)
19//! ```
20//!
21//! Node identifiers are *global* (unique across all types in the same graph).
22//! The type membership is stored separately in `node_types`.
23//!
24//! ## Key operations
25//!
26//! | Function | Description |
27//! |----------|-------------|
28//! [`HeteroGraph::add_node`] | Register a node under a type |
29//! [`HeteroGraph::add_edge`] | Register a typed relation |
30//! [`type_adjacency`] | Build a [`CsrMatrix`] for one edge type |
31//! [`meta_path_adjacency`] | Chain edge types via meta-path multiplication |
32//! [`hetero_message_passing`] | Aggregate neighbour representations per type |
33
34use std::collections::{HashMap, HashSet};
35
36use scirs2_core::ndarray::Array2;
37
38use crate::error::{GraphError, Result};
39use crate::gnn::CsrMatrix;
40
41// Re-export NodeId from attributed_graph for convenience
42pub use crate::attributed_graph::NodeId;
43
44// ============================================================================
45// HeteroEdgeType
46// ============================================================================
47
48/// Describes a typed, directed edge as a `(source_type, relation, destination_type)` triple.
49///
50/// This mirrors the *canonical form* used in heterogeneous GNN literature (HAN,
51/// HGT, etc.) and knowledge-graph reasoning.
52///
53/// # Example
54///
55/// ```
56/// use scirs2_graph::heterogeneous::HeteroEdgeType;
57///
58/// let et = HeteroEdgeType::new("user", "rates", "item");
59/// assert_eq!(et.relation, "rates");
60/// ```
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub struct HeteroEdgeType {
63    /// Entity type of the source node.
64    pub source_type: String,
65    /// Name of the relation.
66    pub relation: String,
67    /// Entity type of the destination node.
68    pub destination_type: String,
69}
70
71impl HeteroEdgeType {
72    /// Create a new edge type.
73    pub fn new(
74        source_type: impl Into<String>,
75        relation: impl Into<String>,
76        destination_type: impl Into<String>,
77    ) -> Self {
78        Self {
79            source_type: source_type.into(),
80            relation: relation.into(),
81            destination_type: destination_type.into(),
82        }
83    }
84
85    /// Canonical string key `"src_type/relation/dst_type"`.
86    pub fn key(&self) -> String {
87        format!(
88            "{}/{}/{}",
89            self.source_type, self.relation, self.destination_type
90        )
91    }
92
93    /// Return the reversed edge type (swaps source and destination).
94    pub fn reversed(&self) -> Self {
95        Self {
96            source_type: self.destination_type.clone(),
97            relation: format!("rev_{}", self.relation),
98            destination_type: self.source_type.clone(),
99        }
100    }
101}
102
103impl std::fmt::Display for HeteroEdgeType {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(
106            f,
107            "({}) --{}--> ({})",
108            self.source_type, self.relation, self.destination_type
109        )
110    }
111}
112
113// ============================================================================
114// HeteroGraph
115// ============================================================================
116
117/// A heterogeneous graph with multiple node and edge types.
118///
119/// ## Node identifiers
120///
121/// All nodes share a single global [`NodeId`] namespace.  Use
122/// [`HeteroGraph::add_node`] to assign a node to a particular type; the
123/// returned [`NodeId`] is globally unique.
124///
125/// ## Edge types
126///
127/// Edges are grouped by [`HeteroEdgeType`].  Each group is an unordered list
128/// of `(src_id, dst_id)` pairs; duplicate edges *are* allowed (useful for
129/// multigraphs).
130///
131/// # Example
132///
133/// ```
134/// use scirs2_graph::heterogeneous::HeteroGraph;
135///
136/// let mut g = HeteroGraph::new();
137/// let u0 = g.add_node("user", 0).unwrap();
138/// let i0 = g.add_node("item", 0).unwrap();
139/// g.add_edge("user", "buys", "item", u0, i0).unwrap();
140///
141/// assert_eq!(g.node_count(), 2);
142/// assert_eq!(g.edge_count(), 1);
143/// ```
144#[derive(Debug, Clone)]
145pub struct HeteroGraph {
146    /// Next globally unique node id counter.
147    next_node_id: usize,
148    /// `node_types["user"]` → list of global NodeIds belonging to "user" type.
149    node_types: HashMap<String, Vec<NodeId>>,
150    /// Reverse mapping: NodeId → type name.
151    node_type_of: HashMap<NodeId, String>,
152    /// `edge_types[HeteroEdgeType]` → ordered list of (src_id, dst_id) pairs.
153    edge_types: HashMap<HeteroEdgeType, Vec<(NodeId, NodeId)>>,
154}
155
156impl Default for HeteroGraph {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl HeteroGraph {
163    /// Create an empty heterogeneous graph.
164    pub fn new() -> Self {
165        Self {
166            next_node_id: 0,
167            node_types: HashMap::new(),
168            node_type_of: HashMap::new(),
169            edge_types: HashMap::new(),
170        }
171    }
172
173    // -----------------------------------------------------------------------
174    // Mutation
175    // -----------------------------------------------------------------------
176
177    /// Register a new node of type `type_name`.
178    ///
179    /// `_hint` is an optional user-supplied integer label (e.g. a database
180    /// primary key); it is stored purely for user convenience and does **not**
181    /// affect the global [`NodeId`] that is returned.
182    ///
183    /// # Errors
184    ///
185    /// Returns [`GraphError::InvalidParameter`] if `type_name` is empty.
186    pub fn add_node(&mut self, type_name: impl Into<String>, _hint: usize) -> Result<NodeId> {
187        let type_name = type_name.into();
188        if type_name.is_empty() {
189            return Err(GraphError::invalid_parameter(
190                "type_name",
191                "<empty>",
192                "non-empty node type",
193            ));
194        }
195        let id = NodeId(self.next_node_id);
196        self.next_node_id += 1;
197        self.node_types
198            .entry(type_name.clone())
199            .or_default()
200            .push(id);
201        self.node_type_of.insert(id, type_name);
202        Ok(id)
203    }
204
205    /// Add a directed typed edge `src_id --relation--> dst_id`.
206    ///
207    /// Both nodes must already be present in the graph and must belong to
208    /// `src_type` and `dst_type` respectively.
209    ///
210    /// # Errors
211    ///
212    /// * [`GraphError::NodeNotFound`]  – node not registered.
213    /// * [`GraphError::InvalidParameter`] – node belongs to the wrong type.
214    pub fn add_edge(
215        &mut self,
216        src_type: impl Into<String>,
217        relation: impl Into<String>,
218        dst_type: impl Into<String>,
219        src_id: NodeId,
220        dst_id: NodeId,
221    ) -> Result<()> {
222        let src_type = src_type.into();
223        let dst_type = dst_type.into();
224        let relation = relation.into();
225
226        // Validate node existence
227        let actual_src_type = self
228            .node_type_of
229            .get(&src_id)
230            .ok_or_else(|| GraphError::node_not_found(src_id.0))?;
231
232        let actual_dst_type = self
233            .node_type_of
234            .get(&dst_id)
235            .ok_or_else(|| GraphError::node_not_found(dst_id.0))?;
236
237        // Validate type consistency
238        if actual_src_type != &src_type {
239            return Err(GraphError::InvalidParameter {
240                param: "src_type".to_string(),
241                value: format!("node {} has type '{}'", src_id.0, actual_src_type),
242                expected: format!("'{src_type}'"),
243                context: "HeteroGraph::add_edge".to_string(),
244            });
245        }
246        if actual_dst_type != &dst_type {
247            return Err(GraphError::InvalidParameter {
248                param: "dst_type".to_string(),
249                value: format!("node {} has type '{}'", dst_id.0, actual_dst_type),
250                expected: format!("'{dst_type}'"),
251                context: "HeteroGraph::add_edge".to_string(),
252            });
253        }
254
255        let et = HeteroEdgeType::new(src_type, relation, dst_type);
256        self.edge_types
257            .entry(et)
258            .or_default()
259            .push((src_id, dst_id));
260        Ok(())
261    }
262
263    // -----------------------------------------------------------------------
264    // Query
265    // -----------------------------------------------------------------------
266
267    /// Total number of registered nodes (across all types).
268    pub fn node_count(&self) -> usize {
269        self.next_node_id
270    }
271
272    /// Total number of registered edges (across all edge types).
273    pub fn edge_count(&self) -> usize {
274        self.edge_types.values().map(|v| v.len()).sum()
275    }
276
277    /// List all node types present in the graph.
278    pub fn node_type_names(&self) -> Vec<&str> {
279        self.node_types.keys().map(String::as_str).collect()
280    }
281
282    /// List all edge types present in the graph.
283    pub fn edge_type_list(&self) -> Vec<&HeteroEdgeType> {
284        self.edge_types.keys().collect()
285    }
286
287    /// Return the nodes belonging to `type_name`.
288    pub fn nodes_of_type(&self, type_name: &str) -> &[NodeId] {
289        self.node_types
290            .get(type_name)
291            .map(Vec::as_slice)
292            .unwrap_or(&[])
293    }
294
295    /// Return the type name of a node.
296    pub fn type_of(&self, node: NodeId) -> Option<&str> {
297        self.node_type_of.get(&node).map(String::as_str)
298    }
299
300    /// Return all edges of a specific [`HeteroEdgeType`].
301    ///
302    /// Returns an empty slice if the edge type has no edges.
303    pub fn edges_of_type(&self, et: &HeteroEdgeType) -> &[(NodeId, NodeId)] {
304        self.edge_types.get(et).map(Vec::as_slice).unwrap_or(&[])
305    }
306
307    /// Return the out-neighbours of `node` under a specific edge type.
308    ///
309    /// Complexity: O(number of edges of that type).
310    pub fn out_neighbors_typed(&self, node: NodeId, et: &HeteroEdgeType) -> Vec<NodeId> {
311        self.edge_types
312            .get(et)
313            .map(|edges| {
314                edges
315                    .iter()
316                    .filter_map(|&(s, d)| if s == node { Some(d) } else { None })
317                    .collect()
318            })
319            .unwrap_or_default()
320    }
321
322    /// Return all (edge_type, neighbours) pairs for `node`.
323    ///
324    /// Useful for heterogeneous message passing.
325    pub fn all_out_neighbors_typed(&self, node: NodeId) -> Vec<(&HeteroEdgeType, Vec<NodeId>)> {
326        self.edge_types
327            .iter()
328            .filter_map(|(et, edges)| {
329                let nbrs: Vec<NodeId> = edges
330                    .iter()
331                    .filter_map(|&(s, d)| if s == node { Some(d) } else { None })
332                    .collect();
333                if nbrs.is_empty() {
334                    None
335                } else {
336                    Some((et, nbrs))
337                }
338            })
339            .collect()
340    }
341
342    /// Check whether a node is registered in the graph.
343    pub fn contains_node(&self, node: NodeId) -> bool {
344        self.node_type_of.contains_key(&node)
345    }
346
347    /// Check whether a typed edge exists.
348    pub fn has_typed_edge(&self, et: &HeteroEdgeType, src: NodeId, dst: NodeId) -> bool {
349        self.edge_types
350            .get(et)
351            .map(|edges| edges.contains(&(src, dst)))
352            .unwrap_or(false)
353    }
354}
355
356// ============================================================================
357// type_adjacency
358// ============================================================================
359
360/// Build the adjacency matrix (as a [`CsrMatrix`]) for a specific edge type.
361///
362/// The matrix has dimensions `(|src_nodes|, |dst_nodes|)` where the row and
363/// column orderings follow the insertion order of the respective node-type
364/// lists.
365///
366/// # Arguments
367///
368/// * `graph` – the heterogeneous graph.
369/// * `edge_type` – which typed edge to materialise.
370///
371/// # Errors
372///
373/// Returns [`GraphError::InvalidParameter`] if the source or destination type
374/// referenced by `edge_type` does not exist in the graph.
375///
376/// # Example
377///
378/// ```
379/// use scirs2_graph::heterogeneous::{HeteroGraph, HeteroEdgeType, type_adjacency};
380///
381/// let mut g = HeteroGraph::new();
382/// let u0 = g.add_node("user", 0).unwrap();
383/// let i0 = g.add_node("item", 0).unwrap();
384/// let i1 = g.add_node("item", 1).unwrap();
385/// g.add_edge("user", "buys", "item", u0, i0).unwrap();
386/// g.add_edge("user", "buys", "item", u0, i1).unwrap();
387///
388/// let et = HeteroEdgeType::new("user", "buys", "item");
389/// let adj = type_adjacency(&g, &et).unwrap();
390/// assert_eq!(adj.n_rows, 1);   // 1 user
391/// assert_eq!(adj.n_cols, 2);   // 2 items
392/// ```
393pub fn type_adjacency(graph: &HeteroGraph, edge_type: &HeteroEdgeType) -> Result<CsrMatrix> {
394    let src_nodes = graph.nodes_of_type(&edge_type.source_type);
395    let dst_nodes = graph.nodes_of_type(&edge_type.destination_type);
396
397    if src_nodes.is_empty() {
398        return Err(GraphError::InvalidParameter {
399            param: "edge_type.source_type".to_string(),
400            value: edge_type.source_type.clone(),
401            expected: "a type with at least one node".to_string(),
402            context: "type_adjacency".to_string(),
403        });
404    }
405    if dst_nodes.is_empty() {
406        return Err(GraphError::InvalidParameter {
407            param: "edge_type.destination_type".to_string(),
408            value: edge_type.destination_type.clone(),
409            expected: "a type with at least one node".to_string(),
410            context: "type_adjacency".to_string(),
411        });
412    }
413
414    // Build local index maps for row and column lookups
415    let src_index: HashMap<NodeId, usize> = src_nodes
416        .iter()
417        .enumerate()
418        .map(|(i, &nid)| (nid, i))
419        .collect();
420    let dst_index: HashMap<NodeId, usize> = dst_nodes
421        .iter()
422        .enumerate()
423        .map(|(i, &nid)| (nid, i))
424        .collect();
425
426    let edges = graph.edges_of_type(edge_type);
427
428    // Collect COO triples
429    let mut coo: Vec<(usize, usize, f64)> = Vec::with_capacity(edges.len());
430    for &(src, dst) in edges {
431        let r = match src_index.get(&src) {
432            Some(&i) => i,
433            None => continue, // edge references node not in expected type; skip
434        };
435        let c = match dst_index.get(&dst) {
436            Some(&j) => j,
437            None => continue,
438        };
439        coo.push((r, c, 1.0));
440    }
441
442    CsrMatrix::from_coo(src_nodes.len(), dst_nodes.len(), &coo).map_err(|e| {
443        GraphError::InvalidGraph(format!("type_adjacency CsrMatrix::from_coo failed: {e}"))
444    })
445}
446
447// ============================================================================
448// meta_path_adjacency
449// ============================================================================
450
451/// Compute the meta-path similarity matrix for a sequence of node types.
452///
453/// A *meta-path* is a sequence of node types connected by edges, e.g.
454/// `["user", "item", "user"]` captures the "users who bought the same item"
455/// relation.  The resulting matrix contains path counts (before row
456/// normalisation) of shape `(|first_type|, |last_type|)`.
457///
458/// Internally, each consecutive pair of types yields an adjacency matrix which
459/// is multiplied together to produce the final result.
460///
461/// # Arguments
462///
463/// * `graph` – the heterogeneous graph.
464/// * `meta_path` – ordered sequence of node-type names of length ≥ 2.
465///
466/// # Errors
467///
468/// * [`GraphError::InvalidParameter`] – meta-path has fewer than 2 types.
469/// * [`GraphError::InvalidParameter`] – a required edge type is absent.
470/// * [`GraphError::InvalidParameter`] – no edge type connects adjacent types in the path.
471///
472/// # Notes
473///
474/// When the same source–destination type pair is connected by multiple
475/// relations, the function sums over **all** of them.  If you need to select a
476/// specific relation, restrict the graph to that edge type first.
477///
478/// # Example
479///
480/// ```no_run
481/// use scirs2_graph::heterogeneous::{HeteroGraph, meta_path_adjacency};
482///
483/// let mut g = HeteroGraph::new();
484/// let u0 = g.add_node("user", 0).unwrap();
485/// let u1 = g.add_node("user", 1).unwrap();
486/// let i0 = g.add_node("item", 0).unwrap();
487/// g.add_edge("user", "buys", "item", u0, i0).unwrap();
488/// g.add_edge("user", "buys", "item", u1, i0).unwrap();
489///
490/// // Meta-path user→item→user: both users bought item 0, so they share paths
491/// let sim = meta_path_adjacency(&g, &["user", "item", "user"]).unwrap();
492/// assert_eq!(sim.shape(), &[2, 2]);
493/// ```
494pub fn meta_path_adjacency(graph: &HeteroGraph, meta_path: &[&str]) -> Result<Array2<f64>> {
495    if meta_path.len() < 2 {
496        return Err(GraphError::InvalidParameter {
497            param: "meta_path".to_string(),
498            value: format!("length={}", meta_path.len()),
499            expected: "at least 2 node types".to_string(),
500            context: "meta_path_adjacency".to_string(),
501        });
502    }
503
504    // Collect all edge types, grouped by (src_type, dst_type)
505    let mut type_pair_edges: HashMap<(&str, &str), Vec<&HeteroEdgeType>> = HashMap::new();
506    for et in graph.edge_types.keys() {
507        type_pair_edges
508            .entry((et.source_type.as_str(), et.destination_type.as_str()))
509            .or_default()
510            .push(et);
511    }
512
513    // Build the adjacency matrix for one step of the meta-path.
514    // Returns a dense (n_src × n_dst) matrix, summing over all edge types
515    // that connect the given type pair.
516    let step_matrix = |src_type: &str, dst_type: &str| -> Result<Array2<f64>> {
517        let src_nodes = graph.nodes_of_type(src_type);
518        let dst_nodes = graph.nodes_of_type(dst_type);
519
520        if src_nodes.is_empty() || dst_nodes.is_empty() {
521            // Return zero matrix
522            return Ok(Array2::zeros((
523                src_nodes.len().max(1),
524                dst_nodes.len().max(1),
525            )));
526        }
527
528        let src_index: HashMap<NodeId, usize> =
529            src_nodes.iter().enumerate().map(|(i, &n)| (n, i)).collect();
530        let dst_index: HashMap<NodeId, usize> =
531            dst_nodes.iter().enumerate().map(|(i, &n)| (n, i)).collect();
532
533        let mut mat = Array2::<f64>::zeros((src_nodes.len(), dst_nodes.len()));
534
535        let ets = type_pair_edges.get(&(src_type, dst_type));
536        if let Some(edge_types) = ets {
537            for &et in edge_types {
538                for &(s, d) in graph.edges_of_type(et) {
539                    if let (Some(&r), Some(&c)) = (src_index.get(&s), dst_index.get(&d)) {
540                        mat[[r, c]] += 1.0;
541                    }
542                }
543            }
544        }
545        Ok(mat)
546    };
547
548    // Initial matrix for first step
549    let mut result = step_matrix(meta_path[0], meta_path[1])?;
550
551    // Multiply through remaining steps
552    for window in meta_path.windows(2).skip(1) {
553        let next = step_matrix(window[0], window[1])?;
554        // result: (n_first × n_mid), next: (n_mid × n_next)
555        // product: (n_first × n_next)
556        let (r_rows, r_cols) = (result.shape()[0], result.shape()[1]);
557        let (n_rows, n_cols) = (next.shape()[0], next.shape()[1]);
558        if r_cols != n_rows {
559            return Err(GraphError::InvalidParameter {
560                param: "meta_path".to_string(),
561                value: format!(
562                    "dimension mismatch: {} cols vs {} rows at step",
563                    r_cols, n_rows
564                ),
565                expected: "matching intermediate dimensions".to_string(),
566                context: "meta_path_adjacency matrix multiply".to_string(),
567            });
568        }
569        let mut product = Array2::<f64>::zeros((r_rows, n_cols));
570        for i in 0..r_rows {
571            for k in 0..r_cols {
572                let rv = result[[i, k]];
573                if rv == 0.0 {
574                    continue;
575                }
576                for j in 0..n_cols {
577                    product[[i, j]] += rv * next[[k, j]];
578                }
579            }
580        }
581        result = product;
582    }
583
584    Ok(result)
585}
586
587// ============================================================================
588// Heterogeneous message passing
589// ============================================================================
590
591/// Message-passing result for a single edge type.
592#[derive(Debug, Clone)]
593pub struct TypedMessageResult {
594    /// The edge type this result corresponds to.
595    pub edge_type: HeteroEdgeType,
596    /// Aggregated features for each destination node (in type-order).
597    /// Shape: `(n_dst_nodes, feature_dim)`.
598    pub aggregated: Array2<f64>,
599}
600
601/// Propagate and aggregate node feature vectors across all edge types in one
602/// message-passing step.
603///
604/// For every registered edge type `(src_type, rel, dst_type)`:
605///
606/// 1. Look up the feature rows for all source nodes (from `node_features`).
607/// 2. Aggregate (sum) incoming features for each destination node.
608/// 3. Return an [`Array2<f64>`] of shape `(n_dst_nodes, feature_dim)`.
609///
610/// # Arguments
611///
612/// * `graph` – the heterogeneous graph.
613/// * `node_features` – map from [`NodeId`] to a fixed-length feature vector.
614///   Nodes absent from this map contribute zero vectors.
615/// * `feature_dim` – length of each feature vector.
616///
617/// # Errors
618///
619/// Returns [`GraphError::InvalidParameter`] if any feature vector has a
620/// different length from `feature_dim`.
621///
622/// # Example
623///
624/// ```
625/// use std::collections::HashMap;
626/// use scirs2_graph::heterogeneous::{HeteroGraph, hetero_message_passing};
627/// use scirs2_graph::attributed_graph::NodeId;
628///
629/// let mut g = HeteroGraph::new();
630/// let u0 = g.add_node("user", 0).unwrap();
631/// let i0 = g.add_node("item", 0).unwrap();
632/// g.add_edge("user", "buys", "item", u0, i0).unwrap();
633///
634/// let mut features = HashMap::new();
635/// features.insert(u0, vec![1.0, 2.0]);
636///
637/// let results = hetero_message_passing(&g, &features, 2).unwrap();
638/// assert_eq!(results.len(), 1);
639/// // Aggregated features for "item" node 0: [1.0, 2.0] (from u0)
640/// assert!((results[0].aggregated[[0, 0]] - 1.0).abs() < 1e-9);
641/// ```
642pub fn hetero_message_passing(
643    graph: &HeteroGraph,
644    node_features: &HashMap<NodeId, Vec<f64>>,
645    feature_dim: usize,
646) -> Result<Vec<TypedMessageResult>> {
647    // Validate all provided feature vectors
648    for (nid, fv) in node_features {
649        if fv.len() != feature_dim {
650            return Err(GraphError::InvalidParameter {
651                param: format!("node_features[{}]", nid.0),
652                value: format!("len={}", fv.len()),
653                expected: format!("feature_dim={feature_dim}"),
654                context: "hetero_message_passing".to_string(),
655            });
656        }
657    }
658
659    let zero_feat = vec![0.0f64; feature_dim];
660    let mut results = Vec::new();
661
662    for (et, edges) in &graph.edge_types {
663        if edges.is_empty() {
664            continue;
665        }
666
667        let dst_nodes = graph.nodes_of_type(&et.destination_type);
668        if dst_nodes.is_empty() {
669            continue;
670        }
671
672        let dst_index: HashMap<NodeId, usize> =
673            dst_nodes.iter().enumerate().map(|(i, &n)| (n, i)).collect();
674
675        let mut aggregated = Array2::<f64>::zeros((dst_nodes.len(), feature_dim));
676
677        for &(src, dst) in edges {
678            let feat = node_features
679                .get(&src)
680                .map(Vec::as_slice)
681                .unwrap_or(zero_feat.as_slice());
682
683            if let Some(&dst_row) = dst_index.get(&dst) {
684                for (j, &v) in feat.iter().enumerate() {
685                    aggregated[[dst_row, j]] += v;
686                }
687            }
688        }
689
690        results.push(TypedMessageResult {
691            edge_type: et.clone(),
692            aggregated,
693        });
694    }
695
696    Ok(results)
697}
698
699// ============================================================================
700// Utility helpers
701// ============================================================================
702
703/// Compute the degree (number of outgoing edges) for each node under a
704/// specific edge type.
705///
706/// Returns a `HashMap<NodeId, usize>` containing only nodes with degree ≥ 1.
707pub fn typed_out_degree(graph: &HeteroGraph, edge_type: &HeteroEdgeType) -> HashMap<NodeId, usize> {
708    let mut deg: HashMap<NodeId, usize> = HashMap::new();
709    for &(src, _dst) in graph.edges_of_type(edge_type) {
710        *deg.entry(src).or_insert(0) += 1;
711    }
712    deg
713}
714
715/// Compute the in-degree for each node under a specific edge type.
716pub fn typed_in_degree(graph: &HeteroGraph, edge_type: &HeteroEdgeType) -> HashMap<NodeId, usize> {
717    let mut deg: HashMap<NodeId, usize> = HashMap::new();
718    for &(_src, dst) in graph.edges_of_type(edge_type) {
719        *deg.entry(dst).or_insert(0) += 1;
720    }
721    deg
722}
723
724/// Return all unique node types reachable from `start_type` following the
725/// given edge types.
726///
727/// This performs a BFS over the *type graph*.
728pub fn reachable_types(graph: &HeteroGraph, start_type: &str) -> HashSet<String> {
729    let mut visited: HashSet<String> = HashSet::new();
730    let mut queue = std::collections::VecDeque::new();
731    queue.push_back(start_type.to_string());
732
733    while let Some(current) = queue.pop_front() {
734        if !visited.insert(current.clone()) {
735            continue;
736        }
737        for et in graph.edge_types.keys() {
738            if et.source_type == current && !visited.contains(&et.destination_type) {
739                queue.push_back(et.destination_type.clone());
740            }
741        }
742    }
743    visited
744}
745
746/// Convert the heterogeneous graph to a homogeneous `Vec<(usize, usize)>` edge list
747/// by stripping type information.
748pub fn to_homogeneous_edge_list(graph: &HeteroGraph) -> Vec<(usize, usize)> {
749    graph
750        .edge_types
751        .values()
752        .flat_map(|edges| edges.iter().map(|&(s, d)| (s.0, d.0)))
753        .collect()
754}
755
756// ============================================================================
757// Tests
758// ============================================================================
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763
764    // ------------------------------------------------------------------
765    // Helpers
766    // ------------------------------------------------------------------
767
768    /// Build the canonical user-item-tag knowledge graph:
769    ///
770    /// ```text
771    /// u0 --buys-->  i0 --has_tag--> t0
772    /// u1 --buys-->  i0
773    /// u0 --buys-->  i1 --has_tag--> t1
774    /// ```
775    fn make_graph() -> (HeteroGraph, NodeId, NodeId, NodeId, NodeId, NodeId) {
776        let mut g = HeteroGraph::new();
777        let u0 = g.add_node("user", 0).unwrap();
778        let u1 = g.add_node("user", 1).unwrap();
779        let i0 = g.add_node("item", 0).unwrap();
780        let i1 = g.add_node("item", 1).unwrap();
781        let t0 = g.add_node("tag", 0).unwrap();
782        g.add_edge("user", "buys", "item", u0, i0).unwrap();
783        g.add_edge("user", "buys", "item", u1, i0).unwrap();
784        g.add_edge("user", "buys", "item", u0, i1).unwrap();
785        g.add_edge("item", "has_tag", "tag", i0, t0).unwrap();
786        (g, u0, u1, i0, i1, t0)
787    }
788
789    // ------------------------------------------------------------------
790    // HeteroEdgeType
791    // ------------------------------------------------------------------
792
793    #[test]
794    fn test_edge_type_key() {
795        let et = HeteroEdgeType::new("user", "buys", "item");
796        assert_eq!(et.key(), "user/buys/item");
797    }
798
799    #[test]
800    fn test_edge_type_reversed() {
801        let et = HeteroEdgeType::new("user", "buys", "item");
802        let rev = et.reversed();
803        assert_eq!(rev.source_type, "item");
804        assert_eq!(rev.relation, "rev_buys");
805        assert_eq!(rev.destination_type, "user");
806    }
807
808    // ------------------------------------------------------------------
809    // HeteroGraph construction
810    // ------------------------------------------------------------------
811
812    #[test]
813    fn test_basic_construction() {
814        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
815        assert_eq!(g.node_count(), 5);
816        assert_eq!(g.edge_count(), 4);
817    }
818
819    #[test]
820    fn test_nodes_of_type() {
821        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
822        assert_eq!(g.nodes_of_type("user").len(), 2);
823        assert_eq!(g.nodes_of_type("item").len(), 2);
824        assert_eq!(g.nodes_of_type("tag").len(), 1);
825        assert_eq!(g.nodes_of_type("nonexistent").len(), 0);
826    }
827
828    #[test]
829    fn test_add_node_empty_type_fails() {
830        let mut g = HeteroGraph::new();
831        assert!(g.add_node("", 0).is_err());
832    }
833
834    #[test]
835    fn test_add_edge_wrong_type_fails() {
836        let mut g = HeteroGraph::new();
837        let u0 = g.add_node("user", 0).unwrap();
838        let i0 = g.add_node("item", 0).unwrap();
839        // Declare wrong src type
840        assert!(g.add_edge("item", "buys", "item", u0, i0).is_err());
841    }
842
843    #[test]
844    fn test_add_edge_unknown_node_fails() {
845        let mut g = HeteroGraph::new();
846        g.add_node("user", 0).unwrap();
847        let ghost = NodeId(999);
848        assert!(g
849            .add_edge("user", "buys", "item", ghost, NodeId(0))
850            .is_err());
851    }
852
853    #[test]
854    fn test_type_of_node() {
855        let (g, u0, u1, i0, _i1, t0) = make_graph();
856        assert_eq!(g.type_of(u0), Some("user"));
857        assert_eq!(g.type_of(u1), Some("user"));
858        assert_eq!(g.type_of(i0), Some("item"));
859        assert_eq!(g.type_of(t0), Some("tag"));
860        assert_eq!(g.type_of(NodeId(999)), None);
861    }
862
863    #[test]
864    fn test_has_typed_edge() {
865        let (g, u0, u1, i0, i1, _t0) = make_graph();
866        let et = HeteroEdgeType::new("user", "buys", "item");
867        assert!(g.has_typed_edge(&et, u0, i0));
868        assert!(g.has_typed_edge(&et, u1, i0));
869        assert!(g.has_typed_edge(&et, u0, i1));
870        assert!(!g.has_typed_edge(&et, u1, i1)); // this edge does not exist
871    }
872
873    #[test]
874    fn test_out_neighbors_typed() {
875        let (g, u0, _u1, i0, i1, _t0) = make_graph();
876        let et = HeteroEdgeType::new("user", "buys", "item");
877        let mut nbrs = g.out_neighbors_typed(u0, &et);
878        nbrs.sort_by_key(|n| n.0);
879        assert_eq!(nbrs, vec![i0, i1]);
880    }
881
882    // ------------------------------------------------------------------
883    // type_adjacency
884    // ------------------------------------------------------------------
885
886    #[test]
887    fn test_type_adjacency_shape() {
888        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
889        let et = HeteroEdgeType::new("user", "buys", "item");
890        let adj = type_adjacency(&g, &et).unwrap();
891        // 2 users × 2 items
892        assert_eq!(adj.n_rows, 2);
893        assert_eq!(adj.n_cols, 2);
894    }
895
896    #[test]
897    fn test_type_adjacency_values() {
898        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
899        let et = HeteroEdgeType::new("user", "buys", "item");
900        let adj = type_adjacency(&g, &et).unwrap();
901
902        // Convert to dense for inspection
903        let mut dense = Array2::<f64>::zeros((adj.n_rows, adj.n_cols));
904        for row in 0..adj.n_rows {
905            let start = adj.indptr[row];
906            let end = adj.indptr[row + 1];
907            for idx in start..end {
908                let col = adj.indices[idx];
909                dense[[row, col]] += adj.data[idx];
910            }
911        }
912
913        // u0 (row 0) buys i0 and i1: two 1s in row 0
914        let row0_sum: f64 = dense.row(0).sum();
915        assert!((row0_sum - 2.0).abs() < 1e-12);
916
917        // u1 (row 1) only buys i0: one 1 in row 1
918        let row1_sum: f64 = dense.row(1).sum();
919        assert!((row1_sum - 1.0).abs() < 1e-12);
920    }
921
922    #[test]
923    fn test_type_adjacency_missing_type() {
924        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
925        let et = HeteroEdgeType::new("ghost", "buys", "item");
926        assert!(type_adjacency(&g, &et).is_err());
927    }
928
929    // ------------------------------------------------------------------
930    // meta_path_adjacency
931    // ------------------------------------------------------------------
932
933    #[test]
934    fn test_meta_path_too_short() {
935        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
936        assert!(meta_path_adjacency(&g, &["user"]).is_err());
937    }
938
939    #[test]
940    fn test_meta_path_user_item_user() {
941        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
942        // user→item→user via "buys"
943        // We need the reverse edge type as well
944        let mut g2 = g.clone();
945        // Add reverse "bought_by" edges
946        let users: Vec<NodeId> = g2.nodes_of_type("user").to_vec();
947        let items: Vec<NodeId> = g2.nodes_of_type("item").to_vec();
948        let buys_et = HeteroEdgeType::new("user", "buys", "item");
949        let buys_edges: Vec<(NodeId, NodeId)> = g2.edges_of_type(&buys_et).to_vec();
950        for (src, dst) in buys_edges {
951            // user → item forward already in g2
952            // add item → user reverse
953            if users.contains(&src) && items.contains(&dst) {
954                g2.add_edge("item", "bought_by", "user", dst, src).unwrap();
955            }
956        }
957
958        // Meta-path: user --buys--> item --bought_by--> user
959        let sim = meta_path_adjacency(&g2, &["user", "item", "user"]).unwrap();
960        assert_eq!(sim.shape()[0], 2); // 2 users
961        assert_eq!(sim.shape()[1], 2); // 2 users
962                                       // u0 and u1 both bought i0; so there is 1 path u0→i0→u1 and 1 path u1→i0→u0
963        let u0_u1 = sim[[0, 1]];
964        let u1_u0 = sim[[1, 0]];
965        assert!(u0_u1 >= 1.0, "Expected at least 1 shared path, got {u0_u1}");
966        assert!(u1_u0 >= 1.0, "Expected at least 1 shared path, got {u1_u0}");
967    }
968
969    #[test]
970    fn test_meta_path_two_steps() {
971        let (g, _u0, _u1, _i0, _i1, t0) = make_graph();
972        // user→item→tag (len=3 path)
973        let sim = meta_path_adjacency(&g, &["user", "item", "tag"]).unwrap();
974        // 2 users, 1 tag
975        assert_eq!(sim.shape(), &[2, 1]);
976        // u0 bought i0 (has t0) and i1 (no tag) → 1 path to t0
977        assert!((sim[[0, 0]] - 1.0).abs() < 1e-9, "u0 paths={}", sim[[0, 0]]);
978        // u1 bought i0 (has t0) → 1 path to t0
979        assert!((sim[[1, 0]] - 1.0).abs() < 1e-9, "u1 paths={}", sim[[1, 0]]);
980        // t0's id is still referenced indirectly
981        let _ = t0;
982    }
983
984    // ------------------------------------------------------------------
985    // hetero_message_passing
986    // ------------------------------------------------------------------
987
988    #[test]
989    fn test_hetero_message_passing_basic() {
990        let (g, u0, u1, _i0, _i1, _t0) = make_graph();
991        let mut feats: HashMap<NodeId, Vec<f64>> = HashMap::new();
992        feats.insert(u0, vec![1.0, 0.0]);
993        feats.insert(u1, vec![0.0, 1.0]);
994
995        let results = hetero_message_passing(&g, &feats, 2).unwrap();
996        // Should have one result per edge type with at least one edge
997        // buys: user→item;   has_tag: item→tag
998        assert!(!results.is_empty());
999
1000        // Find the "buys" result
1001        let buys_result = results
1002            .iter()
1003            .find(|r| r.edge_type.relation == "buys")
1004            .expect("should have buys result");
1005
1006        // 2 item nodes; i0 received from u0 and u1, i1 only from u0
1007        assert_eq!(buys_result.aggregated.shape()[0], 2); // 2 items
1008        assert_eq!(buys_result.aggregated.shape()[1], 2); // 2 features
1009    }
1010
1011    #[test]
1012    fn test_hetero_message_passing_wrong_dim() {
1013        let (g, u0, _u1, _i0, _i1, _t0) = make_graph();
1014        let mut feats: HashMap<NodeId, Vec<f64>> = HashMap::new();
1015        feats.insert(u0, vec![1.0, 2.0, 3.0]); // 3-dim but feature_dim=2
1016        assert!(hetero_message_passing(&g, &feats, 2).is_err());
1017    }
1018
1019    // ------------------------------------------------------------------
1020    // Utility helpers
1021    // ------------------------------------------------------------------
1022
1023    #[test]
1024    fn test_typed_out_degree() {
1025        let (g, u0, u1, _i0, _i1, _t0) = make_graph();
1026        let et = HeteroEdgeType::new("user", "buys", "item");
1027        let deg = typed_out_degree(&g, &et);
1028        assert_eq!(deg[&u0], 2); // u0 buys 2 items
1029        assert_eq!(deg[&u1], 1); // u1 buys 1 item
1030    }
1031
1032    #[test]
1033    fn test_typed_in_degree() {
1034        let (g, _u0, _u1, i0, i1, _t0) = make_graph();
1035        let et = HeteroEdgeType::new("user", "buys", "item");
1036        let deg = typed_in_degree(&g, &et);
1037        assert_eq!(deg[&i0], 2); // i0 bought by 2 users
1038        assert_eq!(deg[&i1], 1); // i1 bought by 1 user
1039    }
1040
1041    #[test]
1042    fn test_reachable_types() {
1043        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
1044        let reachable = reachable_types(&g, "user");
1045        assert!(reachable.contains("user"));
1046        assert!(reachable.contains("item"));
1047        assert!(reachable.contains("tag"));
1048        // No node type "ghost"
1049        assert!(!reachable.contains("ghost"));
1050    }
1051
1052    #[test]
1053    fn test_to_homogeneous_edge_list() {
1054        let (g, _u0, _u1, _i0, _i1, _t0) = make_graph();
1055        let edges = to_homogeneous_edge_list(&g);
1056        assert_eq!(edges.len(), 4); // 3 buys + 1 has_tag
1057    }
1058
1059    #[test]
1060    fn test_all_out_neighbors_typed() {
1061        let (g, u0, _u1, _i0, _i1, _t0) = make_graph();
1062        let nbrs = g.all_out_neighbors_typed(u0);
1063        assert_eq!(nbrs.len(), 1); // only "buys" edges from u0
1064        assert_eq!(nbrs[0].0.relation, "buys");
1065        assert_eq!(nbrs[0].1.len(), 2); // u0 buys 2 items
1066    }
1067
1068    #[test]
1069    fn test_contains_node() {
1070        let (g, u0, _u1, _i0, _i1, _t0) = make_graph();
1071        assert!(g.contains_node(u0));
1072        assert!(!g.contains_node(NodeId(999)));
1073    }
1074}