1use crate::error::{GraphError, Result};
29use scirs2_core::ndarray::Array2;
30use scirs2_core::random::{Rng, RngExt};
31
32#[derive(Debug, Clone)]
38struct Linear {
39 weight: Vec<Vec<f64>>,
40 bias: Vec<f64>,
41 out_dim: usize,
42 in_dim: usize,
43}
44
45impl Linear {
46 fn new(in_dim: usize, out_dim: usize) -> Self {
47 let scale = (2.0 / in_dim as f64).sqrt();
48 let mut rng = scirs2_core::random::rng();
49 let weight: Vec<Vec<f64>> = (0..out_dim)
50 .map(|_| {
51 (0..in_dim)
52 .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
53 .collect()
54 })
55 .collect();
56 Linear {
57 weight,
58 bias: vec![0.0; out_dim],
59 out_dim,
60 in_dim,
61 }
62 }
63
64 fn forward(&self, x: &[f64]) -> Vec<f64> {
65 let mut out = self.bias.clone();
66 for (i, row) in self.weight.iter().enumerate() {
67 for (j, &w) in row.iter().enumerate() {
68 out[i] += w * x[j];
69 }
70 }
71 out
72 }
73}
74
75fn layer_norm(x: &mut [f64]) {
80 let n = x.len() as f64;
81 let mean: f64 = x.iter().sum::<f64>() / n;
82 let var: f64 = x.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / n;
83 let std_dev = (var + 1e-8).sqrt();
84 for v in x.iter_mut() {
85 *v = (*v - mean) / std_dev;
86 }
87}
88
89fn softmax(xs: &[f64]) -> Vec<f64> {
90 if xs.is_empty() {
91 return Vec::new();
92 }
93 let max_val = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
94 let exps: Vec<f64> = xs.iter().map(|x| (x - max_val).exp()).collect();
95 let sum: f64 = exps.iter().sum::<f64>().max(1e-15);
96 exps.iter().map(|e| e / sum).collect()
97}
98
99#[derive(Debug, Clone)]
105#[non_exhaustive]
106pub struct HypergraphAttentionConfig {
107 pub hidden_dim: usize,
109 pub n_heads: usize,
111 pub dropout: f64,
113 pub use_layer_norm: bool,
115}
116
117impl Default for HypergraphAttentionConfig {
118 fn default() -> Self {
119 HypergraphAttentionConfig {
120 hidden_dim: 64,
121 n_heads: 4,
122 dropout: 0.1,
123 use_layer_norm: true,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
138pub struct HypergraphAttentionLayer {
139 w_q_node: Linear,
141 w_k_edge: Linear,
143 w_v_node: Linear,
145 w_q_edge: Linear,
147 w_k_node: Linear,
149 w_v_edge: Linear,
151 w_o: Linear,
153 w_e: Linear,
155 config: HypergraphAttentionConfig,
157 in_dim: usize,
159}
160
161impl HypergraphAttentionLayer {
162 pub fn new(in_dim: usize, config: HypergraphAttentionConfig) -> Self {
168 let h = config.hidden_dim;
169 HypergraphAttentionLayer {
170 w_q_node: Linear::new(in_dim, h),
171 w_k_edge: Linear::new(h, h),
172 w_v_node: Linear::new(in_dim, h),
173 w_q_edge: Linear::new(h, h),
174 w_k_node: Linear::new(in_dim, h),
175 w_v_edge: Linear::new(h, h),
176 w_o: Linear::new(h, in_dim),
177 w_e: Linear::new(in_dim, h),
178 config,
179 in_dim,
180 }
181 }
182
183 pub fn forward(
192 &self,
193 node_feats: &Array2<f64>,
194 incidence_matrix: &Array2<f64>,
195 ) -> Result<Array2<f64>> {
196 let n_nodes = node_feats.nrows();
197 let in_d = node_feats.ncols();
198
199 if in_d != self.in_dim {
200 return Err(GraphError::InvalidParameter {
201 param: "node_feats".to_string(),
202 value: format!("ncols={in_d}"),
203 expected: format!("ncols={}", self.in_dim),
204 context: "HypergraphAttentionLayer::forward".to_string(),
205 });
206 }
207 if incidence_matrix.nrows() != n_nodes {
208 return Err(GraphError::InvalidParameter {
209 param: "incidence_matrix".to_string(),
210 value: format!("nrows={}", incidence_matrix.nrows()),
211 expected: format!("nrows={n_nodes}"),
212 context: "HypergraphAttentionLayer::forward".to_string(),
213 });
214 }
215
216 let n_edges = incidence_matrix.ncols();
217 let h_dim = self.config.hidden_dim;
218 let scale = (h_dim as f64).sqrt();
219
220 let mut edge_feats: Vec<Vec<f64>> = Vec::with_capacity(n_edges);
223 for edge_h in 0..n_edges {
224 let members: Vec<usize> = (0..n_nodes)
225 .filter(|&i| incidence_matrix[[i, edge_h]] > 0.5)
226 .collect();
227 let mean_feat = if members.is_empty() {
228 vec![0.0_f64; in_d]
229 } else {
230 let inv_n = 1.0 / members.len() as f64;
231 let mut mean = vec![0.0_f64; in_d];
232 for &i in &members {
233 for d in 0..in_d {
234 mean[d] += node_feats[[i, d]] * inv_n;
235 }
236 }
237 mean
238 };
239 edge_feats.push(self.w_e.forward(&mean_feat));
240 }
241
242 let mut edge_feats_new: Vec<Vec<f64>> = vec![vec![0.0_f64; h_dim]; n_edges];
245
246 for edge_h in 0..n_edges {
247 let members: Vec<usize> = (0..n_nodes)
248 .filter(|&i| incidence_matrix[[i, edge_h]] > 0.5)
249 .collect();
250 if members.is_empty() {
251 edge_feats_new[edge_h] = edge_feats[edge_h].clone();
252 continue;
253 }
254
255 let k_e = self.w_k_edge.forward(&edge_feats[edge_h]);
256
257 let scores: Vec<f64> = members
259 .iter()
260 .map(|&i| {
261 let q_i = self.w_q_node.forward(
262 node_feats
263 .row(i)
264 .as_slice()
265 .unwrap_or(&[])
266 .to_vec()
267 .as_slice(),
268 );
269 let dot: f64 = q_i.iter().zip(k_e.iter()).map(|(a, b)| a * b).sum();
270 dot / scale
271 })
272 .collect();
273
274 let alphas = softmax(&scores);
275
276 let e_new = &mut edge_feats_new[edge_h];
278 for (k, &i) in members.iter().enumerate() {
279 let v_i = self.w_v_node.forward(
280 node_feats
281 .row(i)
282 .as_slice()
283 .unwrap_or(&[])
284 .to_vec()
285 .as_slice(),
286 );
287 for d in 0..h_dim {
288 e_new[d] += alphas[k] * v_i[d];
289 }
290 }
291 }
292
293 let mut node_feats_new = Array2::zeros((n_nodes, in_d));
295 let mut residual_used = vec![false; n_nodes];
296
297 for node_i in 0..n_nodes {
298 let incident_edges: Vec<usize> = (0..n_edges)
299 .filter(|&h| incidence_matrix[[node_i, h]] > 0.5)
300 .collect();
301 if incident_edges.is_empty() {
302 for d in 0..in_d {
304 node_feats_new[[node_i, d]] = node_feats[[node_i, d]];
305 }
306 residual_used[node_i] = true;
307 continue;
308 }
309
310 let k_i = self.w_k_node.forward(
311 node_feats
312 .row(node_i)
313 .as_slice()
314 .unwrap_or(&[])
315 .to_vec()
316 .as_slice(),
317 );
318
319 let scores: Vec<f64> = incident_edges
321 .iter()
322 .map(|&h| {
323 let q_h = self.w_q_edge.forward(&edge_feats_new[h]);
324 let dot: f64 = q_h.iter().zip(k_i.iter()).map(|(a, b)| a * b).sum();
325 dot / scale
326 })
327 .collect();
328
329 let betas = softmax(&scores);
330
331 let mut x_new_h = vec![0.0_f64; h_dim];
333 for (k, &h) in incident_edges.iter().enumerate() {
334 let v_h = self.w_v_edge.forward(&edge_feats_new[h]);
335 for d in 0..h_dim {
336 x_new_h[d] += betas[k] * v_h[d];
337 }
338 }
339
340 let projected = self.w_o.forward(&x_new_h);
342 let mut out_i: Vec<f64> = projected
343 .iter()
344 .enumerate()
345 .map(|(d, &p)| p + node_feats[[node_i, d]])
346 .collect();
347
348 if self.config.use_layer_norm {
350 layer_norm(&mut out_i);
351 }
352
353 for d in 0..in_d {
354 node_feats_new[[node_i, d]] = out_i[d];
355 }
356 }
357
358 Ok(node_feats_new)
359 }
360}
361
362#[derive(Debug, Clone)]
368pub struct HypergraphAttentionNetwork {
369 pub layers: Vec<HypergraphAttentionLayer>,
371 ff_layers: Vec<(Linear, Linear)>,
373 pub in_dim: usize,
375 pub config: HypergraphAttentionConfig,
377}
378
379impl HypergraphAttentionNetwork {
380 pub fn new(in_dim: usize, n_layers: usize, config: HypergraphAttentionConfig) -> Self {
387 let h = config.hidden_dim;
388 let layers = (0..n_layers)
389 .map(|_| HypergraphAttentionLayer::new(in_dim, config.clone()))
390 .collect();
391 let ff_layers = (0..n_layers)
393 .map(|_| (Linear::new(in_dim, h), Linear::new(h, in_dim)))
394 .collect();
395 HypergraphAttentionNetwork {
396 layers,
397 ff_layers,
398 in_dim,
399 config,
400 }
401 }
402
403 pub fn forward(
412 &self,
413 node_feats: &Array2<f64>,
414 incidence_matrix: &Array2<f64>,
415 ) -> Result<Array2<f64>> {
416 let mut x = node_feats.clone();
417 for (layer, (ff1, ff2)) in self.layers.iter().zip(self.ff_layers.iter()) {
418 let x_att = layer.forward(&x, incidence_matrix)?;
419 let mut x_ff = Array2::zeros(x_att.dim());
421 for i in 0..x_att.nrows() {
422 let row: Vec<f64> = x_att.row(i).to_vec();
423 let mut h_mid = ff1.forward(&row);
424 for v in h_mid.iter_mut() {
425 *v = v.max(0.0); }
427 let projected = ff2.forward(&h_mid);
428 let mut out: Vec<f64> = projected
429 .iter()
430 .zip(row.iter())
431 .map(|(p, r)| p + r)
432 .collect();
433 if self.config.use_layer_norm {
434 layer_norm(&mut out);
435 }
436 for d in 0..self.in_dim {
437 x_ff[[i, d]] = out[d];
438 }
439 }
440 x = x_ff;
441 }
442 Ok(x)
443 }
444}
445
446#[cfg(test)]
451mod tests {
452 use super::*;
453 use scirs2_core::ndarray::Array2;
454
455 fn make_node_feats(n_nodes: usize, in_dim: usize) -> Array2<f64> {
456 let data: Vec<f64> = (0..n_nodes * in_dim)
457 .map(|i| (i as f64 + 1.0) * 0.1)
458 .collect();
459 Array2::from_shape_vec((n_nodes, in_dim), data).expect("node feats")
460 }
461
462 fn make_incidence_matrix(n_nodes: usize, n_edges: usize) -> Array2<f64> {
463 let mut h = Array2::zeros((n_nodes, n_edges));
465 if n_nodes >= 3 && n_edges >= 1 {
466 h[[0, 0]] = 1.0;
467 h[[1, 0]] = 1.0;
468 h[[2, 0]] = 1.0;
469 }
470 if n_nodes >= 5 && n_edges >= 2 {
471 h[[2, 1]] = 1.0;
472 h[[3, 1]] = 1.0;
473 h[[4, 1]] = 1.0;
474 }
475 h
476 }
477
478 #[test]
479 fn test_attention_layer_output_shape() {
480 let config = HypergraphAttentionConfig {
481 hidden_dim: 8,
482 n_heads: 2,
483 ..Default::default()
484 };
485 let layer = HypergraphAttentionLayer::new(4, config);
486 let node_feats = make_node_feats(5, 4);
487 let incidence = make_incidence_matrix(5, 2);
488 let out = layer.forward(&node_feats, &incidence).expect("forward");
489 assert_eq!(out.nrows(), 5, "output node count");
490 assert_eq!(out.ncols(), 4, "output feature dim");
491 }
492
493 #[test]
494 fn test_attention_handles_varying_hyperedge_sizes() {
495 let mut incidence = Array2::zeros((5, 2));
497 incidence[[0, 0]] = 1.0;
498 incidence[[1, 1]] = 1.0;
499 incidence[[2, 1]] = 1.0;
500 incidence[[3, 1]] = 1.0;
501 incidence[[4, 1]] = 1.0;
502
503 let config = HypergraphAttentionConfig {
504 hidden_dim: 8,
505 n_heads: 2,
506 ..Default::default()
507 };
508 let layer = HypergraphAttentionLayer::new(4, config);
509 let node_feats = make_node_feats(5, 4);
510 let out = layer
511 .forward(&node_feats, &incidence)
512 .expect("varying sizes");
513 assert_eq!(out.shape(), &[5, 4]);
514 }
515
516 #[test]
517 fn test_attention_output_is_finite() {
518 let config = HypergraphAttentionConfig {
519 hidden_dim: 8,
520 ..Default::default()
521 };
522 let layer = HypergraphAttentionLayer::new(4, config);
523 let node_feats = make_node_feats(5, 4);
524 let incidence = make_incidence_matrix(5, 2);
525 let out = layer.forward(&node_feats, &incidence).expect("forward");
526 for v in out.iter() {
527 assert!(v.is_finite(), "output must be finite, got {v}");
528 }
529 }
530
531 #[test]
532 fn test_network_stacked_output_shape() {
533 let config = HypergraphAttentionConfig {
534 hidden_dim: 8,
535 n_heads: 2,
536 ..Default::default()
537 };
538 let net = HypergraphAttentionNetwork::new(4, 3, config);
539 let node_feats = make_node_feats(5, 4);
540 let incidence = make_incidence_matrix(5, 2);
541 let out = net.forward(&node_feats, &incidence).expect("net forward");
542 assert_eq!(out.shape(), &[5, 4]);
543 }
544
545 #[test]
546 fn test_network_output_is_finite() {
547 let config = HypergraphAttentionConfig {
548 hidden_dim: 8,
549 ..Default::default()
550 };
551 let net = HypergraphAttentionNetwork::new(4, 2, config);
552 let node_feats = make_node_feats(5, 4);
553 let incidence = make_incidence_matrix(5, 2);
554 let out = net.forward(&node_feats, &incidence).expect("forward");
555 for v in out.iter() {
556 assert!(v.is_finite(), "network output must be finite");
557 }
558 }
559
560 #[test]
561 fn test_empty_hyperedge() {
562 let incidence = Array2::zeros((3, 2)); let config = HypergraphAttentionConfig {
565 hidden_dim: 8,
566 use_layer_norm: false,
567 ..Default::default()
568 };
569 let layer = HypergraphAttentionLayer::new(4, config);
570 let node_feats = make_node_feats(3, 4);
571 let out = layer
572 .forward(&node_feats, &incidence)
573 .expect("empty hyperedge");
574 for i in 0..3 {
576 for d in 0..4 {
577 assert!(
578 (out[[i, d]] - node_feats[[i, d]]).abs() < 1e-12,
579 "residual mismatch at ({i},{d})"
580 );
581 }
582 }
583 }
584}