1#![allow(dead_code)]
6use crate::parameter::Parameter;
7use crate::{GraphData, GraphLayer};
8use torsh_tensor::{
9 creation::{randn, zeros},
10 Tensor,
11};
12
13#[derive(Debug)]
15pub struct GINConv {
16 in_features: usize,
17 out_features: usize,
18 eps: f64,
19 train_eps: bool,
20 eps_param: Option<Parameter>,
21 mlp: Vec<Parameter>, bias: Option<Parameter>,
23}
24
25impl GINConv {
26 pub fn new(
28 in_features: usize,
29 out_features: usize,
30 eps: f64,
31 train_eps: bool,
32 bias: bool,
33 ) -> Self {
34 let eps_param = if train_eps {
35 Some(Parameter::new(
36 torsh_tensor::creation::tensor_scalar(eps as f32)
37 .expect("failed to create epsilon scalar"),
38 ))
39 } else {
40 None
41 };
42
43 let hidden_dim = (in_features + out_features) / 2;
45 let mlp = vec![
46 Parameter::new(
47 randn(&[in_features, hidden_dim]).expect("failed to create MLP layer 1 weights"),
48 ),
49 Parameter::new(
50 randn(&[hidden_dim, out_features]).expect("failed to create MLP layer 2 weights"),
51 ),
52 ];
53
54 let bias = if bias {
55 Some(Parameter::new(
56 zeros(&[out_features]).expect("failed to create bias tensor"),
57 ))
58 } else {
59 None
60 };
61
62 Self {
63 in_features,
64 out_features,
65 eps,
66 train_eps,
67 eps_param,
68 mlp,
69 bias,
70 }
71 }
72
73 pub fn forward(&self, graph: &GraphData) -> GraphData {
75 let num_nodes = graph.num_nodes;
76 let edge_flat = graph
77 .edge_index
78 .to_vec()
79 .expect("conversion should succeed");
80 let num_edges = edge_flat.len() / 2;
81 let edge_data = vec![
82 edge_flat[0..num_edges].to_vec(),
83 edge_flat[num_edges..].to_vec(),
84 ];
85
86 let mut adjacency_list: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
88 for j in 0..edge_data[0].len() {
89 let src = edge_data[0][j] as usize;
90 let dst = edge_data[1][j] as usize;
91 if src < num_nodes && dst < num_nodes {
92 adjacency_list[dst].push(src);
93 }
94 }
95
96 let neighbor_features = zeros(&[num_nodes, self.in_features])
98 .expect("failed to create neighbor features tensor");
99
100 for node in 0..num_nodes {
101 let mut aggregated =
102 zeros(&[self.in_features]).expect("failed to create aggregated features tensor");
103
104 for &neighbor in &adjacency_list[node] {
106 let neighbor_feat = graph
107 .x
108 .slice_tensor(0, neighbor, neighbor + 1)
109 .expect("failed to slice neighbor features")
110 .squeeze_tensor(0)
111 .expect("failed to squeeze neighbor features");
112 aggregated = aggregated
113 .add(&neighbor_feat)
114 .expect("operation should succeed");
115 }
116
117 let mut node_slice = neighbor_features
118 .slice_tensor(0, node, node + 1)
119 .expect("failed to slice neighbor features");
120 let _ = node_slice.copy_(
121 &aggregated
122 .unsqueeze_tensor(0)
123 .expect("failed to unsqueeze aggregated features"),
124 );
125 }
126
127 let epsilon = if let Some(ref eps_param) = self.eps_param {
129 eps_param
130 .clone_data()
131 .to_vec()
132 .expect("conversion should succeed")[0] as f64
133 } else {
134 self.eps
135 };
136
137 let self_weighted = graph
139 .x
140 .mul_scalar((1.0 + epsilon) as f32)
141 .expect("failed to scale self features");
142 let combined_features = self_weighted
143 .add(&neighbor_features)
144 .expect("operation should succeed");
145
146 let mut output = combined_features
148 .matmul(&self.mlp[0].clone_data())
149 .expect("operation should succeed");
150
151 let zero_tensor =
153 zeros(output.shape().dims()).expect("failed to create zero tensor for ReLU");
154 output = output
155 .maximum(&zero_tensor)
156 .expect("failed to apply ReLU activation");
157
158 output = output
160 .matmul(&self.mlp[1].clone_data())
161 .expect("operation should succeed");
162
163 if let Some(ref bias) = self.bias {
165 output = output
166 .add(&bias.clone_data())
167 .expect("operation should succeed");
168 }
169
170 GraphData {
172 x: output,
173 edge_index: graph.edge_index.clone(),
174 edge_attr: graph.edge_attr.clone(),
175 batch: graph.batch.clone(),
176 num_nodes: graph.num_nodes,
177 num_edges: graph.num_edges,
178 }
179 }
180}
181
182impl GraphLayer for GINConv {
183 fn forward(&self, graph: &GraphData) -> GraphData {
184 self.forward(graph)
185 }
186
187 fn parameters(&self) -> Vec<Tensor> {
188 let mut params = vec![self.mlp[0].clone_data(), self.mlp[1].clone_data()];
189
190 if let Some(ref eps_param) = self.eps_param {
191 params.push(eps_param.clone_data());
192 }
193
194 if let Some(ref bias) = self.bias {
195 params.push(bias.clone_data());
196 }
197
198 params
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use torsh_core::device::DeviceType;
206 use torsh_tensor::creation::from_vec;
207
208 #[test]
209 fn test_gin_creation() {
210 let gin = GINConv::new(8, 16, 0.5, true, true);
211 let params = gin.parameters();
212 assert!(params.len() >= 2); assert!(params.len() <= 4); }
215
216 #[test]
217 fn test_gin_forward() {
218 let gin = GINConv::new(4, 6, 0.0, false, false);
219
220 let x = from_vec(
222 vec![
223 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
227 &[3, 4],
228 DeviceType::Cpu,
229 )
230 .unwrap();
231 let edge_index =
232 from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
233 let graph = GraphData::new(x, edge_index);
234
235 let output = gin.forward(&graph);
236 assert_eq!(output.x.shape().dims(), &[3, 6]);
237 assert_eq!(output.num_nodes, 3);
238 }
239
240 #[test]
241 fn test_gin_trainable_eps() {
242 let gin_fixed = GINConv::new(4, 8, 1.0, false, false);
243 let gin_trainable = GINConv::new(4, 8, 1.0, true, false);
244
245 let fixed_params = gin_fixed.parameters();
246 let trainable_params = gin_trainable.parameters();
247
248 assert_eq!(trainable_params.len(), fixed_params.len() + 1);
250 }
251}