Skip to main content

torsh_graph/conv/
gat.rs

1//! Graph Attention Network (GAT) layer implementation
2
3// Framework infrastructure - components designed for future use
4#![allow(dead_code)]
5use crate::parameter::Parameter;
6use crate::{GraphData, GraphLayer};
7use torsh_tensor::{
8    creation::{randn, zeros},
9    Tensor,
10};
11
12/// Graph Attention Network (GAT) layer
13#[derive(Debug)]
14pub struct GATConv {
15    in_features: usize,
16    out_features: usize,
17    heads: usize,
18    weight: Parameter,
19    attention: Parameter,
20    bias: Option<Parameter>,
21    dropout: f32,
22}
23
24impl GATConv {
25    /// Create a new GAT convolution layer
26    pub fn new(
27        in_features: usize,
28        out_features: usize,
29        heads: usize,
30        dropout: f32,
31        bias: bool,
32    ) -> Self {
33        let weight = Parameter::new(
34            randn(&[in_features, heads * out_features]).expect("failed to create weight tensor"),
35        );
36        let attention = Parameter::new(
37            randn(&[heads, 2 * out_features]).expect("failed to create attention tensor"),
38        );
39        let bias = if bias {
40            Some(Parameter::new(
41                zeros(&[heads * out_features]).expect("failed to create bias tensor"),
42            ))
43        } else {
44            None
45        };
46
47        Self {
48            in_features,
49            out_features,
50            heads,
51            weight,
52            attention,
53            bias,
54            dropout,
55        }
56    }
57
58    /// Apply graph attention convolution
59    pub fn forward(&self, graph: &GraphData) -> GraphData {
60        let num_nodes = graph.num_nodes;
61
62        // Transform node features: X @ W
63        let x_transformed = graph
64            .x
65            .matmul(&self.weight.clone_data())
66            .expect("operation should succeed");
67
68        // Reshape to separate heads: [num_nodes, heads, out_features]
69        let x_reshaped = x_transformed
70            .view(&[
71                num_nodes as i32,
72                self.heads as i32,
73                self.out_features as i32,
74            ])
75            .expect("view should succeed");
76
77        // Get edge indices - flatten and interpret as pairs
78        let edge_flat = graph
79            .edge_index
80            .to_vec()
81            .expect("conversion should succeed");
82        let num_edges = graph.num_edges;
83
84        // Extract source and destination nodes (edge_index is [2, num_edges] stored row-major)
85        let src_nodes: Vec<usize> = (0..num_edges).map(|i| edge_flat[i] as usize).collect();
86        let dst_nodes: Vec<usize> = (0..num_edges)
87            .map(|i| edge_flat[i + num_edges] as usize)
88            .collect();
89
90        // Initialize output
91        let mut output = zeros(&[num_nodes, self.heads * self.out_features])
92            .expect("failed to create output tensor");
93
94        // Process each head independently
95        for head in 0..self.heads {
96            // Extract attention parameters for this head
97            let attention_head = self
98                .attention
99                .clone_data()
100                .slice_tensor(0, head, head + 1)
101                .expect("failed to slice attention tensor")
102                .squeeze_tensor(0)
103                .expect("failed to squeeze attention tensor");
104
105            // Compute attention scores for all edges
106            let mut attention_scores = Vec::with_capacity(num_edges);
107
108            for edge_idx in 0..num_edges {
109                let src = src_nodes[edge_idx];
110                let dst = dst_nodes[edge_idx];
111
112                // Get source and destination node features for this head
113                let src_feat = x_reshaped
114                    .slice_tensor(0, src, src + 1)
115                    .expect("failed to slice source node")
116                    .slice_tensor(1, head, head + 1)
117                    .expect("failed to slice head dimension")
118                    .squeeze_tensor(0)
119                    .expect("failed to squeeze node dimension")
120                    .squeeze_tensor(0)
121                    .expect("failed to squeeze head dimension");
122
123                let dst_feat = x_reshaped
124                    .slice_tensor(0, dst, dst + 1)
125                    .expect("failed to slice destination node")
126                    .slice_tensor(1, head, head + 1)
127                    .expect("failed to slice head dimension")
128                    .squeeze_tensor(0)
129                    .expect("failed to squeeze node dimension")
130                    .squeeze_tensor(0)
131                    .expect("failed to squeeze head dimension");
132
133                // Concatenate source and destination features
134                let concat_feat = Tensor::cat(&[&src_feat, &dst_feat], 0)
135                    .expect("failed to concatenate features");
136
137                // Compute attention coefficient: a^T [h_i || h_j]
138                // Element-wise multiplication and sum to get scalar
139                let attention_coeff = attention_head
140                    .mul(&concat_feat)
141                    .expect("operation should succeed")
142                    .sum()
143                    .expect("reduction should succeed");
144
145                // Apply LeakyReLU activation
146                let coeff_val =
147                    attention_coeff.to_vec().expect("conversion should succeed")[0] as f64;
148                let activated_val = if coeff_val > 0.0 {
149                    coeff_val
150                } else {
151                    0.2 * coeff_val // LeakyReLU with alpha=0.2
152                };
153
154                attention_scores.push((src, dst, activated_val));
155            }
156
157            // Apply softmax normalization for each destination node
158            let mut normalized_scores = vec![0.0; num_edges];
159            for node in 0..num_nodes {
160                // Find edges pointing to this node
161                let mut node_edge_indices = Vec::new();
162                let mut node_scores = Vec::new();
163
164                for (edge_idx, (_, dst, score)) in attention_scores.iter().enumerate() {
165                    if *dst == node {
166                        node_edge_indices.push(edge_idx);
167                        node_scores.push(*score);
168                    }
169                }
170
171                if !node_scores.is_empty() {
172                    // Apply softmax
173                    let max_score = node_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
174                    let exp_scores: Vec<f64> =
175                        node_scores.iter().map(|s| (*s - max_score).exp()).collect();
176                    let sum_exp: f64 = exp_scores.iter().sum();
177
178                    for (i, &edge_idx) in node_edge_indices.iter().enumerate() {
179                        normalized_scores[edge_idx] = exp_scores[i] / sum_exp;
180                    }
181                }
182            }
183
184            // Aggregate features using attention weights
185            let head_output = zeros(&[num_nodes, self.out_features])
186                .expect("failed to create head output tensor");
187
188            for node in 0..num_nodes {
189                let mut node_output =
190                    zeros(&[self.out_features]).expect("failed to create node output tensor");
191
192                for (edge_idx, (src, dst, _)) in attention_scores.iter().enumerate() {
193                    if *dst == node {
194                        let weight = normalized_scores[edge_idx];
195                        if weight > 0.0 {
196                            let src_feat = x_reshaped
197                                .slice_tensor(0, *src, *src + 1)
198                                .expect("failed to slice source node")
199                                .slice_tensor(1, head, head + 1)
200                                .expect("failed to slice head dimension")
201                                .squeeze_tensor(0)
202                                .expect("failed to squeeze node dimension")
203                                .squeeze_tensor(0)
204                                .expect("failed to squeeze head dimension");
205
206                            let weighted_feat = src_feat
207                                .mul_scalar(weight as f32)
208                                .expect("failed to scale features");
209                            node_output = node_output
210                                .add(&weighted_feat)
211                                .expect("operation should succeed");
212                        }
213                    }
214                }
215
216                // Set the aggregated features for this node
217                let mut node_slice = head_output
218                    .slice_tensor(0, node, node + 1)
219                    .expect("failed to slice node output");
220                let _ = node_slice.copy_(
221                    &node_output
222                        .unsqueeze_tensor(0)
223                        .expect("failed to unsqueeze node output"),
224                );
225            }
226
227            // Place head output into the appropriate slice of the final output
228            let start_feat = head * self.out_features;
229            let end_feat = (head + 1) * self.out_features;
230            let mut output_slice = output
231                .slice_tensor(1, start_feat, end_feat)
232                .expect("failed to slice output tensor");
233            let _ = output_slice.copy_(&head_output);
234        }
235
236        // Add bias if present
237        if let Some(ref bias) = self.bias {
238            output = output
239                .add(&bias.clone_data())
240                .expect("operation should succeed");
241        }
242
243        // Apply dropout if in training mode (placeholder for now)
244        if self.dropout > 0.0 {
245            // Note: For now, we'll skip dropout implementation to focus on core functionality
246            // In a complete implementation, this would apply dropout during training
247        }
248
249        GraphData {
250            x: output,
251            edge_index: graph.edge_index.clone(),
252            edge_attr: graph.edge_attr.clone(),
253            batch: graph.batch.clone(),
254            num_nodes: graph.num_nodes,
255            num_edges: graph.num_edges,
256        }
257    }
258}
259
260impl GraphLayer for GATConv {
261    fn forward(&self, graph: &GraphData) -> GraphData {
262        self.forward(graph)
263    }
264
265    fn parameters(&self) -> Vec<Tensor> {
266        let mut params = vec![self.weight.clone_data(), self.attention.clone_data()];
267        if let Some(ref bias) = self.bias {
268            params.push(bias.clone_data());
269        }
270        params
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use torsh_core::device::DeviceType;
278    use torsh_tensor::creation::from_vec;
279
280    #[test]
281    fn test_gat_creation() {
282        let gat = GATConv::new(16, 8, 4, 0.1, true);
283        let params = gat.parameters();
284        assert_eq!(params.len(), 3); // weight + attention + bias
285        assert_eq!(gat.heads, 4);
286    }
287
288    #[test]
289    fn test_gat_forward() {
290        let gat = GATConv::new(3, 4, 2, 0.0, false);
291
292        // Create simple test graph
293        let x = from_vec(
294            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
295            &[3, 3],
296            DeviceType::Cpu,
297        )
298        .unwrap();
299        let edge_index =
300            from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
301        let graph = GraphData::new(x, edge_index);
302
303        let output = gat.forward(&graph);
304        assert_eq!(output.x.shape().dims(), &[3, 8]); // 2 heads * 4 features
305        assert_eq!(output.num_nodes, 3);
306    }
307}