1use burn::config::Config;
9use burn::module::Module;
10use burn::nn::{Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig};
11use burn::prelude::*;
12use burn::tensor::activation::leaky_relu;
13
14use super::gnn_ops::{neighborhood_softmax, scatter_add};
15use crate::neural::data::tir_graph::NODE_FEATURE_DIM;
16
17#[derive(Config, Debug)]
21pub struct GatV2LayerConfig {
22 pub d_in: usize,
24 pub d_out: usize,
26 #[config(default = 32)]
28 pub d_edge: usize,
29 #[config(default = 3)]
31 pub num_edge_types: usize,
32 #[config(default = 0.2)]
34 pub leaky_relu_alpha: f64,
35}
36
37#[derive(Config, Debug)]
39pub struct GnnEncoderConfig {
40 #[config(default = 256)]
42 pub d_model: usize,
43 #[config(default = 4)]
45 pub num_layers: usize,
46 #[config(default = 32)]
48 pub d_edge: usize,
49}
50
51#[derive(Module, Debug)]
58pub struct GatV2Layer<B: Backend> {
59 w_src: Linear<B>,
61 w_dst: Linear<B>,
63 w_edge: Linear<B>,
65 attn: Linear<B>,
67 ffn: Linear<B>,
69 norm: LayerNorm<B>,
71 leaky_alpha: f64,
73}
74
75impl GatV2LayerConfig {
76 pub fn init<B: Backend>(&self, device: &B::Device) -> GatV2Layer<B> {
78 GatV2Layer {
79 w_src: LinearConfig::new(self.d_in, self.d_out).init(device),
80 w_dst: LinearConfig::new(self.d_in, self.d_out).init(device),
81 w_edge: LinearConfig::new(self.d_edge, self.d_out).init(device),
82 attn: LinearConfig::new(self.d_out, 1).init(device),
83 ffn: LinearConfig::new(self.d_out, self.d_out).init(device),
84 norm: LayerNormConfig::new(self.d_out).init(device),
85 leaky_alpha: self.leaky_relu_alpha,
86 }
87 }
88}
89
90impl<B: Backend> GatV2Layer<B> {
91 pub fn forward(
101 &self,
102 node_features: Tensor<B, 2>,
103 src_indices: Tensor<B, 1, Int>,
104 dst_indices: Tensor<B, 1, Int>,
105 edge_embeddings: Tensor<B, 2>,
106 num_nodes: usize,
107 ) -> Tensor<B, 2> {
108 let num_edges = src_indices.dims()[0];
109 let d_out = self.ffn.weight.dims()[0];
110
111 let h_src = self.w_src.forward(node_features.clone());
113 let h_dst = self.w_dst.forward(node_features.clone());
114
115 let h_src_edge = h_src.select(0, src_indices.clone()); let h_dst_edge = h_dst.select(0, dst_indices.clone()); let e_proj = self.w_edge.forward(edge_embeddings); let combined = h_src_edge.clone() + h_dst_edge + e_proj;
122 let activated = leaky_relu(combined, self.leaky_alpha);
123 let attn_logits = self.attn.forward(activated); let attn_weights = neighborhood_softmax(attn_logits, dst_indices.clone(), num_nodes);
127
128 let attn_expanded = attn_weights.expand([num_edges, d_out]);
130 let messages = h_src_edge * attn_expanded;
131 let aggregated = scatter_add(messages, dst_indices, num_nodes);
132
133 let out = self.ffn.forward(aggregated);
135
136 let residual = if node_features.dims()[1] == d_out {
138 out + node_features
139 } else {
140 out
141 };
142
143 self.norm.forward(residual)
144 }
145}
146
147#[derive(Module, Debug)]
154pub struct GnnEncoder<B: Backend> {
155 node_proj: Linear<B>,
157 edge_embed: Embedding<B>,
159 layers: Vec<GatV2Layer<B>>,
161 global_proj: Linear<B>,
163}
164
165impl GnnEncoderConfig {
166 pub fn init<B: Backend>(&self, device: &B::Device) -> GnnEncoder<B> {
168 let mut layers = Vec::with_capacity(self.num_layers);
169
170 for i in 0..self.num_layers {
171 let d_in = if i == 0 { self.d_model } else { self.d_model };
172 let config = GatV2LayerConfig {
173 d_in,
174 d_out: self.d_model,
175 d_edge: self.d_edge,
176 num_edge_types: 3,
177 leaky_relu_alpha: 0.2,
178 };
179 layers.push(config.init(device));
180 }
181
182 GnnEncoder {
183 node_proj: LinearConfig::new(NODE_FEATURE_DIM, self.d_model).init(device),
184 edge_embed: EmbeddingConfig::new(3, self.d_edge).init(device),
185 layers,
186 global_proj: LinearConfig::new(self.d_model * 2, self.d_model).init(device),
187 }
188 }
189}
190
191impl<B: Backend> GnnEncoder<B> {
192 pub fn forward(
201 &self,
202 node_features: Tensor<B, 2>,
203 src_indices: Tensor<B, 1, Int>,
204 dst_indices: Tensor<B, 1, Int>,
205 edge_types: Tensor<B, 1, Int>,
206 ) -> (Tensor<B, 2>, Tensor<B, 1>) {
207 let num_nodes = node_features.dims()[0];
208
209 let mut h = self.node_proj.forward(node_features);
211
212 let edge_types_2d: Tensor<B, 2, Int> = edge_types.unsqueeze_dim::<2>(1);
214 let edge_emb_3d = self.edge_embed.forward(edge_types_2d); let edge_emb: Tensor<B, 2> = edge_emb_3d.squeeze_dim::<2>(1); for layer in &self.layers {
219 h = layer.forward(
220 h,
221 src_indices.clone(),
222 dst_indices.clone(),
223 edge_emb.clone(),
224 num_nodes,
225 );
226 }
227
228 let mean_pool: Tensor<B, 1> = h.clone().mean_dim(0).squeeze_dim::<1>(0);
230 let max_pool: Tensor<B, 1> = h.clone().max_dim(0).squeeze_dim::<1>(0);
231 let global_input = Tensor::cat(vec![mean_pool, max_pool], 0); let global: Tensor<B, 1> = self
233 .global_proj
234 .forward(global_input.unsqueeze_dim::<2>(0)) .squeeze_dim::<1>(0); (h, global)
238 }
239
240 pub fn num_params(&self) -> usize {
242 let mut total = NODE_FEATURE_DIM * self.global_proj.weight.dims()[0] / 2; for layer in &self.layers {
247 let d = layer.ffn.weight.dims()[0];
248 total += d * d * 3; total += d; total += d * layer.w_edge.weight.dims()[1]; }
252
253 total
254 }
255}
256
257#[cfg(test)]
260mod tests {
261 use super::*;
262 use burn::backend::NdArray;
263
264 type B = NdArray;
265
266 #[test]
267 fn gnn_encoder_forward_shape() {
268 let device = Default::default();
269 let config = GnnEncoderConfig {
270 d_model: 32, num_layers: 2,
272 d_edge: 8,
273 };
274 let encoder = config.init::<B>(&device);
275
276 let features = Tensor::<B, 2>::zeros([5, NODE_FEATURE_DIM], &device);
278 let src = Tensor::<B, 1, Int>::from_ints([0, 1, 2, 3], &device);
279 let dst = Tensor::<B, 1, Int>::from_ints([1, 2, 3, 4], &device);
280 let edge_types = Tensor::<B, 1, Int>::from_ints([0, 1, 1, 2], &device);
281
282 let (node_emb, global): (Tensor<B, 2>, Tensor<B, 1>) =
283 encoder.forward(features, src, dst, edge_types);
284
285 assert_eq!(node_emb.dims(), [5, 32]);
286 assert_eq!(global.dims(), [32]);
287 }
288
289 #[test]
290 fn gatv2_layer_preserves_node_count() {
291 let device = Default::default();
292 let config = GatV2LayerConfig {
293 d_in: 16,
294 d_out: 16,
295 d_edge: 8,
296 num_edge_types: 3,
297 leaky_relu_alpha: 0.2,
298 };
299 let layer = config.init::<B>(&device);
300
301 let features = Tensor::<B, 2>::zeros([3, 16], &device);
302 let src = Tensor::<B, 1, Int>::from_ints([0, 1], &device);
303 let dst = Tensor::<B, 1, Int>::from_ints([1, 2], &device);
304 let edge_emb = Tensor::<B, 2>::zeros([2, 8], &device);
305
306 let output = layer.forward(features, src, dst, edge_emb, 3);
307 assert_eq!(output.dims(), [3, 16]);
308 }
309}