1#![allow(dead_code)]
5use crate::parameter::Parameter;
6use crate::{GraphData, GraphLayer};
7use torsh_tensor::{
8 creation::{randn, zeros},
9 Tensor,
10};
11
12#[derive(Debug)]
14pub struct GATConv {
15 in_features: usize,
16 out_features: usize,
17 heads: usize,
18 weight: Parameter,
19 attention: Parameter,
20 bias: Option<Parameter>,
21 dropout: f32,
22}
23
24impl GATConv {
25 pub fn new(
27 in_features: usize,
28 out_features: usize,
29 heads: usize,
30 dropout: f32,
31 bias: bool,
32 ) -> Self {
33 let weight = Parameter::new(
34 randn(&[in_features, heads * out_features]).expect("failed to create weight tensor"),
35 );
36 let attention = Parameter::new(
37 randn(&[heads, 2 * out_features]).expect("failed to create attention tensor"),
38 );
39 let bias = if bias {
40 Some(Parameter::new(
41 zeros(&[heads * out_features]).expect("failed to create bias tensor"),
42 ))
43 } else {
44 None
45 };
46
47 Self {
48 in_features,
49 out_features,
50 heads,
51 weight,
52 attention,
53 bias,
54 dropout,
55 }
56 }
57
58 pub fn forward(&self, graph: &GraphData) -> GraphData {
60 let num_nodes = graph.num_nodes;
61
62 let x_transformed = graph
64 .x
65 .matmul(&self.weight.clone_data())
66 .expect("operation should succeed");
67
68 let x_reshaped = x_transformed
70 .view(&[
71 num_nodes as i32,
72 self.heads as i32,
73 self.out_features as i32,
74 ])
75 .expect("view should succeed");
76
77 let edge_flat = graph
79 .edge_index
80 .to_vec()
81 .expect("conversion should succeed");
82 let num_edges = graph.num_edges;
83
84 let src_nodes: Vec<usize> = (0..num_edges).map(|i| edge_flat[i] as usize).collect();
86 let dst_nodes: Vec<usize> = (0..num_edges)
87 .map(|i| edge_flat[i + num_edges] as usize)
88 .collect();
89
90 let mut output = zeros(&[num_nodes, self.heads * self.out_features])
92 .expect("failed to create output tensor");
93
94 for head in 0..self.heads {
96 let attention_head = self
98 .attention
99 .clone_data()
100 .slice_tensor(0, head, head + 1)
101 .expect("failed to slice attention tensor")
102 .squeeze_tensor(0)
103 .expect("failed to squeeze attention tensor");
104
105 let mut attention_scores = Vec::with_capacity(num_edges);
107
108 for edge_idx in 0..num_edges {
109 let src = src_nodes[edge_idx];
110 let dst = dst_nodes[edge_idx];
111
112 let src_feat = x_reshaped
114 .slice_tensor(0, src, src + 1)
115 .expect("failed to slice source node")
116 .slice_tensor(1, head, head + 1)
117 .expect("failed to slice head dimension")
118 .squeeze_tensor(0)
119 .expect("failed to squeeze node dimension")
120 .squeeze_tensor(0)
121 .expect("failed to squeeze head dimension");
122
123 let dst_feat = x_reshaped
124 .slice_tensor(0, dst, dst + 1)
125 .expect("failed to slice destination node")
126 .slice_tensor(1, head, head + 1)
127 .expect("failed to slice head dimension")
128 .squeeze_tensor(0)
129 .expect("failed to squeeze node dimension")
130 .squeeze_tensor(0)
131 .expect("failed to squeeze head dimension");
132
133 let concat_feat = Tensor::cat(&[&src_feat, &dst_feat], 0)
135 .expect("failed to concatenate features");
136
137 let attention_coeff = attention_head
140 .mul(&concat_feat)
141 .expect("operation should succeed")
142 .sum()
143 .expect("reduction should succeed");
144
145 let coeff_val =
147 attention_coeff.to_vec().expect("conversion should succeed")[0] as f64;
148 let activated_val = if coeff_val > 0.0 {
149 coeff_val
150 } else {
151 0.2 * coeff_val };
153
154 attention_scores.push((src, dst, activated_val));
155 }
156
157 let mut normalized_scores = vec![0.0; num_edges];
159 for node in 0..num_nodes {
160 let mut node_edge_indices = Vec::new();
162 let mut node_scores = Vec::new();
163
164 for (edge_idx, (_, dst, score)) in attention_scores.iter().enumerate() {
165 if *dst == node {
166 node_edge_indices.push(edge_idx);
167 node_scores.push(*score);
168 }
169 }
170
171 if !node_scores.is_empty() {
172 let max_score = node_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
174 let exp_scores: Vec<f64> =
175 node_scores.iter().map(|s| (*s - max_score).exp()).collect();
176 let sum_exp: f64 = exp_scores.iter().sum();
177
178 for (i, &edge_idx) in node_edge_indices.iter().enumerate() {
179 normalized_scores[edge_idx] = exp_scores[i] / sum_exp;
180 }
181 }
182 }
183
184 let head_output = zeros(&[num_nodes, self.out_features])
186 .expect("failed to create head output tensor");
187
188 for node in 0..num_nodes {
189 let mut node_output =
190 zeros(&[self.out_features]).expect("failed to create node output tensor");
191
192 for (edge_idx, (src, dst, _)) in attention_scores.iter().enumerate() {
193 if *dst == node {
194 let weight = normalized_scores[edge_idx];
195 if weight > 0.0 {
196 let src_feat = x_reshaped
197 .slice_tensor(0, *src, *src + 1)
198 .expect("failed to slice source node")
199 .slice_tensor(1, head, head + 1)
200 .expect("failed to slice head dimension")
201 .squeeze_tensor(0)
202 .expect("failed to squeeze node dimension")
203 .squeeze_tensor(0)
204 .expect("failed to squeeze head dimension");
205
206 let weighted_feat = src_feat
207 .mul_scalar(weight as f32)
208 .expect("failed to scale features");
209 node_output = node_output
210 .add(&weighted_feat)
211 .expect("operation should succeed");
212 }
213 }
214 }
215
216 let mut node_slice = head_output
218 .slice_tensor(0, node, node + 1)
219 .expect("failed to slice node output");
220 let _ = node_slice.copy_(
221 &node_output
222 .unsqueeze_tensor(0)
223 .expect("failed to unsqueeze node output"),
224 );
225 }
226
227 let start_feat = head * self.out_features;
229 let end_feat = (head + 1) * self.out_features;
230 let mut output_slice = output
231 .slice_tensor(1, start_feat, end_feat)
232 .expect("failed to slice output tensor");
233 let _ = output_slice.copy_(&head_output);
234 }
235
236 if let Some(ref bias) = self.bias {
238 output = output
239 .add(&bias.clone_data())
240 .expect("operation should succeed");
241 }
242
243 if self.dropout > 0.0 {
245 }
248
249 GraphData {
250 x: output,
251 edge_index: graph.edge_index.clone(),
252 edge_attr: graph.edge_attr.clone(),
253 batch: graph.batch.clone(),
254 num_nodes: graph.num_nodes,
255 num_edges: graph.num_edges,
256 }
257 }
258}
259
260impl GraphLayer for GATConv {
261 fn forward(&self, graph: &GraphData) -> GraphData {
262 self.forward(graph)
263 }
264
265 fn parameters(&self) -> Vec<Tensor> {
266 let mut params = vec![self.weight.clone_data(), self.attention.clone_data()];
267 if let Some(ref bias) = self.bias {
268 params.push(bias.clone_data());
269 }
270 params
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use torsh_core::device::DeviceType;
278 use torsh_tensor::creation::from_vec;
279
280 #[test]
281 fn test_gat_creation() {
282 let gat = GATConv::new(16, 8, 4, 0.1, true);
283 let params = gat.parameters();
284 assert_eq!(params.len(), 3); assert_eq!(gat.heads, 4);
286 }
287
288 #[test]
289 fn test_gat_forward() {
290 let gat = GATConv::new(3, 4, 2, 0.0, false);
291
292 let x = from_vec(
294 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
295 &[3, 3],
296 DeviceType::Cpu,
297 )
298 .unwrap();
299 let edge_index =
300 from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
301 let graph = GraphData::new(x, edge_index);
302
303 let output = gat.forward(&graph);
304 assert_eq!(output.x.shape().dims(), &[3, 8]); assert_eq!(output.num_nodes, 3);
306 }
307}