Skip to main content

scirs2_graph/hypergraph/
attention.rs

1//! Hypergraph Attention Network (HAN).
2//!
3//! Implements a dual-attention mechanism over hypergraphs:
4//! 1. **Node-to-hyperedge attention**: each hyperedge aggregates member node features
5//! 2. **Hyperedge-to-node attention**: each node aggregates features from its hyperedges
6//!
7//! ## Architecture
8//!
9//! Given incidence matrix `H ∈ {0,1}^{N×M}` (N nodes, M hyperedges):
10//!
11//! ### Node → Hyperedge
12//! ```text
13//! a_{ih} = softmax_i ∈ h [ (W_Q x_i · W_K ê_h) / sqrt(d) ]
14//! e_h^new = sum_{i ∈ h} a_{ih} W_V x_i
15//! ```
16//!
17//! ### Hyperedge → Node
18//! ```text
19//! b_{hi} = softmax_{h ∋ i} [ (W_Q ê_h^new · W_K x_i) / sqrt(d) ]
20//! x_i^new = sum_{h ∋ i} b_{hi} W_V ê_h^new
21//! ```
22//!
23//! ## References
24//!
25//! - Ding et al. (2020). "HNHN: Hypergraph Networks with Hyperedge Neurons."
26//! - Bai et al. (2021). "Hypergraph Convolution and Hypergraph Attention."
27
28use crate::error::{GraphError, Result};
29use scirs2_core::ndarray::Array2;
30use scirs2_core::random::{Rng, RngExt};
31
32// ============================================================================
33// Linear layer (local to this module, same pattern as egnn.rs)
34// ============================================================================
35
36/// A simple linear layer: y = W x + b.
37#[derive(Debug, Clone)]
38struct Linear {
39    weight: Vec<Vec<f64>>,
40    bias: Vec<f64>,
41    out_dim: usize,
42    in_dim: usize,
43}
44
45impl Linear {
46    fn new(in_dim: usize, out_dim: usize) -> Self {
47        let scale = (2.0 / in_dim as f64).sqrt();
48        let mut rng = scirs2_core::random::rng();
49        let weight: Vec<Vec<f64>> = (0..out_dim)
50            .map(|_| {
51                (0..in_dim)
52                    .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
53                    .collect()
54            })
55            .collect();
56        Linear {
57            weight,
58            bias: vec![0.0; out_dim],
59            out_dim,
60            in_dim,
61        }
62    }
63
64    fn forward(&self, x: &[f64]) -> Vec<f64> {
65        let mut out = self.bias.clone();
66        for (i, row) in self.weight.iter().enumerate() {
67            for (j, &w) in row.iter().enumerate() {
68                out[i] += w * x[j];
69            }
70        }
71        out
72    }
73}
74
75// ============================================================================
76// Layer Norm helper
77// ============================================================================
78
79fn layer_norm(x: &mut [f64]) {
80    let n = x.len() as f64;
81    let mean: f64 = x.iter().sum::<f64>() / n;
82    let var: f64 = x.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / n;
83    let std_dev = (var + 1e-8).sqrt();
84    for v in x.iter_mut() {
85        *v = (*v - mean) / std_dev;
86    }
87}
88
89fn softmax(xs: &[f64]) -> Vec<f64> {
90    if xs.is_empty() {
91        return Vec::new();
92    }
93    let max_val = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
94    let exps: Vec<f64> = xs.iter().map(|x| (x - max_val).exp()).collect();
95    let sum: f64 = exps.iter().sum::<f64>().max(1e-15);
96    exps.iter().map(|e| e / sum).collect()
97}
98
99// ============================================================================
100// Configuration
101// ============================================================================
102
103/// Configuration for a HypergraphAttentionLayer.
104#[derive(Debug, Clone)]
105#[non_exhaustive]
106pub struct HypergraphAttentionConfig {
107    /// Node feature (and hyperedge feature) dimension.
108    pub hidden_dim: usize,
109    /// Number of attention heads.
110    pub n_heads: usize,
111    /// Dropout rate (0.0 = disabled; for inference/tests, dropout is typically turned off).
112    pub dropout: f64,
113    /// Whether to apply layer normalisation.
114    pub use_layer_norm: bool,
115}
116
117impl Default for HypergraphAttentionConfig {
118    fn default() -> Self {
119        HypergraphAttentionConfig {
120            hidden_dim: 64,
121            n_heads: 4,
122            dropout: 0.1,
123            use_layer_norm: true,
124        }
125    }
126}
127
128// ============================================================================
129// HypergraphAttentionLayer
130// ============================================================================
131
132/// A single Hypergraph Attention layer.
133///
134/// Processes node features through two rounds of dual-direction attention:
135/// 1. Nodes → Hyperedges (aggregate node info into hyperedge representations)
136/// 2. Hyperedges → Nodes (aggregate hyperedge info back to node representations)
137#[derive(Debug, Clone)]
138pub struct HypergraphAttentionLayer {
139    /// Query projection for node features.
140    w_q_node: Linear,
141    /// Key projection for hyperedge features.
142    w_k_edge: Linear,
143    /// Value projection for node features.
144    w_v_node: Linear,
145    /// Query projection for hyperedge features.
146    w_q_edge: Linear,
147    /// Key projection for node features.
148    w_k_node: Linear,
149    /// Value projection for hyperedge features.
150    w_v_edge: Linear,
151    /// Output projection back to node feature space.
152    w_o: Linear,
153    /// Hyperedge initial feature projection W_e.
154    w_e: Linear,
155    /// Configuration.
156    config: HypergraphAttentionConfig,
157    /// Input feature dimension.
158    in_dim: usize,
159}
160
161impl HypergraphAttentionLayer {
162    /// Create a new HypergraphAttentionLayer.
163    ///
164    /// # Arguments
165    /// - `in_dim`: input node feature dimension
166    /// - `config`: layer configuration
167    pub fn new(in_dim: usize, config: HypergraphAttentionConfig) -> Self {
168        let h = config.hidden_dim;
169        HypergraphAttentionLayer {
170            w_q_node: Linear::new(in_dim, h),
171            w_k_edge: Linear::new(h, h),
172            w_v_node: Linear::new(in_dim, h),
173            w_q_edge: Linear::new(h, h),
174            w_k_node: Linear::new(in_dim, h),
175            w_v_edge: Linear::new(h, h),
176            w_o: Linear::new(h, in_dim),
177            w_e: Linear::new(in_dim, h),
178            config,
179            in_dim,
180        }
181    }
182
183    /// Forward pass.
184    ///
185    /// # Arguments
186    /// - `node_feats`: node features, shape [N × in_dim]
187    /// - `incidence_matrix`: H ∈ {0,1}^{N×M}, shape [N × M]
188    ///
189    /// # Returns
190    /// Updated node features, shape [N × in_dim].
191    pub fn forward(
192        &self,
193        node_feats: &Array2<f64>,
194        incidence_matrix: &Array2<f64>,
195    ) -> Result<Array2<f64>> {
196        let n_nodes = node_feats.nrows();
197        let in_d = node_feats.ncols();
198
199        if in_d != self.in_dim {
200            return Err(GraphError::InvalidParameter {
201                param: "node_feats".to_string(),
202                value: format!("ncols={in_d}"),
203                expected: format!("ncols={}", self.in_dim),
204                context: "HypergraphAttentionLayer::forward".to_string(),
205            });
206        }
207        if incidence_matrix.nrows() != n_nodes {
208            return Err(GraphError::InvalidParameter {
209                param: "incidence_matrix".to_string(),
210                value: format!("nrows={}", incidence_matrix.nrows()),
211                expected: format!("nrows={n_nodes}"),
212                context: "HypergraphAttentionLayer::forward".to_string(),
213            });
214        }
215
216        let n_edges = incidence_matrix.ncols();
217        let h_dim = self.config.hidden_dim;
218        let scale = (h_dim as f64).sqrt();
219
220        // ── Step 1: compute initial hyperedge features as mean of member nodes ─
221        // e_h = W_e * (mean_{i ∈ h} x_i)
222        let mut edge_feats: Vec<Vec<f64>> = Vec::with_capacity(n_edges);
223        for edge_h in 0..n_edges {
224            let members: Vec<usize> = (0..n_nodes)
225                .filter(|&i| incidence_matrix[[i, edge_h]] > 0.5)
226                .collect();
227            let mean_feat = if members.is_empty() {
228                vec![0.0_f64; in_d]
229            } else {
230                let inv_n = 1.0 / members.len() as f64;
231                let mut mean = vec![0.0_f64; in_d];
232                for &i in &members {
233                    for d in 0..in_d {
234                        mean[d] += node_feats[[i, d]] * inv_n;
235                    }
236                }
237                mean
238            };
239            edge_feats.push(self.w_e.forward(&mean_feat));
240        }
241
242        // ── Step 2: node → hyperedge attention ─────────────────────────────
243        // For each hyperedge h, attend over its member nodes
244        let mut edge_feats_new: Vec<Vec<f64>> = vec![vec![0.0_f64; h_dim]; n_edges];
245
246        for edge_h in 0..n_edges {
247            let members: Vec<usize> = (0..n_nodes)
248                .filter(|&i| incidence_matrix[[i, edge_h]] > 0.5)
249                .collect();
250            if members.is_empty() {
251                edge_feats_new[edge_h] = edge_feats[edge_h].clone();
252                continue;
253            }
254
255            let k_e = self.w_k_edge.forward(&edge_feats[edge_h]);
256
257            // Attention scores: Q_i · K_h / sqrt(d)
258            let scores: Vec<f64> = members
259                .iter()
260                .map(|&i| {
261                    let q_i = self.w_q_node.forward(
262                        node_feats
263                            .row(i)
264                            .as_slice()
265                            .unwrap_or(&[])
266                            .to_vec()
267                            .as_slice(),
268                    );
269                    let dot: f64 = q_i.iter().zip(k_e.iter()).map(|(a, b)| a * b).sum();
270                    dot / scale
271                })
272                .collect();
273
274            let alphas = softmax(&scores);
275
276            // Aggregate: e_h_new = sum_i alpha_i * V_i
277            let e_new = &mut edge_feats_new[edge_h];
278            for (k, &i) in members.iter().enumerate() {
279                let v_i = self.w_v_node.forward(
280                    node_feats
281                        .row(i)
282                        .as_slice()
283                        .unwrap_or(&[])
284                        .to_vec()
285                        .as_slice(),
286                );
287                for d in 0..h_dim {
288                    e_new[d] += alphas[k] * v_i[d];
289                }
290            }
291        }
292
293        // ── Step 3: hyperedge → node attention ─────────────────────────────
294        let mut node_feats_new = Array2::zeros((n_nodes, in_d));
295        let mut residual_used = vec![false; n_nodes];
296
297        for node_i in 0..n_nodes {
298            let incident_edges: Vec<usize> = (0..n_edges)
299                .filter(|&h| incidence_matrix[[node_i, h]] > 0.5)
300                .collect();
301            if incident_edges.is_empty() {
302                // Residual: copy input
303                for d in 0..in_d {
304                    node_feats_new[[node_i, d]] = node_feats[[node_i, d]];
305                }
306                residual_used[node_i] = true;
307                continue;
308            }
309
310            let k_i = self.w_k_node.forward(
311                node_feats
312                    .row(node_i)
313                    .as_slice()
314                    .unwrap_or(&[])
315                    .to_vec()
316                    .as_slice(),
317            );
318
319            // Attention scores: Q_h · K_i / sqrt(d)
320            let scores: Vec<f64> = incident_edges
321                .iter()
322                .map(|&h| {
323                    let q_h = self.w_q_edge.forward(&edge_feats_new[h]);
324                    let dot: f64 = q_h.iter().zip(k_i.iter()).map(|(a, b)| a * b).sum();
325                    dot / scale
326                })
327                .collect();
328
329            let betas = softmax(&scores);
330
331            // Aggregate: x_i_new = sum_h beta_h * W_V e_h_new
332            let mut x_new_h = vec![0.0_f64; h_dim];
333            for (k, &h) in incident_edges.iter().enumerate() {
334                let v_h = self.w_v_edge.forward(&edge_feats_new[h]);
335                for d in 0..h_dim {
336                    x_new_h[d] += betas[k] * v_h[d];
337                }
338            }
339
340            // Project back to input dim + residual
341            let projected = self.w_o.forward(&x_new_h);
342            let mut out_i: Vec<f64> = projected
343                .iter()
344                .enumerate()
345                .map(|(d, &p)| p + node_feats[[node_i, d]])
346                .collect();
347
348            // Layer norm
349            if self.config.use_layer_norm {
350                layer_norm(&mut out_i);
351            }
352
353            for d in 0..in_d {
354                node_feats_new[[node_i, d]] = out_i[d];
355            }
356        }
357
358        Ok(node_feats_new)
359    }
360}
361
362// ============================================================================
363// HypergraphAttentionNetwork
364// ============================================================================
365
366/// Multi-layer Hypergraph Attention Network.
367#[derive(Debug, Clone)]
368pub struct HypergraphAttentionNetwork {
369    /// Stacked hypergraph attention layers.
370    pub layers: Vec<HypergraphAttentionLayer>,
371    /// Inter-layer MLPs (feedforward block after each attention layer).
372    ff_layers: Vec<(Linear, Linear)>,
373    /// Input dimension.
374    pub in_dim: usize,
375    /// Configuration (of the first layer, representative).
376    pub config: HypergraphAttentionConfig,
377}
378
379impl HypergraphAttentionNetwork {
380    /// Create a multi-layer Hypergraph Attention Network.
381    ///
382    /// # Arguments
383    /// - `in_dim`: input node feature dimension
384    /// - `n_layers`: number of stacked attention layers
385    /// - `config`: configuration (shared across layers)
386    pub fn new(in_dim: usize, n_layers: usize, config: HypergraphAttentionConfig) -> Self {
387        let h = config.hidden_dim;
388        let layers = (0..n_layers)
389            .map(|_| HypergraphAttentionLayer::new(in_dim, config.clone()))
390            .collect();
391        // Feedforward block: in_dim → h → in_dim
392        let ff_layers = (0..n_layers)
393            .map(|_| (Linear::new(in_dim, h), Linear::new(h, in_dim)))
394            .collect();
395        HypergraphAttentionNetwork {
396            layers,
397            ff_layers,
398            in_dim,
399            config,
400        }
401    }
402
403    /// Forward pass through all layers.
404    ///
405    /// # Arguments
406    /// - `node_feats`: initial node features [N × in_dim]
407    /// - `incidence_matrix`: H ∈ {0,1}^{N×M}
408    ///
409    /// # Returns
410    /// Final node features [N × in_dim].
411    pub fn forward(
412        &self,
413        node_feats: &Array2<f64>,
414        incidence_matrix: &Array2<f64>,
415    ) -> Result<Array2<f64>> {
416        let mut x = node_feats.clone();
417        for (layer, (ff1, ff2)) in self.layers.iter().zip(self.ff_layers.iter()) {
418            let x_att = layer.forward(&x, incidence_matrix)?;
419            // Feedforward: apply per-node, with ReLU + residual
420            let mut x_ff = Array2::zeros(x_att.dim());
421            for i in 0..x_att.nrows() {
422                let row: Vec<f64> = x_att.row(i).to_vec();
423                let mut h_mid = ff1.forward(&row);
424                for v in h_mid.iter_mut() {
425                    *v = v.max(0.0); // ReLU
426                }
427                let projected = ff2.forward(&h_mid);
428                let mut out: Vec<f64> = projected
429                    .iter()
430                    .zip(row.iter())
431                    .map(|(p, r)| p + r)
432                    .collect();
433                if self.config.use_layer_norm {
434                    layer_norm(&mut out);
435                }
436                for d in 0..self.in_dim {
437                    x_ff[[i, d]] = out[d];
438                }
439            }
440            x = x_ff;
441        }
442        Ok(x)
443    }
444}
445
446// ============================================================================
447// Tests
448// ============================================================================
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use scirs2_core::ndarray::Array2;
454
455    fn make_node_feats(n_nodes: usize, in_dim: usize) -> Array2<f64> {
456        let data: Vec<f64> = (0..n_nodes * in_dim)
457            .map(|i| (i as f64 + 1.0) * 0.1)
458            .collect();
459        Array2::from_shape_vec((n_nodes, in_dim), data).expect("node feats")
460    }
461
462    fn make_incidence_matrix(n_nodes: usize, n_edges: usize) -> Array2<f64> {
463        // Simple hyperedge: hyperedge 0 = {0,1,2}, hyperedge 1 = {2,3,4}
464        let mut h = Array2::zeros((n_nodes, n_edges));
465        if n_nodes >= 3 && n_edges >= 1 {
466            h[[0, 0]] = 1.0;
467            h[[1, 0]] = 1.0;
468            h[[2, 0]] = 1.0;
469        }
470        if n_nodes >= 5 && n_edges >= 2 {
471            h[[2, 1]] = 1.0;
472            h[[3, 1]] = 1.0;
473            h[[4, 1]] = 1.0;
474        }
475        h
476    }
477
478    #[test]
479    fn test_attention_layer_output_shape() {
480        let config = HypergraphAttentionConfig {
481            hidden_dim: 8,
482            n_heads: 2,
483            ..Default::default()
484        };
485        let layer = HypergraphAttentionLayer::new(4, config);
486        let node_feats = make_node_feats(5, 4);
487        let incidence = make_incidence_matrix(5, 2);
488        let out = layer.forward(&node_feats, &incidence).expect("forward");
489        assert_eq!(out.nrows(), 5, "output node count");
490        assert_eq!(out.ncols(), 4, "output feature dim");
491    }
492
493    #[test]
494    fn test_attention_handles_varying_hyperedge_sizes() {
495        // Hyperedge 0: {0} (size 1), hyperedge 1: {1,2,3,4} (size 4)
496        let mut incidence = Array2::zeros((5, 2));
497        incidence[[0, 0]] = 1.0;
498        incidence[[1, 1]] = 1.0;
499        incidence[[2, 1]] = 1.0;
500        incidence[[3, 1]] = 1.0;
501        incidence[[4, 1]] = 1.0;
502
503        let config = HypergraphAttentionConfig {
504            hidden_dim: 8,
505            n_heads: 2,
506            ..Default::default()
507        };
508        let layer = HypergraphAttentionLayer::new(4, config);
509        let node_feats = make_node_feats(5, 4);
510        let out = layer
511            .forward(&node_feats, &incidence)
512            .expect("varying sizes");
513        assert_eq!(out.shape(), &[5, 4]);
514    }
515
516    #[test]
517    fn test_attention_output_is_finite() {
518        let config = HypergraphAttentionConfig {
519            hidden_dim: 8,
520            ..Default::default()
521        };
522        let layer = HypergraphAttentionLayer::new(4, config);
523        let node_feats = make_node_feats(5, 4);
524        let incidence = make_incidence_matrix(5, 2);
525        let out = layer.forward(&node_feats, &incidence).expect("forward");
526        for v in out.iter() {
527            assert!(v.is_finite(), "output must be finite, got {v}");
528        }
529    }
530
531    #[test]
532    fn test_network_stacked_output_shape() {
533        let config = HypergraphAttentionConfig {
534            hidden_dim: 8,
535            n_heads: 2,
536            ..Default::default()
537        };
538        let net = HypergraphAttentionNetwork::new(4, 3, config);
539        let node_feats = make_node_feats(5, 4);
540        let incidence = make_incidence_matrix(5, 2);
541        let out = net.forward(&node_feats, &incidence).expect("net forward");
542        assert_eq!(out.shape(), &[5, 4]);
543    }
544
545    #[test]
546    fn test_network_output_is_finite() {
547        let config = HypergraphAttentionConfig {
548            hidden_dim: 8,
549            ..Default::default()
550        };
551        let net = HypergraphAttentionNetwork::new(4, 2, config);
552        let node_feats = make_node_feats(5, 4);
553        let incidence = make_incidence_matrix(5, 2);
554        let out = net.forward(&node_feats, &incidence).expect("forward");
555        for v in out.iter() {
556            assert!(v.is_finite(), "network output must be finite");
557        }
558    }
559
560    #[test]
561    fn test_empty_hyperedge() {
562        // Node with no hyperedge membership gets residual connection
563        let incidence = Array2::zeros((3, 2)); // all zeros → no memberships
564        let config = HypergraphAttentionConfig {
565            hidden_dim: 8,
566            use_layer_norm: false,
567            ..Default::default()
568        };
569        let layer = HypergraphAttentionLayer::new(4, config);
570        let node_feats = make_node_feats(3, 4);
571        let out = layer
572            .forward(&node_feats, &incidence)
573            .expect("empty hyperedge");
574        // Should equal input (residual path)
575        for i in 0..3 {
576            for d in 0..4 {
577                assert!(
578                    (out[[i, d]] - node_feats[[i, d]]).abs() < 1e-12,
579                    "residual mismatch at ({i},{d})"
580                );
581            }
582        }
583    }
584}