1#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use torsh_tensor::{
20 creation::{from_vec, randn, zeros},
21 Tensor,
22};
23
24#[derive(Debug, Clone)]
26pub struct HypergraphData {
27 pub x: Tensor,
29 pub incidence_matrix: Tensor,
31 pub hyperedge_weights: Option<Tensor>,
33 pub hyperedge_features: Option<Tensor>,
35 pub node_degrees: Tensor,
37 pub hyperedge_cardinalities: Tensor,
39 pub num_nodes: usize,
41 pub num_hyperedges: usize,
43}
44
45impl HypergraphData {
46 pub fn new(x: Tensor, incidence_matrix: Tensor) -> Self {
48 let num_nodes = x.shape().dims()[0];
49 let num_hyperedges = incidence_matrix.shape().dims()[1];
50
51 let node_degrees = incidence_matrix
53 .sum_dim(&[1], false)
54 .expect("sum_dim node_degrees should succeed");
55
56 let hyperedge_cardinalities = incidence_matrix
58 .sum_dim(&[0], false)
59 .expect("sum_dim hyperedge_cardinalities should succeed");
60
61 Self {
62 x,
63 incidence_matrix,
64 hyperedge_weights: None,
65 hyperedge_features: None,
66 node_degrees,
67 hyperedge_cardinalities,
68 num_nodes,
69 num_hyperedges,
70 }
71 }
72
73 pub fn with_hyperedge_weights(mut self, weights: Tensor) -> Self {
75 self.hyperedge_weights = Some(weights);
76 self
77 }
78
79 pub fn with_hyperedge_features(mut self, features: Tensor) -> Self {
81 self.hyperedge_features = Some(features);
82 self
83 }
84
85 pub fn to_graph_clique_expansion(&self) -> GraphData {
87 let incidence_data = self
88 .incidence_matrix
89 .to_vec()
90 .expect("conversion should succeed");
91 let mut edges = Vec::new();
92
93 for e in 0..self.num_hyperedges {
95 let mut nodes_in_hyperedge = Vec::new();
96
97 for v in 0..self.num_nodes {
99 let idx = v * self.num_hyperedges + e;
100 if incidence_data[idx] > 0.0 {
101 nodes_in_hyperedge.push(v as f32);
102 }
103 }
104
105 for i in 0..nodes_in_hyperedge.len() {
107 for j in (i + 1)..nodes_in_hyperedge.len() {
108 edges.extend_from_slice(&[nodes_in_hyperedge[i], nodes_in_hyperedge[j]]);
109 edges.extend_from_slice(&[nodes_in_hyperedge[j], nodes_in_hyperedge[i]]);
110 }
111 }
112 }
113
114 let edge_index = if edges.is_empty() {
115 zeros(&[2, 0]).expect("zeros empty edge_index should succeed")
116 } else {
117 let num_edges = edges.len() / 2;
118 from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
119 .expect("from_vec edge_index should succeed")
120 };
121
122 GraphData::new(self.x.clone(), edge_index)
123 }
124
125 pub fn to_graph_star_expansion(&self) -> GraphData {
127 let incidence_data = self
128 .incidence_matrix
129 .to_vec()
130 .expect("conversion should succeed");
131 let mut edges = Vec::new();
132
133 let virtual_node_offset = self.num_nodes;
135
136 for e in 0..self.num_hyperedges {
137 let virtual_node = (virtual_node_offset + e) as f32;
138
139 for v in 0..self.num_nodes {
141 let idx = v * self.num_hyperedges + e;
142 if incidence_data[idx] > 0.0 {
143 let node = v as f32;
144 edges.extend_from_slice(&[node, virtual_node]);
145 edges.extend_from_slice(&[virtual_node, node]);
146 }
147 }
148 }
149
150 let edge_index = if edges.is_empty() {
151 zeros(&[2, 0]).expect("zeros empty edge_index should succeed")
152 } else {
153 let num_edges = edges.len() / 2;
154 from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
155 .expect("from_vec edge_index should succeed")
156 };
157
158 let virtual_features: Tensor = randn(&[self.num_hyperedges, self.x.shape().dims()[1]])
160 .expect("randn virtual_features should succeed");
161 let node_data = self.x.to_vec().expect("conversion should succeed");
163 let virtual_data = virtual_features
164 .to_vec()
165 .expect("conversion should succeed");
166 let mut extended_data = node_data;
167 extended_data.extend(virtual_data);
168
169 let total_nodes = self.num_nodes + self.num_hyperedges;
170 let features_dim = self.x.shape().dims()[1];
171 let extended_x = from_vec(
172 extended_data,
173 &[total_nodes, features_dim],
174 torsh_core::device::DeviceType::Cpu,
175 )
176 .expect("from_vec extended_x should succeed");
177
178 GraphData::new(extended_x, edge_index)
179 }
180}
181
182#[derive(Debug)]
184pub struct HGCNConv {
185 in_features: usize,
186 out_features: usize,
187 weight: Parameter,
188 bias: Option<Parameter>,
189 use_attention: bool,
190 attention_weight: Option<Parameter>,
191 dropout: f32,
192}
193
194impl HGCNConv {
195 pub fn new(
197 in_features: usize,
198 out_features: usize,
199 bias: bool,
200 use_attention: bool,
201 dropout: f32,
202 ) -> Self {
203 let weight = Parameter::new(
204 randn(&[in_features, out_features]).expect("randn weight should succeed"),
205 );
206 let bias = if bias {
207 Some(Parameter::new(
208 zeros(&[out_features]).expect("zeros bias should succeed"),
209 ))
210 } else {
211 None
212 };
213
214 let attention_weight = if use_attention {
215 Some(Parameter::new(
216 randn(&[out_features]).expect("randn attention_weight should succeed"),
217 ))
218 } else {
219 None
220 };
221
222 Self {
223 in_features,
224 out_features,
225 weight,
226 bias,
227 use_attention,
228 attention_weight,
229 dropout,
230 }
231 }
232
233 pub fn forward(&self, hypergraph: &HypergraphData) -> HypergraphData {
235 let node_features_transformed = hypergraph
238 .x
239 .matmul(&self.weight.clone_data())
240 .expect("operation should succeed");
241
242 let output_features = if let Some(ref bias) = self.bias {
244 node_features_transformed
245 .add(&bias.clone_data())
246 .expect("operation should succeed")
247 } else {
248 node_features_transformed
249 };
250
251 HypergraphData {
253 x: output_features,
254 incidence_matrix: hypergraph.incidence_matrix.clone(),
255 hyperedge_weights: hypergraph.hyperedge_weights.clone(),
256 hyperedge_features: hypergraph.hyperedge_features.clone(),
257 node_degrees: hypergraph.node_degrees.clone(),
258 hyperedge_cardinalities: hypergraph.hyperedge_cardinalities.clone(),
259 num_nodes: hypergraph.num_nodes,
260 num_hyperedges: hypergraph.num_hyperedges,
261 }
262 }
263
264 fn apply_attention(&self, hyperedge_features: &Tensor, _hypergraph: &HypergraphData) -> Tensor {
266 if let Some(ref attention_weight) = self.attention_weight {
267 let attention_scores = hyperedge_features
269 .matmul(&attention_weight.clone_data())
270 .expect("operation should succeed");
271 let attention_probs = attention_scores
272 .softmax(-1)
273 .expect("softmax should succeed");
274
275 let attention_expanded = attention_probs
277 .unsqueeze(-1)
278 .expect("unsqueeze should succeed");
279 hyperedge_features
280 .mul(&attention_expanded)
281 .expect("operation should succeed")
282 } else {
283 hyperedge_features.clone()
284 }
285 }
286
287 fn normalize_by_degrees(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
289 let degrees = &hypergraph.node_degrees;
290 let epsilon = 1e-8;
291
292 let safe_degrees = degrees
294 .add_scalar(epsilon)
295 .expect("add_scalar should succeed");
296 let inv_degrees = safe_degrees
297 .reciprocal()
298 .expect("reciprocal should succeed");
299
300 let inv_degrees_squeezed = if inv_degrees.shape().dims().len() > 1 {
303 inv_degrees
304 .squeeze_tensor(1)
305 .expect("squeeze_tensor should succeed")
306 } else {
307 inv_degrees
308 };
309 let inv_degrees_expanded = inv_degrees_squeezed
310 .unsqueeze(-1)
311 .expect("unsqueeze should succeed");
312 features
313 .mul(&inv_degrees_expanded)
314 .expect("operation should succeed")
315 }
316}
317
318impl GraphLayer for HGCNConv {
319 fn forward(&self, graph: &GraphData) -> GraphData {
320 let hypergraph = graph_to_hypergraph(graph);
322 let output_hypergraph = self.forward(&hypergraph);
323 output_hypergraph.to_graph_clique_expansion()
324 }
325
326 fn parameters(&self) -> Vec<Tensor> {
327 let mut params = vec![self.weight.clone_data()];
328 if let Some(ref bias) = self.bias {
329 params.push(bias.clone_data());
330 }
331 if let Some(ref attention_weight) = self.attention_weight {
332 params.push(attention_weight.clone_data());
333 }
334 params
335 }
336}
337
338#[derive(Debug)]
340pub struct HyperGATConv {
341 in_features: usize,
342 out_features: usize,
343 heads: usize,
344 query_weight: Parameter,
345 key_weight: Parameter,
346 value_weight: Parameter,
347 hyperedge_attention: Parameter,
348 output_weight: Parameter,
349 bias: Option<Parameter>,
350 dropout: f32,
351}
352
353impl HyperGATConv {
354 pub fn new(
356 in_features: usize,
357 out_features: usize,
358 heads: usize,
359 dropout: f32,
360 bias: bool,
361 ) -> Self {
362 let head_dim = out_features / heads;
363
364 let query_weight = Parameter::new(
365 randn(&[in_features, out_features]).expect("randn query_weight should succeed"),
366 );
367 let key_weight = Parameter::new(
368 randn(&[in_features, out_features]).expect("randn key_weight should succeed"),
369 );
370 let value_weight = Parameter::new(
371 randn(&[in_features, out_features]).expect("randn value_weight should succeed"),
372 );
373 let hyperedge_attention = Parameter::new(
374 randn(&[heads, 2 * head_dim]).expect("randn hyperedge_attention should succeed"),
375 );
376 let output_weight = Parameter::new(
377 randn(&[out_features, out_features]).expect("randn output_weight should succeed"),
378 );
379
380 let bias = if bias {
381 Some(Parameter::new(
382 zeros(&[out_features]).expect("zeros bias should succeed"),
383 ))
384 } else {
385 None
386 };
387
388 Self {
389 in_features,
390 out_features,
391 heads,
392 query_weight,
393 key_weight,
394 value_weight,
395 hyperedge_attention,
396 output_weight,
397 bias,
398 dropout,
399 }
400 }
401
402 pub fn forward(&self, hypergraph: &HypergraphData) -> HypergraphData {
404 let num_nodes = hypergraph.num_nodes;
405 let head_dim = self.out_features / self.heads;
406
407 let queries = hypergraph
409 .x
410 .matmul(&self.query_weight.clone_data())
411 .expect("operation should succeed");
412 let keys = hypergraph
413 .x
414 .matmul(&self.key_weight.clone_data())
415 .expect("operation should succeed");
416 let values = hypergraph
417 .x
418 .matmul(&self.value_weight.clone_data())
419 .expect("operation should succeed");
420
421 let q = queries
423 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
424 .expect("view should succeed");
425 let k = keys
426 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
427 .expect("view should succeed");
428 let v = values
429 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
430 .expect("view should succeed");
431
432 let attended_features = self.hyperedge_attention_mechanism(&q, &k, &v, hypergraph);
434
435 let concatenated = attended_features
437 .view(&[num_nodes as i32, self.out_features as i32])
438 .expect("view should succeed");
439 let mut output = concatenated
440 .matmul(&self.output_weight.clone_data())
441 .expect("operation should succeed");
442
443 if let Some(ref bias) = self.bias {
445 output = output
446 .add(&bias.clone_data())
447 .expect("operation should succeed");
448 }
449
450 HypergraphData {
452 x: output,
453 incidence_matrix: hypergraph.incidence_matrix.clone(),
454 hyperedge_weights: hypergraph.hyperedge_weights.clone(),
455 hyperedge_features: hypergraph.hyperedge_features.clone(),
456 node_degrees: hypergraph.node_degrees.clone(),
457 hyperedge_cardinalities: hypergraph.hyperedge_cardinalities.clone(),
458 num_nodes: hypergraph.num_nodes,
459 num_hyperedges: hypergraph.num_hyperedges,
460 }
461 }
462
463 fn hyperedge_attention_mechanism(
465 &self,
466 q: &Tensor,
467 k: &Tensor,
468 v: &Tensor,
469 hypergraph: &HypergraphData,
470 ) -> Tensor {
471 let num_nodes = hypergraph.num_nodes;
472 let head_dim = self.out_features / self.heads;
473
474 let mut output =
476 zeros(&[num_nodes, self.heads, head_dim]).expect("zeros output should succeed");
477
478 let incidence_data = hypergraph
479 .incidence_matrix
480 .to_vec()
481 .expect("conversion should succeed");
482
483 for e in 0..hypergraph.num_hyperedges {
485 let mut nodes_in_hyperedge = Vec::new();
486
487 for v in 0..num_nodes {
489 let idx = v * hypergraph.num_hyperedges + e;
490 if incidence_data[idx] > 0.0 {
491 nodes_in_hyperedge.push(v);
492 }
493 }
494
495 if nodes_in_hyperedge.len() < 2 {
496 continue; }
498
499 for head in 0..self.heads {
501 self.compute_hyperedge_attention(head, &nodes_in_hyperedge, q, k, v, &mut output);
502 }
503 }
504
505 output
506 }
507
508 fn compute_hyperedge_attention(
510 &self,
511 head: usize,
512 nodes: &[usize],
513 q: &Tensor,
514 k: &Tensor,
515 v: &Tensor,
516 output: &mut Tensor,
517 ) {
518 let head_dim = self.out_features / self.heads;
519 let scale = 1.0 / (head_dim as f32).sqrt();
520
521 for &node_i in nodes {
524 let mut aggregated = zeros(&[head_dim]).expect("zeros aggregated should succeed");
525 let mut total_weight = 0.0;
526
527 for &node_j in nodes {
528 if node_i != node_j {
529 let q_i = q
531 .slice_tensor(0, node_i, node_i + 1)
532 .expect("slice_tensor q_i should succeed")
533 .slice_tensor(1, head, head + 1)
534 .expect("slice_tensor q_i head should succeed")
535 .squeeze_tensor(0)
536 .expect("squeeze_tensor should succeed")
537 .squeeze_tensor(0)
538 .expect("squeeze_tensor should succeed");
539
540 let k_j = k
541 .slice_tensor(0, node_j, node_j + 1)
542 .expect("slice_tensor k_j should succeed")
543 .slice_tensor(1, head, head + 1)
544 .expect("slice_tensor k_j head should succeed")
545 .squeeze_tensor(0)
546 .expect("squeeze_tensor should succeed")
547 .squeeze_tensor(0)
548 .expect("squeeze_tensor should succeed");
549
550 let v_j = v
551 .slice_tensor(0, node_j, node_j + 1)
552 .expect("slice_tensor v_j should succeed")
553 .slice_tensor(1, head, head + 1)
554 .expect("slice_tensor v_j head should succeed")
555 .squeeze_tensor(0)
556 .expect("squeeze_tensor should succeed")
557 .squeeze_tensor(0)
558 .expect("squeeze_tensor should succeed");
559
560 let attention_score = q_i
562 .dot(&k_j)
563 .expect("dot should succeed")
564 .mul_scalar(scale)
565 .expect("mul_scalar should succeed");
566 let weight = attention_score
567 .exp()
568 .expect("exp should succeed")
569 .item()
570 .expect("tensor should have single item");
571
572 let weighted_value = v_j.mul_scalar(weight).expect("mul_scalar should succeed");
574 aggregated = aggregated
575 .add(&weighted_value)
576 .expect("operation should succeed");
577 total_weight += weight;
578 }
579 }
580
581 if total_weight > 0.0 {
583 aggregated = aggregated
584 .div_scalar(total_weight)
585 .expect("div_scalar should succeed");
586
587 let aggregated_data = aggregated.to_vec().expect("conversion should succeed");
589 for (j, &val) in aggregated_data.iter().enumerate() {
590 output
591 .set_item(&[node_i, head, j], val)
592 .expect("set_item should succeed");
593 }
594 }
595 }
596 }
597}
598
599impl GraphLayer for HyperGATConv {
600 fn forward(&self, graph: &GraphData) -> GraphData {
601 let hypergraph = graph_to_hypergraph(graph);
602 let output_hypergraph = self.forward(&hypergraph);
603 output_hypergraph.to_graph_clique_expansion()
604 }
605
606 fn parameters(&self) -> Vec<Tensor> {
607 let mut params = vec![
608 self.query_weight.clone_data(),
609 self.key_weight.clone_data(),
610 self.value_weight.clone_data(),
611 self.hyperedge_attention.clone_data(),
612 self.output_weight.clone_data(),
613 ];
614
615 if let Some(ref bias) = self.bias {
616 params.push(bias.clone_data());
617 }
618
619 params
620 }
621}
622
623#[derive(Debug)]
625pub struct HGNNConv {
626 in_features: usize,
627 out_features: usize,
628 weight: Parameter,
629 bias: Option<Parameter>,
630 use_spectral: bool,
631}
632
633impl HGNNConv {
634 pub fn new(in_features: usize, out_features: usize, bias: bool, use_spectral: bool) -> Self {
636 let weight = Parameter::new(
637 randn(&[in_features, out_features]).expect("randn weight should succeed"),
638 );
639 let bias = if bias {
640 Some(Parameter::new(
641 zeros(&[out_features]).expect("zeros bias should succeed"),
642 ))
643 } else {
644 None
645 };
646
647 Self {
648 in_features,
649 out_features,
650 weight,
651 bias,
652 use_spectral,
653 }
654 }
655
656 pub fn forward(&self, hypergraph: &HypergraphData) -> HypergraphData {
658 let x_transformed = hypergraph
660 .x
661 .matmul(&self.weight.clone_data())
662 .expect("operation should succeed");
663
664 let output_features = if self.use_spectral {
666 self.spectral_convolution(&x_transformed, hypergraph)
667 } else {
668 self.spatial_convolution(&x_transformed, hypergraph)
669 };
670
671 let final_features = if let Some(ref bias) = self.bias {
673 output_features
674 .add(&bias.clone_data())
675 .expect("operation should succeed")
676 } else {
677 output_features
678 };
679
680 HypergraphData {
681 x: final_features,
682 incidence_matrix: hypergraph.incidence_matrix.clone(),
683 hyperedge_weights: hypergraph.hyperedge_weights.clone(),
684 hyperedge_features: hypergraph.hyperedge_features.clone(),
685 node_degrees: hypergraph.node_degrees.clone(),
686 hyperedge_cardinalities: hypergraph.hyperedge_cardinalities.clone(),
687 num_nodes: hypergraph.num_nodes,
688 num_hyperedges: hypergraph.num_hyperedges,
689 }
690 }
691
692 fn spectral_convolution(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
694 let laplacian = self.compute_hypergraph_laplacian(hypergraph);
696
697 laplacian
699 .matmul(features)
700 .expect("operation should succeed")
701 }
702
703 fn spatial_convolution(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
705 let incidence_t = hypergraph
707 .incidence_matrix
708 .transpose(0, 1)
709 .expect("transpose should succeed");
710 let hyperedge_features = incidence_t
711 .matmul(features)
712 .expect("operation should succeed");
713
714 let aggregated = hypergraph
716 .incidence_matrix
717 .matmul(&hyperedge_features)
718 .expect("operation should succeed");
719
720 self.normalize_by_degrees(&aggregated, hypergraph)
722 }
723
724 fn compute_hypergraph_laplacian(&self, hypergraph: &HypergraphData) -> Tensor {
726 let h = &hypergraph.incidence_matrix;
727 let num_nodes = hypergraph.num_nodes;
728
729 let node_degrees = h
731 .sum_dim(&[1], false)
732 .expect("sum_dim node_degrees should succeed");
733 let hyperedge_degrees = h
734 .sum_dim(&[0], false)
735 .expect("sum_dim hyperedge_degrees should succeed");
736
737 let mut d_v = zeros(&[num_nodes, num_nodes]).expect("zeros d_v should succeed");
739 let mut d_e = zeros(&[hypergraph.num_hyperedges, hypergraph.num_hyperedges])
740 .expect("zeros d_e should succeed");
741
742 let node_deg_data = node_degrees.to_vec().expect("conversion should succeed");
743 let hyperedge_deg_data = hyperedge_degrees
744 .to_vec()
745 .expect("conversion should succeed");
746
747 for i in 0..num_nodes {
749 let degree = node_deg_data[i].max(1e-8); d_v.set_item(&[i, i], degree.powf(-0.5))
751 .expect("set_item d_v should succeed");
752 }
753
754 for i in 0..hypergraph.num_hyperedges {
755 let degree = hyperedge_deg_data[i].max(1e-8);
756 d_e.set_item(&[i, i], degree.recip())
757 .expect("set_item d_e should succeed");
758 }
759
760 let h_t = h.transpose(0, 1).expect("transpose should succeed");
762 let intermediate = d_v
763 .matmul(h)
764 .expect("operation should succeed")
765 .matmul(&d_e)
766 .expect("operation should succeed")
767 .matmul(&h_t)
768 .expect("operation should succeed")
769 .matmul(&d_v)
770 .expect("operation should succeed");
771
772 let identity = eye(num_nodes);
773 identity
774 .sub(&intermediate)
775 .expect("operation should succeed")
776 }
777
778 fn normalize_by_degrees(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
780 let degrees = &hypergraph.node_degrees;
781 let epsilon = 1e-8;
782
783 let safe_degrees = degrees
784 .add_scalar(epsilon)
785 .expect("add_scalar should succeed");
786 let inv_sqrt_degrees = safe_degrees
787 .pow_scalar(-0.5)
788 .expect("pow_scalar should succeed");
789
790 let inv_degrees_squeezed = if inv_sqrt_degrees.shape().dims().len() > 1 {
792 inv_sqrt_degrees
793 .squeeze_tensor(1)
794 .expect("squeeze_tensor should succeed")
795 } else {
796 inv_sqrt_degrees
797 };
798 let inv_degrees_expanded = inv_degrees_squeezed
799 .unsqueeze(-1)
800 .expect("unsqueeze should succeed");
801
802 features
803 .mul(&inv_degrees_expanded)
804 .expect("operation should succeed")
805 }
806}
807
808impl GraphLayer for HGNNConv {
809 fn forward(&self, graph: &GraphData) -> GraphData {
810 let hypergraph = graph_to_hypergraph(graph);
811 let output_hypergraph = self.forward(&hypergraph);
812 output_hypergraph.to_graph_clique_expansion()
813 }
814
815 fn parameters(&self) -> Vec<Tensor> {
816 let mut params = vec![self.weight.clone_data()];
817 if let Some(ref bias) = self.bias {
818 params.push(bias.clone_data());
819 }
820 params
821 }
822}
823
824pub mod pooling {
826 use super::*;
827
828 pub fn global_hypergraph_pool(hypergraph: &HypergraphData, method: PoolingMethod) -> Tensor {
830 match method {
831 PoolingMethod::Mean => hypergraph
832 .x
833 .mean(Some(&[0]), false)
834 .expect("mean pooling should succeed"),
835 PoolingMethod::Max => hypergraph
836 .x
837 .max(Some(0), false)
838 .expect("max pooling should succeed"),
839 PoolingMethod::Sum => hypergraph
840 .x
841 .sum_dim(&[0], false)
842 .expect("sum pooling should succeed"),
843 PoolingMethod::Attention => attention_pool(hypergraph),
844 }
845 }
846
847 pub fn hyperedge_pool(hypergraph: &HypergraphData, method: PoolingMethod) -> Tensor {
849 let incidence_t = hypergraph
850 .incidence_matrix
851 .transpose(0, 1)
852 .expect("transpose should succeed");
853
854 match method {
855 PoolingMethod::Mean => {
856 let hyperedge_features = incidence_t
858 .matmul(&hypergraph.x)
859 .expect("operation should succeed");
860 hyperedge_features
861 .mean(Some(&[0]), false)
862 .expect("mean pooling should succeed")
863 }
864 PoolingMethod::Max => {
865 let hyperedge_features = incidence_t
866 .matmul(&hypergraph.x)
867 .expect("operation should succeed");
868 hyperedge_features
869 .max(Some(0), false)
870 .expect("max pooling should succeed")
871 }
872 PoolingMethod::Sum => {
873 let hyperedge_features = incidence_t
874 .matmul(&hypergraph.x)
875 .expect("operation should succeed");
876 hyperedge_features
877 .sum_dim(&[0], false)
878 .expect("sum pooling should succeed")
879 }
880 PoolingMethod::Attention => {
881 attention_pool(hypergraph)
883 }
884 }
885 }
886
887 pub fn hierarchical_hypergraph_pool(
889 hypergraph: &HypergraphData,
890 num_clusters: usize,
891 ) -> HypergraphData {
892 let cluster_assignments = cluster_nodes(hypergraph, num_clusters);
894 coarsen_hypergraph(hypergraph, &cluster_assignments)
895 }
896
897 fn attention_pool(hypergraph: &HypergraphData) -> Tensor {
899 let attention_scores = hypergraph
901 .x
902 .sum_dim(&[1], false)
903 .expect("sum_dim should succeed");
904 let attention_weights = attention_scores.softmax(0).expect("softmax should succeed");
905 let attention_expanded = attention_weights
906 .unsqueeze(-1)
907 .expect("unsqueeze should succeed");
908
909 let weighted_features = hypergraph
910 .x
911 .mul(&attention_expanded)
912 .expect("operation should succeed");
913 weighted_features
914 .sum_dim(&[0], false)
915 .expect("sum_dim should succeed")
916 }
917
918 fn cluster_nodes(hypergraph: &HypergraphData, num_clusters: usize) -> Vec<usize> {
920 let num_nodes = hypergraph.num_nodes;
921 let mut assignments = vec![0; num_nodes];
922
923 for i in 0..num_nodes {
925 assignments[i] = i % num_clusters;
926 }
927
928 assignments
929 }
930
931 fn coarsen_hypergraph(
933 hypergraph: &HypergraphData,
934 cluster_assignments: &[usize],
935 ) -> HypergraphData {
936 let num_clusters = cluster_assignments
937 .iter()
938 .max()
939 .expect("reduction should succeed")
940 + 1;
941 let original_features = hypergraph.x.shape().dims()[1];
942
943 let mut coarse_features_data = vec![0.0; num_clusters * original_features];
945 let mut cluster_counts = vec![0; num_clusters];
946
947 let node_data = hypergraph.x.to_vec().expect("conversion should succeed");
948
949 for (node, &cluster) in cluster_assignments.iter().enumerate() {
950 cluster_counts[cluster] += 1;
951 for feat in 0..original_features {
952 let node_feat_idx = node * original_features + feat;
953 let cluster_feat_idx = cluster * original_features + feat;
954 coarse_features_data[cluster_feat_idx] += node_data[node_feat_idx];
955 }
956 }
957
958 for cluster in 0..num_clusters {
960 if cluster_counts[cluster] > 0 {
961 for feat in 0..original_features {
962 let cluster_feat_idx = cluster * original_features + feat;
963 coarse_features_data[cluster_feat_idx] /= cluster_counts[cluster] as f32;
964 }
965 }
966 }
967
968 let coarse_features = from_vec(
969 coarse_features_data,
970 &[num_clusters, original_features],
971 torsh_core::device::DeviceType::Cpu,
972 )
973 .expect("from_vec coarse_features should succeed");
974
975 let coarse_incidence = zeros(&[num_clusters, hypergraph.num_hyperedges])
977 .expect("zeros coarse_incidence should succeed");
978
979 HypergraphData::new(coarse_features, coarse_incidence)
980 }
981
982 #[derive(Debug, Clone, Copy)]
984 pub enum PoolingMethod {
985 Mean,
986 Max,
987 Sum,
988 Attention,
989 }
990}
991
992pub mod utils {
994 use super::*;
995
996 pub fn edge_list_to_hypergraph(
998 edges: &[(Vec<usize>, f32)],
999 num_nodes: usize,
1000 ) -> HypergraphData {
1001 let num_hyperedges = edges.len();
1002 let mut incidence_data = vec![0.0; num_nodes * num_hyperedges];
1003 let mut weights = Vec::new();
1004
1005 for (e, (edge_nodes, weight)) in edges.iter().enumerate() {
1006 weights.push(*weight);
1007 for &node in edge_nodes {
1008 if node < num_nodes {
1009 incidence_data[node * num_hyperedges + e] = 1.0;
1010 }
1011 }
1012 }
1013
1014 let features = randn(&[num_nodes, 16]).expect("randn features should succeed"); let incidence_matrix = from_vec(
1016 incidence_data,
1017 &[num_nodes, num_hyperedges],
1018 torsh_core::device::DeviceType::Cpu,
1019 )
1020 .expect("from_vec incidence_matrix should succeed");
1021
1022 let hyperedge_weights = from_vec(
1023 weights,
1024 &[num_hyperedges],
1025 torsh_core::device::DeviceType::Cpu,
1026 )
1027 .expect("from_vec hyperedge_weights should succeed");
1028
1029 HypergraphData::new(features, incidence_matrix).with_hyperedge_weights(hyperedge_weights)
1030 }
1031
1032 pub fn random_hypergraph(
1034 num_nodes: usize,
1035 num_hyperedges: usize,
1036 edge_prob: f32,
1037 features_dim: usize,
1038 ) -> HypergraphData {
1039 let mut rng = scirs2_core::random::thread_rng();
1040 let mut incidence_data = vec![0.0; num_nodes * num_hyperedges];
1041
1042 for e in 0..num_hyperedges {
1044 for v in 0..num_nodes {
1045 if rng.gen_range(0.0..1.0) < edge_prob {
1046 incidence_data[v * num_hyperedges + e] = 1.0;
1047 }
1048 }
1049 }
1050
1051 let features = randn(&[num_nodes, features_dim]).expect("randn features should succeed");
1052 let incidence_matrix = from_vec(
1053 incidence_data,
1054 &[num_nodes, num_hyperedges],
1055 torsh_core::device::DeviceType::Cpu,
1056 )
1057 .expect("from_vec incidence_matrix should succeed");
1058
1059 HypergraphData::new(features, incidence_matrix)
1060 }
1061
1062 pub fn hypergraph_metrics(hypergraph: &HypergraphData) -> HypergraphMetrics {
1064 let node_degrees = hypergraph
1065 .node_degrees
1066 .to_vec()
1067 .expect("conversion should succeed");
1068 let hyperedge_cardinalities = hypergraph
1069 .hyperedge_cardinalities
1070 .to_vec()
1071 .expect("conversion should succeed");
1072
1073 let avg_node_degree = node_degrees.iter().sum::<f32>() / node_degrees.len() as f32;
1074 let avg_hyperedge_size =
1075 hyperedge_cardinalities.iter().sum::<f32>() / hyperedge_cardinalities.len() as f32;
1076
1077 let density = node_degrees.iter().sum::<f32>()
1078 / (hypergraph.num_nodes * hypergraph.num_hyperedges) as f32;
1079
1080 HypergraphMetrics {
1081 avg_node_degree,
1082 avg_hyperedge_size,
1083 density,
1084 num_nodes: hypergraph.num_nodes,
1085 num_hyperedges: hypergraph.num_hyperedges,
1086 }
1087 }
1088
1089 #[derive(Debug, Clone)]
1091 pub struct HypergraphMetrics {
1092 pub avg_node_degree: f32,
1093 pub avg_hyperedge_size: f32,
1094 pub density: f32,
1095 pub num_nodes: usize,
1096 pub num_hyperedges: usize,
1097 }
1098}
1099
1100pub fn graph_to_hypergraph(graph: &GraphData) -> HypergraphData {
1102 let edge_data = crate::utils::tensor_to_vec2::<f32>(&graph.edge_index)
1103 .expect("tensor_to_vec2 should succeed");
1104 let num_edges = edge_data[0].len();
1105 let num_nodes = graph.num_nodes;
1106
1107 let mut incidence_data = vec![0.0; num_nodes * num_edges];
1109
1110 for e in 0..num_edges {
1111 let src = edge_data[0][e] as usize;
1112 let dst = edge_data[1][e] as usize;
1113
1114 if src < num_nodes && dst < num_nodes {
1115 incidence_data[src * num_edges + e] = 1.0;
1116 incidence_data[dst * num_edges + e] = 1.0;
1117 }
1118 }
1119
1120 let incidence_matrix = from_vec(
1121 incidence_data,
1122 &[num_nodes, num_edges],
1123 torsh_core::device::DeviceType::Cpu,
1124 )
1125 .expect("from_vec incidence_matrix should succeed");
1126
1127 HypergraphData::new(graph.x.clone(), incidence_matrix)
1128}
1129
1130fn eye(n: usize) -> Tensor {
1132 let mut data = vec![0.0; n * n];
1133 for i in 0..n {
1134 data[i * n + i] = 1.0;
1135 }
1136 from_vec(data, &[n, n], torsh_core::device::DeviceType::Cpu)
1137 .expect("from_vec eye should succeed")
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142 use super::*;
1143 use torsh_core::device::DeviceType;
1144
1145 #[test]
1146 fn test_hypergraph_creation() {
1147 let features = randn(&[4, 3]).unwrap();
1148 let incidence_data = vec![
1149 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, ];
1154 let incidence_matrix = from_vec(incidence_data, &[4, 3], DeviceType::Cpu).unwrap();
1155
1156 let hypergraph = HypergraphData::new(features, incidence_matrix);
1157
1158 assert_eq!(hypergraph.num_nodes, 4);
1159 assert_eq!(hypergraph.num_hyperedges, 3);
1160 assert_eq!(hypergraph.x.shape().dims(), &[4, 3]);
1161 assert_eq!(hypergraph.incidence_matrix.shape().dims(), &[4, 3]);
1162 }
1163
1164 #[test]
1165 fn test_hgcn_layer() {
1166 let features = randn(&[3, 4]).unwrap();
1167 let incidence_matrix =
1168 from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0], &[3, 2], DeviceType::Cpu).unwrap();
1169 let hypergraph = HypergraphData::new(features, incidence_matrix);
1170
1171 let hgcn = HGCNConv::new(4, 8, true, false, 0.1);
1172 let output = hgcn.forward(&hypergraph);
1173
1174 assert_eq!(output.x.shape().dims(), &[3, 8]);
1175 assert_eq!(output.num_nodes, 3);
1176 assert_eq!(output.num_hyperedges, 2);
1177 }
1178
1179 #[test]
1180 fn test_hypergraph_to_graph_conversion() {
1181 let features = randn(&[3, 4]).unwrap();
1182 let incidence_matrix =
1183 from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0], &[3, 2], DeviceType::Cpu).unwrap();
1184 let hypergraph = HypergraphData::new(features, incidence_matrix);
1185
1186 let graph = hypergraph.to_graph_clique_expansion();
1187 assert_eq!(graph.num_nodes, 3);
1188
1189 let star_graph = hypergraph.to_graph_star_expansion();
1190 assert_eq!(star_graph.num_nodes, 5); }
1192
1193 #[test]
1194 fn test_hypergraph_pooling() {
1195 let features = randn(&[4, 6]).unwrap();
1196 let incidence_matrix = from_vec(
1197 vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
1198 &[4, 2],
1199 DeviceType::Cpu,
1200 )
1201 .unwrap();
1202 let hypergraph = HypergraphData::new(features, incidence_matrix);
1203
1204 let pooled_mean =
1205 pooling::global_hypergraph_pool(&hypergraph, pooling::PoolingMethod::Mean);
1206 assert_eq!(pooled_mean.shape().dims(), &[6]);
1207
1208 let pooled_max = pooling::global_hypergraph_pool(&hypergraph, pooling::PoolingMethod::Max);
1209 assert_eq!(pooled_max.shape().dims(), &[6]);
1210 }
1211
1212 #[test]
1213 fn test_hypergraph_utils() {
1214 let edges = vec![
1215 (vec![0, 1, 2], 1.0),
1216 (vec![1, 3], 0.8),
1217 (vec![0, 2, 3], 1.2),
1218 ];
1219
1220 let hypergraph = utils::edge_list_to_hypergraph(&edges, 4);
1221 assert_eq!(hypergraph.num_nodes, 4);
1222 assert_eq!(hypergraph.num_hyperedges, 3);
1223
1224 let metrics = utils::hypergraph_metrics(&hypergraph);
1225 assert!(metrics.avg_node_degree > 0.0);
1226 assert!(metrics.avg_hyperedge_size > 0.0);
1227 }
1228
1229 #[test]
1230 fn test_random_hypergraph_generation() {
1231 let hypergraph = utils::random_hypergraph(5, 3, 0.6, 8);
1232 assert_eq!(hypergraph.num_nodes, 5);
1233 assert_eq!(hypergraph.num_hyperedges, 3);
1234 assert_eq!(hypergraph.x.shape().dims(), &[5, 8]);
1235 }
1236
1237 #[test]
1238 fn test_hypergat_layer() {
1239 let features = randn(&[4, 6]).unwrap();
1240 let incidence_matrix = from_vec(
1241 vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
1242 &[4, 2],
1243 DeviceType::Cpu,
1244 )
1245 .unwrap();
1246 let hypergraph = HypergraphData::new(features, incidence_matrix);
1247
1248 let hypergat = HyperGATConv::new(6, 12, 3, 0.1, true);
1249 let output = hypergat.forward(&hypergraph);
1250
1251 assert_eq!(output.x.shape().dims(), &[4, 12]);
1252 assert_eq!(output.num_nodes, 4);
1253 }
1254}