1use burn::prelude::*;
8use burn::tensor::IndexingUpdateOp;
9
10pub fn scatter_add<B: Backend>(
21 src_features: Tensor<B, 2>,
22 dst_indices: Tensor<B, 1, Int>,
23 num_nodes: usize,
24) -> Tensor<B, 2> {
25 let device = src_features.device();
26 let num_edges = dst_indices.dims()[0];
27 let d = src_features.dims()[1];
28
29 let indices_2d: Tensor<B, 2, Int> = dst_indices.unsqueeze_dim::<2>(1).expand([num_edges, d]);
31
32 let output = Tensor::<B, 2>::zeros([num_nodes, d], &device);
33 output.scatter(0, indices_2d, src_features, IndexingUpdateOp::Add)
34}
35
36pub fn neighborhood_softmax<B: Backend>(
45 edge_scores: Tensor<B, 2>,
46 dst_indices: Tensor<B, 1, Int>,
47 num_nodes: usize,
48) -> Tensor<B, 2> {
49 let device = edge_scores.device();
50 let num_edges = edge_scores.dims()[0];
51
52 let indices_2d: Tensor<B, 2, Int> = dst_indices
53 .clone()
54 .unsqueeze_dim::<2>(1)
55 .expand([num_edges, 1]);
56
57 let clamped = edge_scores.clamp(-20.0, 20.0);
59 let exp_scores = clamped.exp();
60
61 let zeros = Tensor::<B, 2>::zeros([num_nodes, 1], &device);
63 let node_sum = zeros.scatter(0, indices_2d, exp_scores.clone(), IndexingUpdateOp::Add);
64
65 let edge_sum = node_sum.select(0, dst_indices);
67 exp_scores / (edge_sum + 1e-10)
68}
69
70pub struct BatchedGraph<B: Backend> {
72 pub node_features: Tensor<B, 2>,
73 pub src_indices: Tensor<B, 1, Int>,
74 pub dst_indices: Tensor<B, 1, Int>,
75 pub edge_types: Tensor<B, 1, Int>,
76 pub graph_ids: Tensor<B, 1, Int>,
77 pub num_nodes: usize,
78 pub num_graphs: usize,
79}
80
81pub fn batch_graphs<B: Backend>(
83 node_features_list: &[Tensor<B, 2>],
84 src_indices_list: &[Tensor<B, 1, Int>],
85 dst_indices_list: &[Tensor<B, 1, Int>],
86 edge_types_list: &[Tensor<B, 1, Int>],
87 device: &B::Device,
88) -> BatchedGraph<B> {
89 let num_graphs = node_features_list.len();
90 let mut total_nodes = 0usize;
91
92 let mut all_features = Vec::new();
93 let mut all_src = Vec::new();
94 let mut all_dst = Vec::new();
95 let mut all_edge_types = Vec::new();
96 let mut all_graph_ids = Vec::new();
97
98 for i in 0..num_graphs {
99 let n = node_features_list[i].dims()[0];
100 let offset = total_nodes as i64;
101
102 all_features.push(node_features_list[i].clone());
103
104 let offset_tensor =
105 Tensor::<B, 1, Int>::full([src_indices_list[i].dims()[0]], offset, device);
106 all_src.push(src_indices_list[i].clone() + offset_tensor.clone());
107 all_dst.push(dst_indices_list[i].clone() + offset_tensor);
108 all_edge_types.push(edge_types_list[i].clone());
109
110 let graph_id = Tensor::<B, 1, Int>::full([n], i as i64, device);
111 all_graph_ids.push(graph_id);
112
113 total_nodes += n;
114 }
115
116 BatchedGraph {
117 node_features: Tensor::cat(all_features, 0),
118 src_indices: Tensor::cat(all_src, 0),
119 dst_indices: Tensor::cat(all_dst, 0),
120 edge_types: Tensor::cat(all_edge_types, 0),
121 graph_ids: Tensor::cat(all_graph_ids, 0),
122 num_nodes: total_nodes,
123 num_graphs,
124 }
125}
126
127#[cfg(test)]
130mod tests {
131 use super::*;
132 use burn::backend::NdArray;
133
134 type B = NdArray;
135
136 #[test]
137 fn scatter_add_basic() {
138 let device = Default::default();
139 let src = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], &device);
140 let dst = Tensor::<B, 1, Int>::from_ints([0, 0, 1], &device);
141
142 let result = scatter_add(src, dst, 2);
143 let data = result.to_data();
144 assert_eq!(data.as_slice::<f32>().unwrap(), &[4.0, 6.0, 5.0, 6.0]);
145 }
146
147 #[test]
148 fn neighborhood_softmax_sums_to_one() {
149 let device = Default::default();
150 let scores = Tensor::<B, 2>::from_floats([[1.0], [2.0], [3.0], [4.0]], &device);
151 let dst = Tensor::<B, 1, Int>::from_ints([0, 0, 1, 1], &device);
152
153 let result = neighborhood_softmax(scores, dst, 2);
154 let data = result.to_data();
155 let vals = data.as_slice::<f32>().unwrap();
156
157 let sum_node0 = vals[0] + vals[1];
158 assert!((sum_node0 - 1.0).abs() < 1e-5, "node 0 sum: {}", sum_node0);
159
160 let sum_node1 = vals[2] + vals[3];
161 assert!((sum_node1 - 1.0).abs() < 1e-5, "node 1 sum: {}", sum_node1);
162 }
163}