ruvector_attention/graph/
edge_featured.rs1use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7use crate::utils::stable_softmax;
8
9#[derive(Clone, Debug)]
11pub struct EdgeFeaturedConfig {
12 pub node_dim: usize,
13 pub edge_dim: usize,
14 pub num_heads: usize,
15 pub dropout: f32,
16 pub concat_heads: bool,
17 pub add_self_loops: bool,
18 pub negative_slope: f32, }
20
21impl Default for EdgeFeaturedConfig {
22 fn default() -> Self {
23 Self {
24 node_dim: 256,
25 edge_dim: 64,
26 num_heads: 4,
27 dropout: 0.0,
28 concat_heads: true,
29 add_self_loops: true,
30 negative_slope: 0.2,
31 }
32 }
33}
34
35impl EdgeFeaturedConfig {
36 pub fn builder() -> EdgeFeaturedConfigBuilder {
37 EdgeFeaturedConfigBuilder::default()
38 }
39
40 pub fn head_dim(&self) -> usize {
41 self.node_dim / self.num_heads
42 }
43}
44
45#[derive(Default)]
46pub struct EdgeFeaturedConfigBuilder {
47 config: EdgeFeaturedConfig,
48}
49
50impl EdgeFeaturedConfigBuilder {
51 pub fn node_dim(mut self, d: usize) -> Self {
52 self.config.node_dim = d;
53 self
54 }
55
56 pub fn edge_dim(mut self, d: usize) -> Self {
57 self.config.edge_dim = d;
58 self
59 }
60
61 pub fn num_heads(mut self, n: usize) -> Self {
62 self.config.num_heads = n;
63 self
64 }
65
66 pub fn dropout(mut self, d: f32) -> Self {
67 self.config.dropout = d;
68 self
69 }
70
71 pub fn concat_heads(mut self, c: bool) -> Self {
72 self.config.concat_heads = c;
73 self
74 }
75
76 pub fn negative_slope(mut self, s: f32) -> Self {
77 self.config.negative_slope = s;
78 self
79 }
80
81 pub fn build(self) -> EdgeFeaturedConfig {
82 self.config
83 }
84}
85
86pub struct EdgeFeaturedAttention {
88 config: EdgeFeaturedConfig,
89 w_node: Vec<f32>, w_edge: Vec<f32>, a_src: Vec<f32>, a_dst: Vec<f32>, a_edge: Vec<f32>, }
96
97impl EdgeFeaturedAttention {
98 pub fn new(config: EdgeFeaturedConfig) -> Self {
99 let head_dim = config.head_dim();
100 let num_heads = config.num_heads;
101
102 let node_scale = (2.0 / (config.node_dim + head_dim) as f32).sqrt();
104 let edge_scale = (2.0 / (config.edge_dim + head_dim) as f32).sqrt();
105 let attn_scale = (1.0 / head_dim as f32).sqrt();
106
107 let mut seed = 42u64;
108 let mut rand = || {
109 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
110 (seed as f32) / (u64::MAX as f32) - 0.5
111 };
112
113 let w_node: Vec<f32> = (0..num_heads * head_dim * config.node_dim)
114 .map(|_| rand() * 2.0 * node_scale)
115 .collect();
116
117 let w_edge: Vec<f32> = (0..num_heads * head_dim * config.edge_dim)
118 .map(|_| rand() * 2.0 * edge_scale)
119 .collect();
120
121 let a_src: Vec<f32> = (0..num_heads * head_dim)
122 .map(|_| rand() * 2.0 * attn_scale)
123 .collect();
124
125 let a_dst: Vec<f32> = (0..num_heads * head_dim)
126 .map(|_| rand() * 2.0 * attn_scale)
127 .collect();
128
129 let a_edge: Vec<f32> = (0..num_heads * head_dim)
130 .map(|_| rand() * 2.0 * attn_scale)
131 .collect();
132
133 Self {
134 config,
135 w_node,
136 w_edge,
137 a_src,
138 a_dst,
139 a_edge,
140 }
141 }
142
143 fn transform_node(&self, node: &[f32], head: usize) -> Vec<f32> {
145 let head_dim = self.config.head_dim();
146 let node_dim = self.config.node_dim;
147
148 (0..head_dim)
149 .map(|i| {
150 node.iter()
151 .enumerate()
152 .map(|(j, &nj)| nj * self.w_node[head * head_dim * node_dim + i * node_dim + j])
153 .sum()
154 })
155 .collect()
156 }
157
158 fn transform_edge(&self, edge: &[f32], head: usize) -> Vec<f32> {
160 let head_dim = self.config.head_dim();
161 let edge_dim = self.config.edge_dim;
162
163 (0..head_dim)
164 .map(|i| {
165 edge.iter()
166 .enumerate()
167 .map(|(j, &ej)| ej * self.w_edge[head * head_dim * edge_dim + i * edge_dim + j])
168 .sum()
169 })
170 .collect()
171 }
172
173 fn attention_coeff(&self, src: &[f32], dst: &[f32], edge: &[f32], head: usize) -> f32 {
175 let head_dim = self.config.head_dim();
176
177 let mut score = 0.0f32;
178 for i in 0..head_dim {
179 let offset = head * head_dim + i;
180 score += src[i] * self.a_src[offset];
181 score += dst[i] * self.a_dst[offset];
182 score += edge[i] * self.a_edge[offset];
183 }
184
185 if score < 0.0 {
187 self.config.negative_slope * score
188 } else {
189 score
190 }
191 }
192}
193
194impl EdgeFeaturedAttention {
195 pub fn compute_with_edges(
197 &self,
198 query: &[f32],
199 keys: &[&[f32]],
200 values: &[&[f32]],
201 edges: &[&[f32]],
202 ) -> AttentionResult<Vec<f32>> {
203 if keys.len() != edges.len() {
204 return Err(AttentionError::InvalidConfig(
205 "Keys and edges must have same length".to_string(),
206 ));
207 }
208
209 let num_heads = self.config.num_heads;
210 let head_dim = self.config.head_dim();
211 let n = keys.len();
212
213 let query_transformed: Vec<Vec<f32>> = (0..num_heads)
215 .map(|h| self.transform_node(query, h))
216 .collect();
217
218 let mut head_outputs: Vec<Vec<f32>> = Vec::with_capacity(num_heads);
220
221 for h in 0..num_heads {
222 let keys_t: Vec<Vec<f32>> = keys.iter().map(|k| self.transform_node(k, h)).collect();
224 let edges_t: Vec<Vec<f32>> = edges.iter().map(|e| self.transform_edge(e, h)).collect();
225
226 let coeffs: Vec<f32> = (0..n)
228 .map(|i| self.attention_coeff(&query_transformed[h], &keys_t[i], &edges_t[i], h))
229 .collect();
230
231 let weights = stable_softmax(&coeffs);
233
234 let mut head_out = vec![0.0f32; head_dim];
236 for (i, &w) in weights.iter().enumerate() {
237 let value_t = self.transform_node(values[i], h);
238 for (j, &vj) in value_t.iter().enumerate() {
239 head_out[j] += w * vj;
240 }
241 }
242
243 head_outputs.push(head_out);
244 }
245
246 if self.config.concat_heads {
248 Ok(head_outputs.into_iter().flatten().collect())
249 } else {
250 let mut output = vec![0.0f32; head_dim];
251 for head_out in &head_outputs {
252 for (i, &v) in head_out.iter().enumerate() {
253 output[i] += v / num_heads as f32;
254 }
255 }
256 Ok(output)
257 }
258 }
259
260 pub fn edge_dim(&self) -> usize {
262 self.config.edge_dim
263 }
264}
265
266impl Attention for EdgeFeaturedAttention {
267 fn compute(
268 &self,
269 query: &[f32],
270 keys: &[&[f32]],
271 values: &[&[f32]],
272 ) -> AttentionResult<Vec<f32>> {
273 if keys.is_empty() {
274 return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
275 }
276 if query.len() != self.config.node_dim {
277 return Err(AttentionError::DimensionMismatch {
278 expected: self.config.node_dim,
279 actual: query.len(),
280 });
281 }
282
283 let zero_edge = vec![0.0f32; self.config.edge_dim];
285 let edges: Vec<&[f32]> = (0..keys.len()).map(|_| zero_edge.as_slice()).collect();
286
287 self.compute_with_edges(query, keys, values, &edges)
288 }
289
290 fn compute_with_mask(
291 &self,
292 query: &[f32],
293 keys: &[&[f32]],
294 values: &[&[f32]],
295 mask: Option<&[bool]>,
296 ) -> AttentionResult<Vec<f32>> {
297 if let Some(m) = mask {
299 let filtered: Vec<(usize, bool)> = m
300 .iter()
301 .copied()
302 .enumerate()
303 .filter(|(_, keep)| *keep)
304 .collect();
305 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
306 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
307 self.compute(query, &filtered_keys, &filtered_values)
308 } else {
309 self.compute(query, keys, values)
310 }
311 }
312
313 fn dim(&self) -> usize {
314 if self.config.concat_heads {
315 self.config.node_dim
316 } else {
317 self.config.head_dim()
318 }
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_edge_featured_attention() {
328 let config = EdgeFeaturedConfig::builder()
329 .node_dim(64)
330 .edge_dim(16)
331 .num_heads(4)
332 .build();
333
334 let attn = EdgeFeaturedAttention::new(config);
335
336 let query = vec![0.5; 64];
337 let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.3; 64]).collect();
338 let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
339 let edges: Vec<Vec<f32>> = (0..10).map(|_| vec![0.2; 16]).collect();
340
341 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
342 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
343 let edges_refs: Vec<&[f32]> = edges.iter().map(|e| e.as_slice()).collect();
344
345 let result = attn
346 .compute_with_edges(&query, &keys_refs, &values_refs, &edges_refs)
347 .unwrap();
348 assert_eq!(result.len(), 64);
349 }
350
351 #[test]
352 fn test_without_edges() {
353 let config = EdgeFeaturedConfig::builder()
354 .node_dim(32)
355 .edge_dim(8)
356 .num_heads(2)
357 .build();
358
359 let attn = EdgeFeaturedAttention::new(config);
360
361 let query = vec![0.5; 32];
362 let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
363 let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
364
365 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
366 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
367
368 let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
369 assert_eq!(result.len(), 32);
370 }
371
372 #[test]
373 fn test_leaky_relu() {
374 let config = EdgeFeaturedConfig::builder()
375 .node_dim(16)
376 .edge_dim(4)
377 .num_heads(1)
378 .negative_slope(0.2)
379 .build();
380
381 let attn = EdgeFeaturedAttention::new(config);
382
383 let query = vec![-1.0; 16];
385 let keys: Vec<Vec<f32>> = vec![vec![-0.5; 16]; 3];
386 let values: Vec<Vec<f32>> = vec![vec![1.0; 16]; 3];
387
388 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
389 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
390
391 let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
392 assert_eq!(result.len(), 16);
393 }
394}