Skip to main content

torsh_graph/conv/
gcn.rs

1//! Graph Convolutional Network (GCN) layer implementation
2
3use crate::parameter::Parameter;
4use crate::{GraphData, GraphLayer};
5use torsh_tensor::{
6    creation::{randn, zeros},
7    Tensor,
8};
9
10/// Graph Convolutional Network (GCN) layer
11#[derive(Debug)]
12pub struct GCNConv {
13    in_features: usize,
14    out_features: usize,
15    weight: Parameter,
16    bias: Option<Parameter>,
17}
18
19impl GCNConv {
20    /// Create a new GCN convolution layer
21    pub fn new(in_features: usize, out_features: usize, bias: bool) -> Self {
22        let weight = Parameter::new(
23            randn(&[in_features, out_features]).expect("failed to create weight tensor"),
24        );
25        let bias = if bias {
26            Some(Parameter::new(
27                zeros(&[out_features]).expect("failed to create bias tensor"),
28            ))
29        } else {
30            None
31        };
32
33        Self {
34            in_features,
35            out_features,
36            weight,
37            bias,
38        }
39    }
40
41    /// Get input feature dimension
42    pub fn in_features(&self) -> usize {
43        self.in_features
44    }
45
46    /// Get output feature dimension
47    pub fn out_features(&self) -> usize {
48        self.out_features
49    }
50
51    /// Apply graph convolution
52    pub fn forward(&self, graph: &GraphData) -> GraphData {
53        // Compute normalized Laplacian matrix
54        let laplacian = crate::utils::graph_laplacian(&graph.edge_index, graph.num_nodes, true);
55
56        // Apply graph convolution: L @ X @ W
57        let x_transformed = graph
58            .x
59            .matmul(&self.weight.clone_data())
60            .expect("operation should succeed");
61        let mut output_features = laplacian
62            .matmul(&x_transformed)
63            .expect("operation should succeed");
64
65        // Add bias if present
66        if let Some(ref bias) = self.bias {
67            output_features = output_features
68                .add(&bias.clone_data())
69                .expect("operation should succeed");
70        }
71
72        // Create output graph with transformed features
73        GraphData {
74            x: output_features,
75            edge_index: graph.edge_index.clone(),
76            edge_attr: graph.edge_attr.clone(),
77            batch: graph.batch.clone(),
78            num_nodes: graph.num_nodes,
79            num_edges: graph.num_edges,
80        }
81    }
82}
83
84impl GraphLayer for GCNConv {
85    fn forward(&self, graph: &GraphData) -> GraphData {
86        self.forward(graph)
87    }
88
89    fn parameters(&self) -> Vec<Tensor> {
90        let mut params = vec![self.weight.clone_data()];
91        if let Some(ref bias) = self.bias {
92            params.push(bias.clone_data());
93        }
94        params
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use torsh_core::device::DeviceType;
102    use torsh_tensor::creation::from_vec;
103
104    #[test]
105    fn test_gcn_creation() {
106        let gcn = GCNConv::new(8, 16, true);
107        let params = gcn.parameters();
108        assert_eq!(params.len(), 2); // weight + bias
109    }
110
111    #[test]
112    fn test_gcn_forward() {
113        let gcn = GCNConv::new(3, 8, false);
114
115        // Create simple test graph
116        let x = from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], DeviceType::Cpu).unwrap();
117        let edge_index = from_vec(vec![0.0, 1.0, 1.0, 0.0], &[2, 2], DeviceType::Cpu).unwrap();
118        let graph = GraphData::new(x, edge_index);
119
120        let output = gcn.forward(&graph);
121        assert_eq!(output.x.shape().dims(), &[2, 8]);
122        assert_eq!(output.num_nodes, 2);
123    }
124}