Skip to main content

trident/neural/model/
gnn_ops.rs

1//! Graph neural network operations for burn.
2//!
3//! Provides scatter-based message passing primitives that burn
4//! doesn't have natively. These ops enable GATv2 attention
5//! aggregation over graph neighborhoods.
6
7use burn::prelude::*;
8use burn::tensor::IndexingUpdateOp;
9
10/// Scatter-add: aggregate source features by destination index.
11///
12/// For each edge (src, dst), adds `src_features[edge]` into
13/// `output[dst_indices[edge]]`. Used for GNN message aggregation.
14///
15/// - `src_features`: [num_edges, d] — per-edge feature vectors
16/// - `dst_indices`: [num_edges] — destination node index per edge
17/// - `num_nodes`: total number of nodes (output rows)
18///
19/// Returns: [num_nodes, d] — aggregated features per node
20pub fn scatter_add<B: Backend>(
21    src_features: Tensor<B, 2>,
22    dst_indices: Tensor<B, 1, Int>,
23    num_nodes: usize,
24) -> Tensor<B, 2> {
25    let device = src_features.device();
26    let num_edges = dst_indices.dims()[0];
27    let d = src_features.dims()[1];
28
29    // Expand dst_indices [E] → [E, 1] → [E, d] for scatter
30    let indices_2d: Tensor<B, 2, Int> = dst_indices.unsqueeze_dim::<2>(1).expand([num_edges, d]);
31
32    let output = Tensor::<B, 2>::zeros([num_nodes, d], &device);
33    output.scatter(0, indices_2d, src_features, IndexingUpdateOp::Add)
34}
35
36/// Neighborhood softmax: softmax of edge scores grouped by destination node.
37///
38/// For attention: each destination node's incoming edge scores should
39/// sum to 1.0 after softmax. This implements the grouped softmax.
40///
41/// - `edge_scores`: [num_edges, 1] — raw attention logits
42/// - `dst_indices`: [num_edges] — destination node per edge
43/// - `num_nodes`: total number of nodes
44pub fn neighborhood_softmax<B: Backend>(
45    edge_scores: Tensor<B, 2>,
46    dst_indices: Tensor<B, 1, Int>,
47    num_nodes: usize,
48) -> Tensor<B, 2> {
49    let device = edge_scores.device();
50    let num_edges = edge_scores.dims()[0];
51
52    let indices_2d: Tensor<B, 2, Int> = dst_indices
53        .clone()
54        .unsqueeze_dim::<2>(1)
55        .expand([num_edges, 1]);
56
57    // Clamp scores to prevent overflow in exp (replaces max subtraction)
58    let clamped = edge_scores.clamp(-20.0, 20.0);
59    let exp_scores = clamped.exp();
60
61    // Sum exp per node via scatter-add
62    let zeros = Tensor::<B, 2>::zeros([num_nodes, 1], &device);
63    let node_sum = zeros.scatter(0, indices_2d, exp_scores.clone(), IndexingUpdateOp::Add);
64
65    // Gather sum back to edges and normalize
66    let edge_sum = node_sum.select(0, dst_indices);
67    exp_scores / (edge_sum + 1e-10)
68}
69
70/// Batched graph for packing multiple graphs into one large disconnected graph.
71pub struct BatchedGraph<B: Backend> {
72    pub node_features: Tensor<B, 2>,
73    pub src_indices: Tensor<B, 1, Int>,
74    pub dst_indices: Tensor<B, 1, Int>,
75    pub edge_types: Tensor<B, 1, Int>,
76    pub graph_ids: Tensor<B, 1, Int>,
77    pub num_nodes: usize,
78    pub num_graphs: usize,
79}
80
81/// Batch multiple graphs into a single large disconnected graph.
82pub fn batch_graphs<B: Backend>(
83    node_features_list: &[Tensor<B, 2>],
84    src_indices_list: &[Tensor<B, 1, Int>],
85    dst_indices_list: &[Tensor<B, 1, Int>],
86    edge_types_list: &[Tensor<B, 1, Int>],
87    device: &B::Device,
88) -> BatchedGraph<B> {
89    let num_graphs = node_features_list.len();
90    let mut total_nodes = 0usize;
91
92    let mut all_features = Vec::new();
93    let mut all_src = Vec::new();
94    let mut all_dst = Vec::new();
95    let mut all_edge_types = Vec::new();
96    let mut all_graph_ids = Vec::new();
97
98    for i in 0..num_graphs {
99        let n = node_features_list[i].dims()[0];
100        let offset = total_nodes as i64;
101
102        all_features.push(node_features_list[i].clone());
103
104        let offset_tensor =
105            Tensor::<B, 1, Int>::full([src_indices_list[i].dims()[0]], offset, device);
106        all_src.push(src_indices_list[i].clone() + offset_tensor.clone());
107        all_dst.push(dst_indices_list[i].clone() + offset_tensor);
108        all_edge_types.push(edge_types_list[i].clone());
109
110        let graph_id = Tensor::<B, 1, Int>::full([n], i as i64, device);
111        all_graph_ids.push(graph_id);
112
113        total_nodes += n;
114    }
115
116    BatchedGraph {
117        node_features: Tensor::cat(all_features, 0),
118        src_indices: Tensor::cat(all_src, 0),
119        dst_indices: Tensor::cat(all_dst, 0),
120        edge_types: Tensor::cat(all_edge_types, 0),
121        graph_ids: Tensor::cat(all_graph_ids, 0),
122        num_nodes: total_nodes,
123        num_graphs,
124    }
125}
126
127// ─── Tests ────────────────────────────────────────────────────────
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use burn::backend::NdArray;
133
134    type B = NdArray;
135
136    #[test]
137    fn scatter_add_basic() {
138        let device = Default::default();
139        let src = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], &device);
140        let dst = Tensor::<B, 1, Int>::from_ints([0, 0, 1], &device);
141
142        let result = scatter_add(src, dst, 2);
143        let data = result.to_data();
144        assert_eq!(data.as_slice::<f32>().unwrap(), &[4.0, 6.0, 5.0, 6.0]);
145    }
146
147    #[test]
148    fn neighborhood_softmax_sums_to_one() {
149        let device = Default::default();
150        let scores = Tensor::<B, 2>::from_floats([[1.0], [2.0], [3.0], [4.0]], &device);
151        let dst = Tensor::<B, 1, Int>::from_ints([0, 0, 1, 1], &device);
152
153        let result = neighborhood_softmax(scores, dst, 2);
154        let data = result.to_data();
155        let vals = data.as_slice::<f32>().unwrap();
156
157        let sum_node0 = vals[0] + vals[1];
158        assert!((sum_node0 - 1.0).abs() < 1e-5, "node 0 sum: {}", sum_node0);
159
160        let sum_node1 = vals[2] + vals[3];
161        assert!((sum_node1 - 1.0).abs() < 1e-5, "node 1 sum: {}", sum_node1);
162    }
163}