Skip to main content

trident/neural/model/
encoder.rs

1//! GNN Encoder — GATv2 (Graph Attention Network v2) in burn.
2//!
3//! Encodes a TirGraph into node embeddings + global context vector.
4//! 3-4 GATv2 layers, d=256, ~3M parameters.
5//!
6//! CPU for single-graph inference, GPU for batched training.
7
8use burn::config::Config;
9use burn::module::Module;
10use burn::nn::{Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig};
11use burn::prelude::*;
12use burn::tensor::activation::leaky_relu;
13
14use super::gnn_ops::{neighborhood_softmax, scatter_add};
15use crate::neural::data::tir_graph::NODE_FEATURE_DIM;
16
17// ─── Configuration ────────────────────────────────────────────────
18
19/// GATv2 layer configuration.
20#[derive(Config, Debug)]
21pub struct GatV2LayerConfig {
22    /// Input feature dimension.
23    pub d_in: usize,
24    /// Output feature dimension.
25    pub d_out: usize,
26    /// Edge embedding dimension.
27    #[config(default = 32)]
28    pub d_edge: usize,
29    /// Number of edge types.
30    #[config(default = 3)]
31    pub num_edge_types: usize,
32    /// Negative slope for LeakyReLU.
33    #[config(default = 0.2)]
34    pub leaky_relu_alpha: f64,
35}
36
37/// GNN Encoder configuration.
38#[derive(Config, Debug)]
39pub struct GnnEncoderConfig {
40    /// Model dimension (node embedding size).
41    #[config(default = 256)]
42    pub d_model: usize,
43    /// Number of GATv2 layers.
44    #[config(default = 4)]
45    pub num_layers: usize,
46    /// Edge embedding dimension.
47    #[config(default = 32)]
48    pub d_edge: usize,
49}
50
51// ─── GATv2 Layer ──────────────────────────────────────────────────
52
53/// Single GATv2 attention layer.
54///
55/// Implements: a^T · LeakyReLU(W_src·h_i + W_dst·h_j + W_edge·e_ij)
56/// with softmax per neighborhood, followed by FFN + residual + LayerNorm.
57#[derive(Module, Debug)]
58pub struct GatV2Layer<B: Backend> {
59    /// Source node projection.
60    w_src: Linear<B>,
61    /// Destination node projection.
62    w_dst: Linear<B>,
63    /// Edge type projection.
64    w_edge: Linear<B>,
65    /// Attention scoring vector (projects concatenated features to scalar).
66    attn: Linear<B>,
67    /// Output FFN.
68    ffn: Linear<B>,
69    /// Layer normalization.
70    norm: LayerNorm<B>,
71    /// LeakyReLU negative slope.
72    leaky_alpha: f64,
73}
74
75impl GatV2LayerConfig {
76    /// Initialize a GATv2 layer.
77    pub fn init<B: Backend>(&self, device: &B::Device) -> GatV2Layer<B> {
78        GatV2Layer {
79            w_src: LinearConfig::new(self.d_in, self.d_out).init(device),
80            w_dst: LinearConfig::new(self.d_in, self.d_out).init(device),
81            w_edge: LinearConfig::new(self.d_edge, self.d_out).init(device),
82            attn: LinearConfig::new(self.d_out, 1).init(device),
83            ffn: LinearConfig::new(self.d_out, self.d_out).init(device),
84            norm: LayerNormConfig::new(self.d_out).init(device),
85            leaky_alpha: self.leaky_relu_alpha,
86        }
87    }
88}
89
90impl<B: Backend> GatV2Layer<B> {
91    /// Forward pass: GATv2 message passing.
92    ///
93    /// - `node_features`: [N, d_in] — node feature matrix
94    /// - `src_indices`: [E] — source node index per edge
95    /// - `dst_indices`: [E] — destination node index per edge
96    /// - `edge_embeddings`: [E, d_edge] — edge type embeddings
97    /// - `num_nodes`: N
98    ///
99    /// Returns: [N, d_out] — updated node features
100    pub fn forward(
101        &self,
102        node_features: Tensor<B, 2>,
103        src_indices: Tensor<B, 1, Int>,
104        dst_indices: Tensor<B, 1, Int>,
105        edge_embeddings: Tensor<B, 2>,
106        num_nodes: usize,
107    ) -> Tensor<B, 2> {
108        let num_edges = src_indices.dims()[0];
109        let d_out = self.ffn.weight.dims()[0];
110
111        // Project source and destination features
112        let h_src = self.w_src.forward(node_features.clone());
113        let h_dst = self.w_dst.forward(node_features.clone());
114
115        // Gather per-edge features
116        let h_src_edge = h_src.select(0, src_indices.clone()); // [E, d_out]
117        let h_dst_edge = h_dst.select(0, dst_indices.clone()); // [E, d_out]
118        let e_proj = self.w_edge.forward(edge_embeddings); // [E, d_out]
119
120        // GATv2 attention: a^T · LeakyReLU(h_src + h_dst + e)
121        let combined = h_src_edge.clone() + h_dst_edge + e_proj;
122        let activated = leaky_relu(combined, self.leaky_alpha);
123        let attn_logits = self.attn.forward(activated); // [E, 1]
124
125        // Neighborhood softmax
126        let attn_weights = neighborhood_softmax(attn_logits, dst_indices.clone(), num_nodes);
127
128        // Weighted message aggregation: broadcast [E, 1] to [E, d_out]
129        let attn_expanded = attn_weights.expand([num_edges, d_out]);
130        let messages = h_src_edge * attn_expanded;
131        let aggregated = scatter_add(messages, dst_indices, num_nodes);
132
133        // FFN + residual + norm
134        let out = self.ffn.forward(aggregated);
135
136        // Residual connection (only if dimensions match)
137        let residual = if node_features.dims()[1] == d_out {
138            out + node_features
139        } else {
140            out
141        };
142
143        self.norm.forward(residual)
144    }
145}
146
147// ─── GNN Encoder ──────────────────────────────────────────────────
148
149/// GNN Encoder: stack of GATv2 layers with global pooling.
150///
151/// Input: TirGraph node features + edge structure
152/// Output: (node_embeddings [N, d], global_context [d])
153#[derive(Module, Debug)]
154pub struct GnnEncoder<B: Backend> {
155    /// Initial node feature projection: NODE_FEATURE_DIM → d_model
156    node_proj: Linear<B>,
157    /// Edge type embedding: 3 types → d_edge
158    edge_embed: Embedding<B>,
159    /// Stack of GATv2 layers
160    layers: Vec<GatV2Layer<B>>,
161    /// Global pooling projection: 2*d_model → d_model (mean+max concatenated)
162    global_proj: Linear<B>,
163}
164
165impl GnnEncoderConfig {
166    /// Initialize a GNN encoder.
167    pub fn init<B: Backend>(&self, device: &B::Device) -> GnnEncoder<B> {
168        let mut layers = Vec::with_capacity(self.num_layers);
169
170        for i in 0..self.num_layers {
171            let d_in = if i == 0 { self.d_model } else { self.d_model };
172            let config = GatV2LayerConfig {
173                d_in,
174                d_out: self.d_model,
175                d_edge: self.d_edge,
176                num_edge_types: 3,
177                leaky_relu_alpha: 0.2,
178            };
179            layers.push(config.init(device));
180        }
181
182        GnnEncoder {
183            node_proj: LinearConfig::new(NODE_FEATURE_DIM, self.d_model).init(device),
184            edge_embed: EmbeddingConfig::new(3, self.d_edge).init(device),
185            layers,
186            global_proj: LinearConfig::new(self.d_model * 2, self.d_model).init(device),
187        }
188    }
189}
190
191impl<B: Backend> GnnEncoder<B> {
192    /// Encode a graph into node embeddings and a global context vector.
193    ///
194    /// - `node_features`: [N, NODE_FEATURE_DIM] — raw node feature vectors
195    /// - `src_indices`: [E] — source node index per edge
196    /// - `dst_indices`: [E] — destination node index per edge
197    /// - `edge_types`: [E] — edge type (0=DataDep, 1=ControlFlow, 2=MemOrder)
198    ///
199    /// Returns: (node_embeddings [N, d_model], global_context [d_model])
200    pub fn forward(
201        &self,
202        node_features: Tensor<B, 2>,
203        src_indices: Tensor<B, 1, Int>,
204        dst_indices: Tensor<B, 1, Int>,
205        edge_types: Tensor<B, 1, Int>,
206    ) -> (Tensor<B, 2>, Tensor<B, 1>) {
207        let num_nodes = node_features.dims()[0];
208
209        // Project node features to model dimension
210        let mut h = self.node_proj.forward(node_features);
211
212        // Embed edge types: [E] → [E, 1] → Embedding → [E, 1, d_edge] → [E, d_edge]
213        let edge_types_2d: Tensor<B, 2, Int> = edge_types.unsqueeze_dim::<2>(1);
214        let edge_emb_3d = self.edge_embed.forward(edge_types_2d); // [E, 1, d_edge]
215        let edge_emb: Tensor<B, 2> = edge_emb_3d.squeeze_dim::<2>(1); // [E, d_edge]
216
217        // GATv2 layers
218        for layer in &self.layers {
219            h = layer.forward(
220                h,
221                src_indices.clone(),
222                dst_indices.clone(),
223                edge_emb.clone(),
224                num_nodes,
225            );
226        }
227
228        // Global pooling: mean + max → project to d_model
229        let mean_pool: Tensor<B, 1> = h.clone().mean_dim(0).squeeze_dim::<1>(0);
230        let max_pool: Tensor<B, 1> = h.clone().max_dim(0).squeeze_dim::<1>(0);
231        let global_input = Tensor::cat(vec![mean_pool, max_pool], 0); // [2*d_model]
232        let global: Tensor<B, 1> = self
233            .global_proj
234            .forward(global_input.unsqueeze_dim::<2>(0)) // [2*d] → [1, 2*d] → [1, d]
235            .squeeze_dim::<1>(0); // [1, d] → [d]
236
237        (h, global)
238    }
239
240    /// Count total parameters.
241    pub fn num_params(&self) -> usize {
242        // node_proj
243        let mut total = NODE_FEATURE_DIM * self.global_proj.weight.dims()[0] / 2; // approximate
244
245        // Each GATv2 layer
246        for layer in &self.layers {
247            let d = layer.ffn.weight.dims()[0];
248            total += d * d * 3; // w_src, w_dst, ffn
249            total += d; // attn
250            total += d * layer.w_edge.weight.dims()[1]; // w_edge
251        }
252
253        total
254    }
255}
256
257// ─── Tests ────────────────────────────────────────────────────────
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use burn::backend::NdArray;
263
264    type B = NdArray;
265
266    #[test]
267    fn gnn_encoder_forward_shape() {
268        let device = Default::default();
269        let config = GnnEncoderConfig {
270            d_model: 32, // Small for test
271            num_layers: 2,
272            d_edge: 8,
273        };
274        let encoder = config.init::<B>(&device);
275
276        // 5 nodes, 4 edges
277        let features = Tensor::<B, 2>::zeros([5, NODE_FEATURE_DIM], &device);
278        let src = Tensor::<B, 1, Int>::from_ints([0, 1, 2, 3], &device);
279        let dst = Tensor::<B, 1, Int>::from_ints([1, 2, 3, 4], &device);
280        let edge_types = Tensor::<B, 1, Int>::from_ints([0, 1, 1, 2], &device);
281
282        let (node_emb, global): (Tensor<B, 2>, Tensor<B, 1>) =
283            encoder.forward(features, src, dst, edge_types);
284
285        assert_eq!(node_emb.dims(), [5, 32]);
286        assert_eq!(global.dims(), [32]);
287    }
288
289    #[test]
290    fn gatv2_layer_preserves_node_count() {
291        let device = Default::default();
292        let config = GatV2LayerConfig {
293            d_in: 16,
294            d_out: 16,
295            d_edge: 8,
296            num_edge_types: 3,
297            leaky_relu_alpha: 0.2,
298        };
299        let layer = config.init::<B>(&device);
300
301        let features = Tensor::<B, 2>::zeros([3, 16], &device);
302        let src = Tensor::<B, 1, Int>::from_ints([0, 1], &device);
303        let dst = Tensor::<B, 1, Int>::from_ints([1, 2], &device);
304        let edge_emb = Tensor::<B, 2>::zeros([2, 8], &device);
305
306        let output = layer.forward(features, src, dst, edge_emb, 3);
307        assert_eq!(output.dims(), [3, 16]);
308    }
309}