ruvector_attention/graph/
edge_featured.rs

1//! Edge-featured graph attention (GATv2 style)
2//!
3//! Extends standard graph attention with edge feature integration.
4
5use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7use crate::utils::stable_softmax;
8
9/// Configuration for edge-featured attention
10#[derive(Clone, Debug)]
11pub struct EdgeFeaturedConfig {
12    pub node_dim: usize,
13    pub edge_dim: usize,
14    pub num_heads: usize,
15    pub dropout: f32,
16    pub concat_heads: bool,
17    pub add_self_loops: bool,
18    pub negative_slope: f32, // LeakyReLU slope
19}
20
21impl Default for EdgeFeaturedConfig {
22    fn default() -> Self {
23        Self {
24            node_dim: 256,
25            edge_dim: 64,
26            num_heads: 4,
27            dropout: 0.0,
28            concat_heads: true,
29            add_self_loops: true,
30            negative_slope: 0.2,
31        }
32    }
33}
34
35impl EdgeFeaturedConfig {
36    pub fn builder() -> EdgeFeaturedConfigBuilder {
37        EdgeFeaturedConfigBuilder::default()
38    }
39
40    pub fn head_dim(&self) -> usize {
41        self.node_dim / self.num_heads
42    }
43}
44
45#[derive(Default)]
46pub struct EdgeFeaturedConfigBuilder {
47    config: EdgeFeaturedConfig,
48}
49
50impl EdgeFeaturedConfigBuilder {
51    pub fn node_dim(mut self, d: usize) -> Self {
52        self.config.node_dim = d;
53        self
54    }
55
56    pub fn edge_dim(mut self, d: usize) -> Self {
57        self.config.edge_dim = d;
58        self
59    }
60
61    pub fn num_heads(mut self, n: usize) -> Self {
62        self.config.num_heads = n;
63        self
64    }
65
66    pub fn dropout(mut self, d: f32) -> Self {
67        self.config.dropout = d;
68        self
69    }
70
71    pub fn concat_heads(mut self, c: bool) -> Self {
72        self.config.concat_heads = c;
73        self
74    }
75
76    pub fn negative_slope(mut self, s: f32) -> Self {
77        self.config.negative_slope = s;
78        self
79    }
80
81    pub fn build(self) -> EdgeFeaturedConfig {
82        self.config
83    }
84}
85
86/// Edge-featured graph attention layer
87pub struct EdgeFeaturedAttention {
88    config: EdgeFeaturedConfig,
89    // Weight matrices (would be learnable in training)
90    w_node: Vec<f32>, // [num_heads, head_dim, node_dim]
91    w_edge: Vec<f32>, // [num_heads, head_dim, edge_dim]
92    a_src: Vec<f32>,  // [num_heads, head_dim]
93    a_dst: Vec<f32>,  // [num_heads, head_dim]
94    a_edge: Vec<f32>, // [num_heads, head_dim]
95}
96
97impl EdgeFeaturedAttention {
98    pub fn new(config: EdgeFeaturedConfig) -> Self {
99        let head_dim = config.head_dim();
100        let num_heads = config.num_heads;
101
102        // Xavier initialization
103        let node_scale = (2.0 / (config.node_dim + head_dim) as f32).sqrt();
104        let edge_scale = (2.0 / (config.edge_dim + head_dim) as f32).sqrt();
105        let attn_scale = (1.0 / head_dim as f32).sqrt();
106
107        let mut seed = 42u64;
108        let mut rand = || {
109            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
110            (seed as f32) / (u64::MAX as f32) - 0.5
111        };
112
113        let w_node: Vec<f32> = (0..num_heads * head_dim * config.node_dim)
114            .map(|_| rand() * 2.0 * node_scale)
115            .collect();
116
117        let w_edge: Vec<f32> = (0..num_heads * head_dim * config.edge_dim)
118            .map(|_| rand() * 2.0 * edge_scale)
119            .collect();
120
121        let a_src: Vec<f32> = (0..num_heads * head_dim)
122            .map(|_| rand() * 2.0 * attn_scale)
123            .collect();
124
125        let a_dst: Vec<f32> = (0..num_heads * head_dim)
126            .map(|_| rand() * 2.0 * attn_scale)
127            .collect();
128
129        let a_edge: Vec<f32> = (0..num_heads * head_dim)
130            .map(|_| rand() * 2.0 * attn_scale)
131            .collect();
132
133        Self {
134            config,
135            w_node,
136            w_edge,
137            a_src,
138            a_dst,
139            a_edge,
140        }
141    }
142
143    /// Transform node features for a specific head
144    fn transform_node(&self, node: &[f32], head: usize) -> Vec<f32> {
145        let head_dim = self.config.head_dim();
146        let node_dim = self.config.node_dim;
147
148        (0..head_dim)
149            .map(|i| {
150                node.iter()
151                    .enumerate()
152                    .map(|(j, &nj)| nj * self.w_node[head * head_dim * node_dim + i * node_dim + j])
153                    .sum()
154            })
155            .collect()
156    }
157
158    /// Transform edge features for a specific head
159    fn transform_edge(&self, edge: &[f32], head: usize) -> Vec<f32> {
160        let head_dim = self.config.head_dim();
161        let edge_dim = self.config.edge_dim;
162
163        (0..head_dim)
164            .map(|i| {
165                edge.iter()
166                    .enumerate()
167                    .map(|(j, &ej)| ej * self.w_edge[head * head_dim * edge_dim + i * edge_dim + j])
168                    .sum()
169            })
170            .collect()
171    }
172
173    /// Compute attention coefficient with LeakyReLU
174    fn attention_coeff(&self, src: &[f32], dst: &[f32], edge: &[f32], head: usize) -> f32 {
175        let head_dim = self.config.head_dim();
176
177        let mut score = 0.0f32;
178        for i in 0..head_dim {
179            let offset = head * head_dim + i;
180            score += src[i] * self.a_src[offset];
181            score += dst[i] * self.a_dst[offset];
182            score += edge[i] * self.a_edge[offset];
183        }
184
185        // LeakyReLU
186        if score < 0.0 {
187            self.config.negative_slope * score
188        } else {
189            score
190        }
191    }
192}
193
194impl EdgeFeaturedAttention {
195    /// Compute attention with explicit edge features
196    pub fn compute_with_edges(
197        &self,
198        query: &[f32],
199        keys: &[&[f32]],
200        values: &[&[f32]],
201        edges: &[&[f32]],
202    ) -> AttentionResult<Vec<f32>> {
203        if keys.len() != edges.len() {
204            return Err(AttentionError::InvalidConfig(
205                "Keys and edges must have same length".to_string(),
206            ));
207        }
208
209        let num_heads = self.config.num_heads;
210        let head_dim = self.config.head_dim();
211        let n = keys.len();
212
213        // Transform query once per head
214        let query_transformed: Vec<Vec<f32>> = (0..num_heads)
215            .map(|h| self.transform_node(query, h))
216            .collect();
217
218        // Compute per-head outputs
219        let mut head_outputs: Vec<Vec<f32>> = Vec::with_capacity(num_heads);
220
221        for h in 0..num_heads {
222            // Transform all keys and edges
223            let keys_t: Vec<Vec<f32>> = keys.iter().map(|k| self.transform_node(k, h)).collect();
224            let edges_t: Vec<Vec<f32>> = edges.iter().map(|e| self.transform_edge(e, h)).collect();
225
226            // Compute attention coefficients
227            let coeffs: Vec<f32> = (0..n)
228                .map(|i| self.attention_coeff(&query_transformed[h], &keys_t[i], &edges_t[i], h))
229                .collect();
230
231            // Softmax
232            let weights = stable_softmax(&coeffs);
233
234            // Weighted sum of values
235            let mut head_out = vec![0.0f32; head_dim];
236            for (i, &w) in weights.iter().enumerate() {
237                let value_t = self.transform_node(values[i], h);
238                for (j, &vj) in value_t.iter().enumerate() {
239                    head_out[j] += w * vj;
240                }
241            }
242
243            head_outputs.push(head_out);
244        }
245
246        // Concatenate or average heads
247        if self.config.concat_heads {
248            Ok(head_outputs.into_iter().flatten().collect())
249        } else {
250            let mut output = vec![0.0f32; head_dim];
251            for head_out in &head_outputs {
252                for (i, &v) in head_out.iter().enumerate() {
253                    output[i] += v / num_heads as f32;
254                }
255            }
256            Ok(output)
257        }
258    }
259
260    /// Get the edge feature dimension
261    pub fn edge_dim(&self) -> usize {
262        self.config.edge_dim
263    }
264}
265
266impl Attention for EdgeFeaturedAttention {
267    fn compute(
268        &self,
269        query: &[f32],
270        keys: &[&[f32]],
271        values: &[&[f32]],
272    ) -> AttentionResult<Vec<f32>> {
273        if keys.is_empty() {
274            return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
275        }
276        if query.len() != self.config.node_dim {
277            return Err(AttentionError::DimensionMismatch {
278                expected: self.config.node_dim,
279                actual: query.len(),
280            });
281        }
282
283        // Use zero edge features for basic attention
284        let zero_edge = vec![0.0f32; self.config.edge_dim];
285        let edges: Vec<&[f32]> = (0..keys.len()).map(|_| zero_edge.as_slice()).collect();
286
287        self.compute_with_edges(query, keys, values, &edges)
288    }
289
290    fn compute_with_mask(
291        &self,
292        query: &[f32],
293        keys: &[&[f32]],
294        values: &[&[f32]],
295        mask: Option<&[bool]>,
296    ) -> AttentionResult<Vec<f32>> {
297        // Apply mask by filtering keys/values
298        if let Some(m) = mask {
299            let filtered: Vec<(usize, bool)> = m
300                .iter()
301                .copied()
302                .enumerate()
303                .filter(|(_, keep)| *keep)
304                .collect();
305            let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
306            let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
307            self.compute(query, &filtered_keys, &filtered_values)
308        } else {
309            self.compute(query, keys, values)
310        }
311    }
312
313    fn dim(&self) -> usize {
314        if self.config.concat_heads {
315            self.config.node_dim
316        } else {
317            self.config.head_dim()
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_edge_featured_attention() {
328        let config = EdgeFeaturedConfig::builder()
329            .node_dim(64)
330            .edge_dim(16)
331            .num_heads(4)
332            .build();
333
334        let attn = EdgeFeaturedAttention::new(config);
335
336        let query = vec![0.5; 64];
337        let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.3; 64]).collect();
338        let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
339        let edges: Vec<Vec<f32>> = (0..10).map(|_| vec![0.2; 16]).collect();
340
341        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
342        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
343        let edges_refs: Vec<&[f32]> = edges.iter().map(|e| e.as_slice()).collect();
344
345        let result = attn
346            .compute_with_edges(&query, &keys_refs, &values_refs, &edges_refs)
347            .unwrap();
348        assert_eq!(result.len(), 64);
349    }
350
351    #[test]
352    fn test_without_edges() {
353        let config = EdgeFeaturedConfig::builder()
354            .node_dim(32)
355            .edge_dim(8)
356            .num_heads(2)
357            .build();
358
359        let attn = EdgeFeaturedAttention::new(config);
360
361        let query = vec![0.5; 32];
362        let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
363        let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
364
365        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
366        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
367
368        let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
369        assert_eq!(result.len(), 32);
370    }
371
372    #[test]
373    fn test_leaky_relu() {
374        let config = EdgeFeaturedConfig::builder()
375            .node_dim(16)
376            .edge_dim(4)
377            .num_heads(1)
378            .negative_slope(0.2)
379            .build();
380
381        let attn = EdgeFeaturedAttention::new(config);
382
383        // Just verify it computes without error
384        let query = vec![-1.0; 16];
385        let keys: Vec<Vec<f32>> = vec![vec![-0.5; 16]; 3];
386        let values: Vec<Vec<f32>> = vec![vec![1.0; 16]; 3];
387
388        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
389        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
390
391        let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
392        assert_eq!(result.len(), 16);
393    }
394}