Skip to main content

torsh_graph/conv/
gin.rs

1//! Graph Isomorphism Network (GIN) layer implementation
2//! Based on the paper "How Powerful are Graph Neural Networks?"
3
4// Framework infrastructure - components designed for future use
5#![allow(dead_code)]
6use crate::parameter::Parameter;
7use crate::{GraphData, GraphLayer};
8use torsh_tensor::{
9    creation::{randn, zeros},
10    Tensor,
11};
12
13/// Graph Isomorphism Network (GIN) layer
14#[derive(Debug)]
15pub struct GINConv {
16    in_features: usize,
17    out_features: usize,
18    eps: f64,
19    train_eps: bool,
20    eps_param: Option<Parameter>,
21    mlp: Vec<Parameter>, // Simple MLP: Linear -> ReLU -> Linear
22    bias: Option<Parameter>,
23}
24
25impl GINConv {
26    /// Create a new GIN convolution layer
27    pub fn new(
28        in_features: usize,
29        out_features: usize,
30        eps: f64,
31        train_eps: bool,
32        bias: bool,
33    ) -> Self {
34        let eps_param = if train_eps {
35            Some(Parameter::new(
36                torsh_tensor::creation::tensor_scalar(eps as f32)
37                    .expect("failed to create epsilon scalar"),
38            ))
39        } else {
40            None
41        };
42
43        // Create a simple 2-layer MLP
44        let hidden_dim = (in_features + out_features) / 2;
45        let mlp = vec![
46            Parameter::new(
47                randn(&[in_features, hidden_dim]).expect("failed to create MLP layer 1 weights"),
48            ),
49            Parameter::new(
50                randn(&[hidden_dim, out_features]).expect("failed to create MLP layer 2 weights"),
51            ),
52        ];
53
54        let bias = if bias {
55            Some(Parameter::new(
56                zeros(&[out_features]).expect("failed to create bias tensor"),
57            ))
58        } else {
59            None
60        };
61
62        Self {
63            in_features,
64            out_features,
65            eps,
66            train_eps,
67            eps_param,
68            mlp,
69            bias,
70        }
71    }
72
73    /// Apply GIN convolution
74    pub fn forward(&self, graph: &GraphData) -> GraphData {
75        let num_nodes = graph.num_nodes;
76        let edge_flat = graph
77            .edge_index
78            .to_vec()
79            .expect("conversion should succeed");
80        let num_edges = edge_flat.len() / 2;
81        let edge_data = vec![
82            edge_flat[0..num_edges].to_vec(),
83            edge_flat[num_edges..].to_vec(),
84        ];
85
86        // Build adjacency list for efficient neighbor aggregation
87        let mut adjacency_list: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
88        for j in 0..edge_data[0].len() {
89            let src = edge_data[0][j] as usize;
90            let dst = edge_data[1][j] as usize;
91            if src < num_nodes && dst < num_nodes {
92                adjacency_list[dst].push(src);
93            }
94        }
95
96        // Aggregate neighbor features (sum aggregation for GIN)
97        let neighbor_features = zeros(&[num_nodes, self.in_features])
98            .expect("failed to create neighbor features tensor");
99
100        for node in 0..num_nodes {
101            let mut aggregated =
102                zeros(&[self.in_features]).expect("failed to create aggregated features tensor");
103
104            // Sum all neighbor features
105            for &neighbor in &adjacency_list[node] {
106                let neighbor_feat = graph
107                    .x
108                    .slice_tensor(0, neighbor, neighbor + 1)
109                    .expect("failed to slice neighbor features")
110                    .squeeze_tensor(0)
111                    .expect("failed to squeeze neighbor features");
112                aggregated = aggregated
113                    .add(&neighbor_feat)
114                    .expect("operation should succeed");
115            }
116
117            let mut node_slice = neighbor_features
118                .slice_tensor(0, node, node + 1)
119                .expect("failed to slice neighbor features");
120            let _ = node_slice.copy_(
121                &aggregated
122                    .unsqueeze_tensor(0)
123                    .expect("failed to unsqueeze aggregated features"),
124            );
125        }
126
127        // Get epsilon value
128        let epsilon = if let Some(ref eps_param) = self.eps_param {
129            eps_param
130                .clone_data()
131                .to_vec()
132                .expect("conversion should succeed")[0] as f64
133        } else {
134            self.eps
135        };
136
137        // Combine self and neighbor features: (1 + eps) * h_i + sum(h_j)
138        let self_weighted = graph
139            .x
140            .mul_scalar((1.0 + epsilon) as f32)
141            .expect("failed to scale self features");
142        let combined_features = self_weighted
143            .add(&neighbor_features)
144            .expect("operation should succeed");
145
146        // Apply MLP
147        let mut output = combined_features
148            .matmul(&self.mlp[0].clone_data())
149            .expect("operation should succeed");
150
151        // Apply ReLU activation (using max with zero tensor)
152        let zero_tensor =
153            zeros(output.shape().dims()).expect("failed to create zero tensor for ReLU");
154        output = output
155            .maximum(&zero_tensor)
156            .expect("failed to apply ReLU activation");
157
158        // Second layer
159        output = output
160            .matmul(&self.mlp[1].clone_data())
161            .expect("operation should succeed");
162
163        // Add bias if present
164        if let Some(ref bias) = self.bias {
165            output = output
166                .add(&bias.clone_data())
167                .expect("operation should succeed");
168        }
169
170        // Create output graph
171        GraphData {
172            x: output,
173            edge_index: graph.edge_index.clone(),
174            edge_attr: graph.edge_attr.clone(),
175            batch: graph.batch.clone(),
176            num_nodes: graph.num_nodes,
177            num_edges: graph.num_edges,
178        }
179    }
180}
181
182impl GraphLayer for GINConv {
183    fn forward(&self, graph: &GraphData) -> GraphData {
184        self.forward(graph)
185    }
186
187    fn parameters(&self) -> Vec<Tensor> {
188        let mut params = vec![self.mlp[0].clone_data(), self.mlp[1].clone_data()];
189
190        if let Some(ref eps_param) = self.eps_param {
191            params.push(eps_param.clone_data());
192        }
193
194        if let Some(ref bias) = self.bias {
195            params.push(bias.clone_data());
196        }
197
198        params
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use torsh_core::device::DeviceType;
206    use torsh_tensor::creation::from_vec;
207
208    #[test]
209    fn test_gin_creation() {
210        let gin = GINConv::new(8, 16, 0.5, true, true);
211        let params = gin.parameters();
212        assert!(params.len() >= 2); // At least MLP weights
213        assert!(params.len() <= 4); // At most MLP + eps + bias
214    }
215
216    #[test]
217    fn test_gin_forward() {
218        let gin = GINConv::new(4, 6, 0.0, false, false);
219
220        // Create test graph
221        let x = from_vec(
222            vec![
223                1.0, 2.0, 3.0, 4.0, // node 0
224                5.0, 6.0, 7.0, 8.0, // node 1
225                9.0, 10.0, 11.0, 12.0, // node 2
226            ],
227            &[3, 4],
228            DeviceType::Cpu,
229        )
230        .unwrap();
231        let edge_index =
232            from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
233        let graph = GraphData::new(x, edge_index);
234
235        let output = gin.forward(&graph);
236        assert_eq!(output.x.shape().dims(), &[3, 6]);
237        assert_eq!(output.num_nodes, 3);
238    }
239
240    #[test]
241    fn test_gin_trainable_eps() {
242        let gin_fixed = GINConv::new(4, 8, 1.0, false, false);
243        let gin_trainable = GINConv::new(4, 8, 1.0, true, false);
244
245        let fixed_params = gin_fixed.parameters();
246        let trainable_params = gin_trainable.parameters();
247
248        // Trainable eps version should have one more parameter
249        assert_eq!(trainable_params.len(), fixed_params.len() + 1);
250    }
251}