Skip to main content

torsh_graph/
hypergraph.rs

1//! Hypergraph Neural Networks
2//!
3//! Advanced implementation of hypergraph neural networks for multi-relational learning.
4//! Hypergraphs generalize graphs by allowing edges (hyperedges) to connect any number of nodes,
5//! enabling modeling of complex multi-way relationships in data.
6//!
7//! # Features:
8//! - Hypergraph data structures with efficient storage
9//! - Multiple hypergraph convolution layers (HGCN, HyperGAT, HGNN)
10//! - Hypergraph attention mechanisms
11//! - Advanced pooling and coarsening operations
12//! - Spectral hypergraph methods
13//! - Dynamic hypergraph construction
14
15// Framework infrastructure - components designed for future use
16#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use torsh_tensor::{
20    creation::{from_vec, randn, zeros},
21    Tensor,
22};
23
24/// Hypergraph data structure representing multi-way relationships
25#[derive(Debug, Clone)]
26pub struct HypergraphData {
27    /// Node feature matrix (num_nodes x num_features)
28    pub x: Tensor,
29    /// Hyperedge incidence matrix (num_nodes x num_hyperedges)
30    pub incidence_matrix: Tensor,
31    /// Hyperedge weights (optional)
32    pub hyperedge_weights: Option<Tensor>,
33    /// Hyperedge features (optional)
34    pub hyperedge_features: Option<Tensor>,
35    /// Node degrees (sum of incident hyperedge weights)
36    pub node_degrees: Tensor,
37    /// Hyperedge cardinalities (number of nodes per hyperedge)
38    pub hyperedge_cardinalities: Tensor,
39    /// Number of nodes
40    pub num_nodes: usize,
41    /// Number of hyperedges
42    pub num_hyperedges: usize,
43}
44
45impl HypergraphData {
46    /// Create a new hypergraph from node features and incidence matrix
47    pub fn new(x: Tensor, incidence_matrix: Tensor) -> Self {
48        let num_nodes = x.shape().dims()[0];
49        let num_hyperedges = incidence_matrix.shape().dims()[1];
50
51        // Compute node degrees (sum over hyperedges - axis 1)
52        let node_degrees = incidence_matrix
53            .sum_dim(&[1], false)
54            .expect("sum_dim node_degrees should succeed");
55
56        // Compute hyperedge cardinalities (sum over nodes - axis 0)
57        let hyperedge_cardinalities = incidence_matrix
58            .sum_dim(&[0], false)
59            .expect("sum_dim hyperedge_cardinalities should succeed");
60
61        Self {
62            x,
63            incidence_matrix,
64            hyperedge_weights: None,
65            hyperedge_features: None,
66            node_degrees,
67            hyperedge_cardinalities,
68            num_nodes,
69            num_hyperedges,
70        }
71    }
72
73    /// Add hyperedge weights
74    pub fn with_hyperedge_weights(mut self, weights: Tensor) -> Self {
75        self.hyperedge_weights = Some(weights);
76        self
77    }
78
79    /// Add hyperedge features
80    pub fn with_hyperedge_features(mut self, features: Tensor) -> Self {
81        self.hyperedge_features = Some(features);
82        self
83    }
84
85    /// Convert to regular graph using clique expansion
86    pub fn to_graph_clique_expansion(&self) -> GraphData {
87        let incidence_data = self
88            .incidence_matrix
89            .to_vec()
90            .expect("conversion should succeed");
91        let mut edges = Vec::new();
92
93        // For each hyperedge, create clique (all pairs of nodes)
94        for e in 0..self.num_hyperedges {
95            let mut nodes_in_hyperedge = Vec::new();
96
97            // Find nodes in this hyperedge
98            for v in 0..self.num_nodes {
99                let idx = v * self.num_hyperedges + e;
100                if incidence_data[idx] > 0.0 {
101                    nodes_in_hyperedge.push(v as f32);
102                }
103            }
104
105            // Create all pairs within the hyperedge
106            for i in 0..nodes_in_hyperedge.len() {
107                for j in (i + 1)..nodes_in_hyperedge.len() {
108                    edges.extend_from_slice(&[nodes_in_hyperedge[i], nodes_in_hyperedge[j]]);
109                    edges.extend_from_slice(&[nodes_in_hyperedge[j], nodes_in_hyperedge[i]]);
110                }
111            }
112        }
113
114        let edge_index = if edges.is_empty() {
115            zeros(&[2, 0]).expect("zeros empty edge_index should succeed")
116        } else {
117            let num_edges = edges.len() / 2;
118            from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
119                .expect("from_vec edge_index should succeed")
120        };
121
122        GraphData::new(self.x.clone(), edge_index)
123    }
124
125    /// Convert to regular graph using star expansion
126    pub fn to_graph_star_expansion(&self) -> GraphData {
127        let incidence_data = self
128            .incidence_matrix
129            .to_vec()
130            .expect("conversion should succeed");
131        let mut edges = Vec::new();
132
133        // For each hyperedge, create a star with center at virtual node
134        let virtual_node_offset = self.num_nodes;
135
136        for e in 0..self.num_hyperedges {
137            let virtual_node = (virtual_node_offset + e) as f32;
138
139            // Connect all nodes in hyperedge to virtual center
140            for v in 0..self.num_nodes {
141                let idx = v * self.num_hyperedges + e;
142                if incidence_data[idx] > 0.0 {
143                    let node = v as f32;
144                    edges.extend_from_slice(&[node, virtual_node]);
145                    edges.extend_from_slice(&[virtual_node, node]);
146                }
147            }
148        }
149
150        let edge_index = if edges.is_empty() {
151            zeros(&[2, 0]).expect("zeros empty edge_index should succeed")
152        } else {
153            let num_edges = edges.len() / 2;
154            from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
155                .expect("from_vec edge_index should succeed")
156        };
157
158        // Extend node features with virtual nodes
159        let virtual_features: Tensor = randn(&[self.num_hyperedges, self.x.shape().dims()[1]])
160            .expect("randn virtual_features should succeed");
161        // Concatenate original and virtual node features
162        let node_data = self.x.to_vec().expect("conversion should succeed");
163        let virtual_data = virtual_features
164            .to_vec()
165            .expect("conversion should succeed");
166        let mut extended_data = node_data;
167        extended_data.extend(virtual_data);
168
169        let total_nodes = self.num_nodes + self.num_hyperedges;
170        let features_dim = self.x.shape().dims()[1];
171        let extended_x = from_vec(
172            extended_data,
173            &[total_nodes, features_dim],
174            torsh_core::device::DeviceType::Cpu,
175        )
176        .expect("from_vec extended_x should succeed");
177
178        GraphData::new(extended_x, edge_index)
179    }
180}
181
182/// Hypergraph Convolutional Network (HGCN) layer
183#[derive(Debug)]
184pub struct HGCNConv {
185    in_features: usize,
186    out_features: usize,
187    weight: Parameter,
188    bias: Option<Parameter>,
189    use_attention: bool,
190    attention_weight: Option<Parameter>,
191    dropout: f32,
192}
193
194impl HGCNConv {
195    /// Create a new HGCN layer
196    pub fn new(
197        in_features: usize,
198        out_features: usize,
199        bias: bool,
200        use_attention: bool,
201        dropout: f32,
202    ) -> Self {
203        let weight = Parameter::new(
204            randn(&[in_features, out_features]).expect("randn weight should succeed"),
205        );
206        let bias = if bias {
207            Some(Parameter::new(
208                zeros(&[out_features]).expect("zeros bias should succeed"),
209            ))
210        } else {
211            None
212        };
213
214        let attention_weight = if use_attention {
215            Some(Parameter::new(
216                randn(&[out_features]).expect("randn attention_weight should succeed"),
217            ))
218        } else {
219            None
220        };
221
222        Self {
223            in_features,
224            out_features,
225            weight,
226            bias,
227            use_attention,
228            attention_weight,
229            dropout,
230        }
231    }
232
233    /// Forward pass through HGCN layer
234    pub fn forward(&self, hypergraph: &HypergraphData) -> HypergraphData {
235        // Simplified implementation for API compatibility
236        // Step 1: Transform node features
237        let node_features_transformed = hypergraph
238            .x
239            .matmul(&self.weight.clone_data())
240            .expect("operation should succeed");
241
242        // Step 2: Simplified hypergraph convolution (skip complex aggregation for now)
243        let output_features = if let Some(ref bias) = self.bias {
244            node_features_transformed
245                .add(&bias.clone_data())
246                .expect("operation should succeed")
247        } else {
248            node_features_transformed
249        };
250
251        // Create output hypergraph with updated node features
252        HypergraphData {
253            x: output_features,
254            incidence_matrix: hypergraph.incidence_matrix.clone(),
255            hyperedge_weights: hypergraph.hyperedge_weights.clone(),
256            hyperedge_features: hypergraph.hyperedge_features.clone(),
257            node_degrees: hypergraph.node_degrees.clone(),
258            hyperedge_cardinalities: hypergraph.hyperedge_cardinalities.clone(),
259            num_nodes: hypergraph.num_nodes,
260            num_hyperedges: hypergraph.num_hyperedges,
261        }
262    }
263
264    /// Apply attention mechanism to hyperedge features
265    fn apply_attention(&self, hyperedge_features: &Tensor, _hypergraph: &HypergraphData) -> Tensor {
266        if let Some(ref attention_weight) = self.attention_weight {
267            // Compute attention scores
268            let attention_scores = hyperedge_features
269                .matmul(&attention_weight.clone_data())
270                .expect("operation should succeed");
271            let attention_probs = attention_scores
272                .softmax(-1)
273                .expect("softmax should succeed");
274
275            // Apply attention to features
276            let attention_expanded = attention_probs
277                .unsqueeze(-1)
278                .expect("unsqueeze should succeed");
279            hyperedge_features
280                .mul(&attention_expanded)
281                .expect("operation should succeed")
282        } else {
283            hyperedge_features.clone()
284        }
285    }
286
287    /// Normalize aggregated features by node degrees
288    fn normalize_by_degrees(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
289        let degrees = &hypergraph.node_degrees;
290        let epsilon = 1e-8;
291
292        // Add epsilon to prevent division by zero
293        let safe_degrees = degrees
294            .add_scalar(epsilon)
295            .expect("add_scalar should succeed");
296        let inv_degrees = safe_degrees
297            .reciprocal()
298            .expect("reciprocal should succeed");
299
300        // Expand inverse degrees to match feature dimensions
301        // First squeeze to ensure we have shape [num_nodes] rather than [num_nodes, 1]
302        let inv_degrees_squeezed = if inv_degrees.shape().dims().len() > 1 {
303            inv_degrees
304                .squeeze_tensor(1)
305                .expect("squeeze_tensor should succeed")
306        } else {
307            inv_degrees
308        };
309        let inv_degrees_expanded = inv_degrees_squeezed
310            .unsqueeze(-1)
311            .expect("unsqueeze should succeed");
312        features
313            .mul(&inv_degrees_expanded)
314            .expect("operation should succeed")
315    }
316}
317
318impl GraphLayer for HGCNConv {
319    fn forward(&self, graph: &GraphData) -> GraphData {
320        // Convert regular graph to hypergraph and back for compatibility
321        let hypergraph = graph_to_hypergraph(graph);
322        let output_hypergraph = self.forward(&hypergraph);
323        output_hypergraph.to_graph_clique_expansion()
324    }
325
326    fn parameters(&self) -> Vec<Tensor> {
327        let mut params = vec![self.weight.clone_data()];
328        if let Some(ref bias) = self.bias {
329            params.push(bias.clone_data());
330        }
331        if let Some(ref attention_weight) = self.attention_weight {
332            params.push(attention_weight.clone_data());
333        }
334        params
335    }
336}
337
338/// Hypergraph Attention Network (HyperGAT) layer
339#[derive(Debug)]
340pub struct HyperGATConv {
341    in_features: usize,
342    out_features: usize,
343    heads: usize,
344    query_weight: Parameter,
345    key_weight: Parameter,
346    value_weight: Parameter,
347    hyperedge_attention: Parameter,
348    output_weight: Parameter,
349    bias: Option<Parameter>,
350    dropout: f32,
351}
352
353impl HyperGATConv {
354    /// Create a new HyperGAT layer
355    pub fn new(
356        in_features: usize,
357        out_features: usize,
358        heads: usize,
359        dropout: f32,
360        bias: bool,
361    ) -> Self {
362        let head_dim = out_features / heads;
363
364        let query_weight = Parameter::new(
365            randn(&[in_features, out_features]).expect("randn query_weight should succeed"),
366        );
367        let key_weight = Parameter::new(
368            randn(&[in_features, out_features]).expect("randn key_weight should succeed"),
369        );
370        let value_weight = Parameter::new(
371            randn(&[in_features, out_features]).expect("randn value_weight should succeed"),
372        );
373        let hyperedge_attention = Parameter::new(
374            randn(&[heads, 2 * head_dim]).expect("randn hyperedge_attention should succeed"),
375        );
376        let output_weight = Parameter::new(
377            randn(&[out_features, out_features]).expect("randn output_weight should succeed"),
378        );
379
380        let bias = if bias {
381            Some(Parameter::new(
382                zeros(&[out_features]).expect("zeros bias should succeed"),
383            ))
384        } else {
385            None
386        };
387
388        Self {
389            in_features,
390            out_features,
391            heads,
392            query_weight,
393            key_weight,
394            value_weight,
395            hyperedge_attention,
396            output_weight,
397            bias,
398            dropout,
399        }
400    }
401
402    /// Forward pass through HyperGAT layer
403    pub fn forward(&self, hypergraph: &HypergraphData) -> HypergraphData {
404        let num_nodes = hypergraph.num_nodes;
405        let head_dim = self.out_features / self.heads;
406
407        // Linear transformations
408        let queries = hypergraph
409            .x
410            .matmul(&self.query_weight.clone_data())
411            .expect("operation should succeed");
412        let keys = hypergraph
413            .x
414            .matmul(&self.key_weight.clone_data())
415            .expect("operation should succeed");
416        let values = hypergraph
417            .x
418            .matmul(&self.value_weight.clone_data())
419            .expect("operation should succeed");
420
421        // Reshape for multi-head attention
422        let q = queries
423            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
424            .expect("view should succeed");
425        let k = keys
426            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
427            .expect("view should succeed");
428        let v = values
429            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
430            .expect("view should succeed");
431
432        // Perform hyperedge-based attention
433        let attended_features = self.hyperedge_attention_mechanism(&q, &k, &v, hypergraph);
434
435        // Reshape back and apply output transformation
436        let concatenated = attended_features
437            .view(&[num_nodes as i32, self.out_features as i32])
438            .expect("view should succeed");
439        let mut output = concatenated
440            .matmul(&self.output_weight.clone_data())
441            .expect("operation should succeed");
442
443        // Add bias if present
444        if let Some(ref bias) = self.bias {
445            output = output
446                .add(&bias.clone_data())
447                .expect("operation should succeed");
448        }
449
450        // Create output hypergraph
451        HypergraphData {
452            x: output,
453            incidence_matrix: hypergraph.incidence_matrix.clone(),
454            hyperedge_weights: hypergraph.hyperedge_weights.clone(),
455            hyperedge_features: hypergraph.hyperedge_features.clone(),
456            node_degrees: hypergraph.node_degrees.clone(),
457            hyperedge_cardinalities: hypergraph.hyperedge_cardinalities.clone(),
458            num_nodes: hypergraph.num_nodes,
459            num_hyperedges: hypergraph.num_hyperedges,
460        }
461    }
462
463    /// Hyperedge-based attention mechanism
464    fn hyperedge_attention_mechanism(
465        &self,
466        q: &Tensor,
467        k: &Tensor,
468        v: &Tensor,
469        hypergraph: &HypergraphData,
470    ) -> Tensor {
471        let num_nodes = hypergraph.num_nodes;
472        let head_dim = self.out_features / self.heads;
473
474        // Initialize output
475        let mut output =
476            zeros(&[num_nodes, self.heads, head_dim]).expect("zeros output should succeed");
477
478        let incidence_data = hypergraph
479            .incidence_matrix
480            .to_vec()
481            .expect("conversion should succeed");
482
483        // Process each hyperedge separately
484        for e in 0..hypergraph.num_hyperedges {
485            let mut nodes_in_hyperedge = Vec::new();
486
487            // Find nodes in this hyperedge
488            for v in 0..num_nodes {
489                let idx = v * hypergraph.num_hyperedges + e;
490                if incidence_data[idx] > 0.0 {
491                    nodes_in_hyperedge.push(v);
492                }
493            }
494
495            if nodes_in_hyperedge.len() < 2 {
496                continue; // Skip hyperedges with less than 2 nodes
497            }
498
499            // Compute attention within hyperedge for each head
500            for head in 0..self.heads {
501                self.compute_hyperedge_attention(head, &nodes_in_hyperedge, q, k, v, &mut output);
502            }
503        }
504
505        output
506    }
507
508    /// Compute attention for a specific hyperedge and head
509    fn compute_hyperedge_attention(
510        &self,
511        head: usize,
512        nodes: &[usize],
513        q: &Tensor,
514        k: &Tensor,
515        v: &Tensor,
516        output: &mut Tensor,
517    ) {
518        let head_dim = self.out_features / self.heads;
519        let scale = 1.0 / (head_dim as f32).sqrt();
520
521        // For simplicity, use mean pooling within hyperedge
522        // In practice, this would use more sophisticated attention
523        for &node_i in nodes {
524            let mut aggregated = zeros(&[head_dim]).expect("zeros aggregated should succeed");
525            let mut total_weight = 0.0;
526
527            for &node_j in nodes {
528                if node_i != node_j {
529                    // Get query and key for these nodes and head
530                    let q_i = q
531                        .slice_tensor(0, node_i, node_i + 1)
532                        .expect("slice_tensor q_i should succeed")
533                        .slice_tensor(1, head, head + 1)
534                        .expect("slice_tensor q_i head should succeed")
535                        .squeeze_tensor(0)
536                        .expect("squeeze_tensor should succeed")
537                        .squeeze_tensor(0)
538                        .expect("squeeze_tensor should succeed");
539
540                    let k_j = k
541                        .slice_tensor(0, node_j, node_j + 1)
542                        .expect("slice_tensor k_j should succeed")
543                        .slice_tensor(1, head, head + 1)
544                        .expect("slice_tensor k_j head should succeed")
545                        .squeeze_tensor(0)
546                        .expect("squeeze_tensor should succeed")
547                        .squeeze_tensor(0)
548                        .expect("squeeze_tensor should succeed");
549
550                    let v_j = v
551                        .slice_tensor(0, node_j, node_j + 1)
552                        .expect("slice_tensor v_j should succeed")
553                        .slice_tensor(1, head, head + 1)
554                        .expect("slice_tensor v_j head should succeed")
555                        .squeeze_tensor(0)
556                        .expect("squeeze_tensor should succeed")
557                        .squeeze_tensor(0)
558                        .expect("squeeze_tensor should succeed");
559
560                    // Compute attention weight (simplified)
561                    let attention_score = q_i
562                        .dot(&k_j)
563                        .expect("dot should succeed")
564                        .mul_scalar(scale)
565                        .expect("mul_scalar should succeed");
566                    let weight = attention_score
567                        .exp()
568                        .expect("exp should succeed")
569                        .item()
570                        .expect("tensor should have single item");
571
572                    // Aggregate values
573                    let weighted_value = v_j.mul_scalar(weight).expect("mul_scalar should succeed");
574                    aggregated = aggregated
575                        .add(&weighted_value)
576                        .expect("operation should succeed");
577                    total_weight += weight;
578                }
579            }
580
581            // Normalize and update output
582            if total_weight > 0.0 {
583                aggregated = aggregated
584                    .div_scalar(total_weight)
585                    .expect("div_scalar should succeed");
586
587                // Update output tensor (simplified assignment)
588                let aggregated_data = aggregated.to_vec().expect("conversion should succeed");
589                for (j, &val) in aggregated_data.iter().enumerate() {
590                    output
591                        .set_item(&[node_i, head, j], val)
592                        .expect("set_item should succeed");
593                }
594            }
595        }
596    }
597}
598
599impl GraphLayer for HyperGATConv {
600    fn forward(&self, graph: &GraphData) -> GraphData {
601        let hypergraph = graph_to_hypergraph(graph);
602        let output_hypergraph = self.forward(&hypergraph);
603        output_hypergraph.to_graph_clique_expansion()
604    }
605
606    fn parameters(&self) -> Vec<Tensor> {
607        let mut params = vec![
608            self.query_weight.clone_data(),
609            self.key_weight.clone_data(),
610            self.value_weight.clone_data(),
611            self.hyperedge_attention.clone_data(),
612            self.output_weight.clone_data(),
613        ];
614
615        if let Some(ref bias) = self.bias {
616            params.push(bias.clone_data());
617        }
618
619        params
620    }
621}
622
623/// Hypergraph Neural Network (HGNN) layer based on spectral methods
624#[derive(Debug)]
625pub struct HGNNConv {
626    in_features: usize,
627    out_features: usize,
628    weight: Parameter,
629    bias: Option<Parameter>,
630    use_spectral: bool,
631}
632
633impl HGNNConv {
634    /// Create a new HGNN layer
635    pub fn new(in_features: usize, out_features: usize, bias: bool, use_spectral: bool) -> Self {
636        let weight = Parameter::new(
637            randn(&[in_features, out_features]).expect("randn weight should succeed"),
638        );
639        let bias = if bias {
640            Some(Parameter::new(
641                zeros(&[out_features]).expect("zeros bias should succeed"),
642            ))
643        } else {
644            None
645        };
646
647        Self {
648            in_features,
649            out_features,
650            weight,
651            bias,
652            use_spectral,
653        }
654    }
655
656    /// Forward pass through HGNN layer
657    pub fn forward(&self, hypergraph: &HypergraphData) -> HypergraphData {
658        // Transform node features
659        let x_transformed = hypergraph
660            .x
661            .matmul(&self.weight.clone_data())
662            .expect("operation should succeed");
663
664        // Compute hypergraph Laplacian and apply convolution
665        let output_features = if self.use_spectral {
666            self.spectral_convolution(&x_transformed, hypergraph)
667        } else {
668            self.spatial_convolution(&x_transformed, hypergraph)
669        };
670
671        // Add bias if present
672        let final_features = if let Some(ref bias) = self.bias {
673            output_features
674                .add(&bias.clone_data())
675                .expect("operation should succeed")
676        } else {
677            output_features
678        };
679
680        HypergraphData {
681            x: final_features,
682            incidence_matrix: hypergraph.incidence_matrix.clone(),
683            hyperedge_weights: hypergraph.hyperedge_weights.clone(),
684            hyperedge_features: hypergraph.hyperedge_features.clone(),
685            node_degrees: hypergraph.node_degrees.clone(),
686            hyperedge_cardinalities: hypergraph.hyperedge_cardinalities.clone(),
687            num_nodes: hypergraph.num_nodes,
688            num_hyperedges: hypergraph.num_hyperedges,
689        }
690    }
691
692    /// Spectral convolution using hypergraph Laplacian
693    fn spectral_convolution(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
694        // Compute normalized hypergraph Laplacian
695        let laplacian = self.compute_hypergraph_laplacian(hypergraph);
696
697        // Apply Laplacian: L @ X
698        laplacian
699            .matmul(features)
700            .expect("operation should succeed")
701    }
702
703    /// Spatial convolution using incidence matrix
704    fn spatial_convolution(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
705        // Node-to-hyperedge aggregation
706        let incidence_t = hypergraph
707            .incidence_matrix
708            .transpose(0, 1)
709            .expect("transpose should succeed");
710        let hyperedge_features = incidence_t
711            .matmul(features)
712            .expect("operation should succeed");
713
714        // Hyperedge-to-node aggregation
715        let aggregated = hypergraph
716            .incidence_matrix
717            .matmul(&hyperedge_features)
718            .expect("operation should succeed");
719
720        // Normalize by node degrees
721        self.normalize_by_degrees(&aggregated, hypergraph)
722    }
723
724    /// Compute normalized hypergraph Laplacian
725    fn compute_hypergraph_laplacian(&self, hypergraph: &HypergraphData) -> Tensor {
726        let h = &hypergraph.incidence_matrix;
727        let num_nodes = hypergraph.num_nodes;
728
729        // Compute degree matrices
730        let node_degrees = h
731            .sum_dim(&[1], false)
732            .expect("sum_dim node_degrees should succeed");
733        let hyperedge_degrees = h
734            .sum_dim(&[0], false)
735            .expect("sum_dim hyperedge_degrees should succeed");
736
737        // Create diagonal degree matrices (simplified)
738        let mut d_v = zeros(&[num_nodes, num_nodes]).expect("zeros d_v should succeed");
739        let mut d_e = zeros(&[hypergraph.num_hyperedges, hypergraph.num_hyperedges])
740            .expect("zeros d_e should succeed");
741
742        let node_deg_data = node_degrees.to_vec().expect("conversion should succeed");
743        let hyperedge_deg_data = hyperedge_degrees
744            .to_vec()
745            .expect("conversion should succeed");
746
747        // Fill diagonal matrices
748        for i in 0..num_nodes {
749            let degree = node_deg_data[i].max(1e-8); // Avoid division by zero
750            d_v.set_item(&[i, i], degree.powf(-0.5))
751                .expect("set_item d_v should succeed");
752        }
753
754        for i in 0..hypergraph.num_hyperedges {
755            let degree = hyperedge_deg_data[i].max(1e-8);
756            d_e.set_item(&[i, i], degree.recip())
757                .expect("set_item d_e should succeed");
758        }
759
760        // Compute normalized Laplacian: I - D_v^{-1/2} H D_e H^T D_v^{-1/2}
761        let h_t = h.transpose(0, 1).expect("transpose should succeed");
762        let intermediate = d_v
763            .matmul(h)
764            .expect("operation should succeed")
765            .matmul(&d_e)
766            .expect("operation should succeed")
767            .matmul(&h_t)
768            .expect("operation should succeed")
769            .matmul(&d_v)
770            .expect("operation should succeed");
771
772        let identity = eye(num_nodes);
773        identity
774            .sub(&intermediate)
775            .expect("operation should succeed")
776    }
777
778    /// Normalize features by node degrees
779    fn normalize_by_degrees(&self, features: &Tensor, hypergraph: &HypergraphData) -> Tensor {
780        let degrees = &hypergraph.node_degrees;
781        let epsilon = 1e-8;
782
783        let safe_degrees = degrees
784            .add_scalar(epsilon)
785            .expect("add_scalar should succeed");
786        let inv_sqrt_degrees = safe_degrees
787            .pow_scalar(-0.5)
788            .expect("pow_scalar should succeed");
789
790        // First squeeze to ensure we have shape [num_nodes] rather than [num_nodes, 1]
791        let inv_degrees_squeezed = if inv_sqrt_degrees.shape().dims().len() > 1 {
792            inv_sqrt_degrees
793                .squeeze_tensor(1)
794                .expect("squeeze_tensor should succeed")
795        } else {
796            inv_sqrt_degrees
797        };
798        let inv_degrees_expanded = inv_degrees_squeezed
799            .unsqueeze(-1)
800            .expect("unsqueeze should succeed");
801
802        features
803            .mul(&inv_degrees_expanded)
804            .expect("operation should succeed")
805    }
806}
807
808impl GraphLayer for HGNNConv {
809    fn forward(&self, graph: &GraphData) -> GraphData {
810        let hypergraph = graph_to_hypergraph(graph);
811        let output_hypergraph = self.forward(&hypergraph);
812        output_hypergraph.to_graph_clique_expansion()
813    }
814
815    fn parameters(&self) -> Vec<Tensor> {
816        let mut params = vec![self.weight.clone_data()];
817        if let Some(ref bias) = self.bias {
818            params.push(bias.clone_data());
819        }
820        params
821    }
822}
823
824/// Hypergraph pooling operations
825pub mod pooling {
826    use super::*;
827
828    /// Global hypergraph pooling
829    pub fn global_hypergraph_pool(hypergraph: &HypergraphData, method: PoolingMethod) -> Tensor {
830        match method {
831            PoolingMethod::Mean => hypergraph
832                .x
833                .mean(Some(&[0]), false)
834                .expect("mean pooling should succeed"),
835            PoolingMethod::Max => hypergraph
836                .x
837                .max(Some(0), false)
838                .expect("max pooling should succeed"),
839            PoolingMethod::Sum => hypergraph
840                .x
841                .sum_dim(&[0], false)
842                .expect("sum pooling should succeed"),
843            PoolingMethod::Attention => attention_pool(hypergraph),
844        }
845    }
846
847    /// Hyperedge-aware pooling
848    pub fn hyperedge_pool(hypergraph: &HypergraphData, method: PoolingMethod) -> Tensor {
849        let incidence_t = hypergraph
850            .incidence_matrix
851            .transpose(0, 1)
852            .expect("transpose should succeed");
853
854        match method {
855            PoolingMethod::Mean => {
856                // Average pooling over hyperedges
857                let hyperedge_features = incidence_t
858                    .matmul(&hypergraph.x)
859                    .expect("operation should succeed");
860                hyperedge_features
861                    .mean(Some(&[0]), false)
862                    .expect("mean pooling should succeed")
863            }
864            PoolingMethod::Max => {
865                let hyperedge_features = incidence_t
866                    .matmul(&hypergraph.x)
867                    .expect("operation should succeed");
868                hyperedge_features
869                    .max(Some(0), false)
870                    .expect("max pooling should succeed")
871            }
872            PoolingMethod::Sum => {
873                let hyperedge_features = incidence_t
874                    .matmul(&hypergraph.x)
875                    .expect("operation should succeed");
876                hyperedge_features
877                    .sum_dim(&[0], false)
878                    .expect("sum pooling should succeed")
879            }
880            PoolingMethod::Attention => {
881                // Attention over hyperedges
882                attention_pool(hypergraph)
883            }
884        }
885    }
886
887    /// Hierarchical hypergraph pooling
888    pub fn hierarchical_hypergraph_pool(
889        hypergraph: &HypergraphData,
890        num_clusters: usize,
891    ) -> HypergraphData {
892        // Simplified clustering-based pooling
893        let cluster_assignments = cluster_nodes(hypergraph, num_clusters);
894        coarsen_hypergraph(hypergraph, &cluster_assignments)
895    }
896
897    /// Attention-based pooling
898    fn attention_pool(hypergraph: &HypergraphData) -> Tensor {
899        // Simplified attention pooling
900        let attention_scores = hypergraph
901            .x
902            .sum_dim(&[1], false)
903            .expect("sum_dim should succeed");
904        let attention_weights = attention_scores.softmax(0).expect("softmax should succeed");
905        let attention_expanded = attention_weights
906            .unsqueeze(-1)
907            .expect("unsqueeze should succeed");
908
909        let weighted_features = hypergraph
910            .x
911            .mul(&attention_expanded)
912            .expect("operation should succeed");
913        weighted_features
914            .sum_dim(&[0], false)
915            .expect("sum_dim should succeed")
916    }
917
918    /// Simple node clustering for hierarchical pooling
919    fn cluster_nodes(hypergraph: &HypergraphData, num_clusters: usize) -> Vec<usize> {
920        let num_nodes = hypergraph.num_nodes;
921        let mut assignments = vec![0; num_nodes];
922
923        // Simple clustering by node index (for demonstration)
924        for i in 0..num_nodes {
925            assignments[i] = i % num_clusters;
926        }
927
928        assignments
929    }
930
931    /// Coarsen hypergraph based on cluster assignments
932    fn coarsen_hypergraph(
933        hypergraph: &HypergraphData,
934        cluster_assignments: &[usize],
935    ) -> HypergraphData {
936        let num_clusters = cluster_assignments
937            .iter()
938            .max()
939            .expect("reduction should succeed")
940            + 1;
941        let original_features = hypergraph.x.shape().dims()[1];
942
943        // Average node features within clusters (simplified implementation)
944        let mut coarse_features_data = vec![0.0; num_clusters * original_features];
945        let mut cluster_counts = vec![0; num_clusters];
946
947        let node_data = hypergraph.x.to_vec().expect("conversion should succeed");
948
949        for (node, &cluster) in cluster_assignments.iter().enumerate() {
950            cluster_counts[cluster] += 1;
951            for feat in 0..original_features {
952                let node_feat_idx = node * original_features + feat;
953                let cluster_feat_idx = cluster * original_features + feat;
954                coarse_features_data[cluster_feat_idx] += node_data[node_feat_idx];
955            }
956        }
957
958        // Normalize by cluster size
959        for cluster in 0..num_clusters {
960            if cluster_counts[cluster] > 0 {
961                for feat in 0..original_features {
962                    let cluster_feat_idx = cluster * original_features + feat;
963                    coarse_features_data[cluster_feat_idx] /= cluster_counts[cluster] as f32;
964                }
965            }
966        }
967
968        let coarse_features = from_vec(
969            coarse_features_data,
970            &[num_clusters, original_features],
971            torsh_core::device::DeviceType::Cpu,
972        )
973        .expect("from_vec coarse_features should succeed");
974
975        // Create coarse incidence matrix (simplified)
976        let coarse_incidence = zeros(&[num_clusters, hypergraph.num_hyperedges])
977            .expect("zeros coarse_incidence should succeed");
978
979        HypergraphData::new(coarse_features, coarse_incidence)
980    }
981
982    /// Pooling methods
983    #[derive(Debug, Clone, Copy)]
984    pub enum PoolingMethod {
985        Mean,
986        Max,
987        Sum,
988        Attention,
989    }
990}
991
992/// Utility functions for hypergraph operations
993pub mod utils {
994    use super::*;
995
996    /// Convert edge list to hypergraph
997    pub fn edge_list_to_hypergraph(
998        edges: &[(Vec<usize>, f32)],
999        num_nodes: usize,
1000    ) -> HypergraphData {
1001        let num_hyperedges = edges.len();
1002        let mut incidence_data = vec![0.0; num_nodes * num_hyperedges];
1003        let mut weights = Vec::new();
1004
1005        for (e, (edge_nodes, weight)) in edges.iter().enumerate() {
1006            weights.push(*weight);
1007            for &node in edge_nodes {
1008                if node < num_nodes {
1009                    incidence_data[node * num_hyperedges + e] = 1.0;
1010                }
1011            }
1012        }
1013
1014        let features = randn(&[num_nodes, 16]).expect("randn features should succeed"); // Default features
1015        let incidence_matrix = from_vec(
1016            incidence_data,
1017            &[num_nodes, num_hyperedges],
1018            torsh_core::device::DeviceType::Cpu,
1019        )
1020        .expect("from_vec incidence_matrix should succeed");
1021
1022        let hyperedge_weights = from_vec(
1023            weights,
1024            &[num_hyperedges],
1025            torsh_core::device::DeviceType::Cpu,
1026        )
1027        .expect("from_vec hyperedge_weights should succeed");
1028
1029        HypergraphData::new(features, incidence_matrix).with_hyperedge_weights(hyperedge_weights)
1030    }
1031
1032    /// Generate random hypergraph
1033    pub fn random_hypergraph(
1034        num_nodes: usize,
1035        num_hyperedges: usize,
1036        edge_prob: f32,
1037        features_dim: usize,
1038    ) -> HypergraphData {
1039        let mut rng = scirs2_core::random::thread_rng();
1040        let mut incidence_data = vec![0.0; num_nodes * num_hyperedges];
1041
1042        // Generate random hyperedges
1043        for e in 0..num_hyperedges {
1044            for v in 0..num_nodes {
1045                if rng.gen_range(0.0..1.0) < edge_prob {
1046                    incidence_data[v * num_hyperedges + e] = 1.0;
1047                }
1048            }
1049        }
1050
1051        let features = randn(&[num_nodes, features_dim]).expect("randn features should succeed");
1052        let incidence_matrix = from_vec(
1053            incidence_data,
1054            &[num_nodes, num_hyperedges],
1055            torsh_core::device::DeviceType::Cpu,
1056        )
1057        .expect("from_vec incidence_matrix should succeed");
1058
1059        HypergraphData::new(features, incidence_matrix)
1060    }
1061
1062    /// Hypergraph metrics
1063    pub fn hypergraph_metrics(hypergraph: &HypergraphData) -> HypergraphMetrics {
1064        let node_degrees = hypergraph
1065            .node_degrees
1066            .to_vec()
1067            .expect("conversion should succeed");
1068        let hyperedge_cardinalities = hypergraph
1069            .hyperedge_cardinalities
1070            .to_vec()
1071            .expect("conversion should succeed");
1072
1073        let avg_node_degree = node_degrees.iter().sum::<f32>() / node_degrees.len() as f32;
1074        let avg_hyperedge_size =
1075            hyperedge_cardinalities.iter().sum::<f32>() / hyperedge_cardinalities.len() as f32;
1076
1077        let density = node_degrees.iter().sum::<f32>()
1078            / (hypergraph.num_nodes * hypergraph.num_hyperedges) as f32;
1079
1080        HypergraphMetrics {
1081            avg_node_degree,
1082            avg_hyperedge_size,
1083            density,
1084            num_nodes: hypergraph.num_nodes,
1085            num_hyperedges: hypergraph.num_hyperedges,
1086        }
1087    }
1088
1089    /// Hypergraph statistics
1090    #[derive(Debug, Clone)]
1091    pub struct HypergraphMetrics {
1092        pub avg_node_degree: f32,
1093        pub avg_hyperedge_size: f32,
1094        pub density: f32,
1095        pub num_nodes: usize,
1096        pub num_hyperedges: usize,
1097    }
1098}
1099
1100/// Convert regular graph to hypergraph (each edge becomes a hyperedge)
1101pub fn graph_to_hypergraph(graph: &GraphData) -> HypergraphData {
1102    let edge_data = crate::utils::tensor_to_vec2::<f32>(&graph.edge_index)
1103        .expect("tensor_to_vec2 should succeed");
1104    let num_edges = edge_data[0].len();
1105    let num_nodes = graph.num_nodes;
1106
1107    // Each edge becomes a hyperedge connecting two nodes
1108    let mut incidence_data = vec![0.0; num_nodes * num_edges];
1109
1110    for e in 0..num_edges {
1111        let src = edge_data[0][e] as usize;
1112        let dst = edge_data[1][e] as usize;
1113
1114        if src < num_nodes && dst < num_nodes {
1115            incidence_data[src * num_edges + e] = 1.0;
1116            incidence_data[dst * num_edges + e] = 1.0;
1117        }
1118    }
1119
1120    let incidence_matrix = from_vec(
1121        incidence_data,
1122        &[num_nodes, num_edges],
1123        torsh_core::device::DeviceType::Cpu,
1124    )
1125    .expect("from_vec incidence_matrix should succeed");
1126
1127    HypergraphData::new(graph.x.clone(), incidence_matrix)
1128}
1129
1130/// Create identity matrix
1131fn eye(n: usize) -> Tensor {
1132    let mut data = vec![0.0; n * n];
1133    for i in 0..n {
1134        data[i * n + i] = 1.0;
1135    }
1136    from_vec(data, &[n, n], torsh_core::device::DeviceType::Cpu)
1137        .expect("from_vec eye should succeed")
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142    use super::*;
1143    use torsh_core::device::DeviceType;
1144
1145    #[test]
1146    fn test_hypergraph_creation() {
1147        let features = randn(&[4, 3]).unwrap();
1148        let incidence_data = vec![
1149            1.0, 0.0, 1.0, // Node 0 in hyperedges 0 and 2
1150            1.0, 1.0, 0.0, // Node 1 in hyperedges 0 and 1
1151            0.0, 1.0, 1.0, // Node 2 in hyperedges 1 and 2
1152            0.0, 0.0, 1.0, // Node 3 in hyperedge 2
1153        ];
1154        let incidence_matrix = from_vec(incidence_data, &[4, 3], DeviceType::Cpu).unwrap();
1155
1156        let hypergraph = HypergraphData::new(features, incidence_matrix);
1157
1158        assert_eq!(hypergraph.num_nodes, 4);
1159        assert_eq!(hypergraph.num_hyperedges, 3);
1160        assert_eq!(hypergraph.x.shape().dims(), &[4, 3]);
1161        assert_eq!(hypergraph.incidence_matrix.shape().dims(), &[4, 3]);
1162    }
1163
1164    #[test]
1165    fn test_hgcn_layer() {
1166        let features = randn(&[3, 4]).unwrap();
1167        let incidence_matrix =
1168            from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0], &[3, 2], DeviceType::Cpu).unwrap();
1169        let hypergraph = HypergraphData::new(features, incidence_matrix);
1170
1171        let hgcn = HGCNConv::new(4, 8, true, false, 0.1);
1172        let output = hgcn.forward(&hypergraph);
1173
1174        assert_eq!(output.x.shape().dims(), &[3, 8]);
1175        assert_eq!(output.num_nodes, 3);
1176        assert_eq!(output.num_hyperedges, 2);
1177    }
1178
1179    #[test]
1180    fn test_hypergraph_to_graph_conversion() {
1181        let features = randn(&[3, 4]).unwrap();
1182        let incidence_matrix =
1183            from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0], &[3, 2], DeviceType::Cpu).unwrap();
1184        let hypergraph = HypergraphData::new(features, incidence_matrix);
1185
1186        let graph = hypergraph.to_graph_clique_expansion();
1187        assert_eq!(graph.num_nodes, 3);
1188
1189        let star_graph = hypergraph.to_graph_star_expansion();
1190        assert_eq!(star_graph.num_nodes, 5); // 3 original + 2 virtual nodes
1191    }
1192
1193    #[test]
1194    fn test_hypergraph_pooling() {
1195        let features = randn(&[4, 6]).unwrap();
1196        let incidence_matrix = from_vec(
1197            vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
1198            &[4, 2],
1199            DeviceType::Cpu,
1200        )
1201        .unwrap();
1202        let hypergraph = HypergraphData::new(features, incidence_matrix);
1203
1204        let pooled_mean =
1205            pooling::global_hypergraph_pool(&hypergraph, pooling::PoolingMethod::Mean);
1206        assert_eq!(pooled_mean.shape().dims(), &[6]);
1207
1208        let pooled_max = pooling::global_hypergraph_pool(&hypergraph, pooling::PoolingMethod::Max);
1209        assert_eq!(pooled_max.shape().dims(), &[6]);
1210    }
1211
1212    #[test]
1213    fn test_hypergraph_utils() {
1214        let edges = vec![
1215            (vec![0, 1, 2], 1.0),
1216            (vec![1, 3], 0.8),
1217            (vec![0, 2, 3], 1.2),
1218        ];
1219
1220        let hypergraph = utils::edge_list_to_hypergraph(&edges, 4);
1221        assert_eq!(hypergraph.num_nodes, 4);
1222        assert_eq!(hypergraph.num_hyperedges, 3);
1223
1224        let metrics = utils::hypergraph_metrics(&hypergraph);
1225        assert!(metrics.avg_node_degree > 0.0);
1226        assert!(metrics.avg_hyperedge_size > 0.0);
1227    }
1228
1229    #[test]
1230    fn test_random_hypergraph_generation() {
1231        let hypergraph = utils::random_hypergraph(5, 3, 0.6, 8);
1232        assert_eq!(hypergraph.num_nodes, 5);
1233        assert_eq!(hypergraph.num_hyperedges, 3);
1234        assert_eq!(hypergraph.x.shape().dims(), &[5, 8]);
1235    }
1236
1237    #[test]
1238    fn test_hypergat_layer() {
1239        let features = randn(&[4, 6]).unwrap();
1240        let incidence_matrix = from_vec(
1241            vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
1242            &[4, 2],
1243            DeviceType::Cpu,
1244        )
1245        .unwrap();
1246        let hypergraph = HypergraphData::new(features, incidence_matrix);
1247
1248        let hypergat = HyperGATConv::new(6, 12, 3, 0.1, true);
1249        let output = hypergat.forward(&hypergraph);
1250
1251        assert_eq!(output.x.shape().dims(), &[4, 12]);
1252        assert_eq!(output.num_nodes, 4);
1253    }
1254}