ruvector_attention/graph/
rope.rs

1//! Rotary Position Embeddings (RoPE) for Graph Attention
2//!
3//! Adapts RoPE for graph structures where positions are defined by graph topology
4//! (e.g., hop distance, shortest path length, or learned positional encodings).
5
6use crate::error::{AttentionError, AttentionResult};
7use crate::traits::Attention;
8use crate::utils::stable_softmax;
9
10/// Configuration for Graph RoPE
11#[derive(Clone, Debug)]
12pub struct RoPEConfig {
13    pub dim: usize,
14    pub base: f32,
15    pub max_position: usize,
16    pub scaling_factor: f32,
17}
18
19impl Default for RoPEConfig {
20    fn default() -> Self {
21        Self {
22            dim: 256,
23            base: 10000.0,
24            max_position: 512,
25            scaling_factor: 1.0,
26        }
27    }
28}
29
30impl RoPEConfig {
31    pub fn builder() -> RoPEConfigBuilder {
32        RoPEConfigBuilder::default()
33    }
34}
35
36#[derive(Default)]
37pub struct RoPEConfigBuilder {
38    config: RoPEConfig,
39}
40
41impl RoPEConfigBuilder {
42    pub fn dim(mut self, d: usize) -> Self {
43        self.config.dim = d;
44        self
45    }
46
47    pub fn base(mut self, b: f32) -> Self {
48        self.config.base = b;
49        self
50    }
51
52    pub fn max_position(mut self, m: usize) -> Self {
53        self.config.max_position = m;
54        self
55    }
56
57    pub fn scaling_factor(mut self, s: f32) -> Self {
58        self.config.scaling_factor = s;
59        self
60    }
61
62    pub fn build(self) -> RoPEConfig {
63        self.config
64    }
65}
66
67/// Graph attention with Rotary Position Embeddings
68pub struct GraphRoPE {
69    config: RoPEConfig,
70    /// Precomputed cos/sin tables: [max_position, dim]
71    cos_cache: Vec<f32>,
72    sin_cache: Vec<f32>,
73    scale: f32,
74}
75
76impl GraphRoPE {
77    pub fn new(config: RoPEConfig) -> Self {
78        let dim = config.dim;
79        let max_pos = config.max_position;
80        let base = config.base;
81        let scaling = config.scaling_factor;
82
83        // Compute frequency bands
84        let half_dim = dim / 2;
85        let inv_freq: Vec<f32> = (0..half_dim)
86            .map(|i| 1.0 / (base.powf(2.0 * i as f32 / dim as f32)))
87            .collect();
88
89        // Precompute cos/sin for all positions
90        let mut cos_cache = Vec::with_capacity(max_pos * dim);
91        let mut sin_cache = Vec::with_capacity(max_pos * dim);
92
93        for pos in 0..max_pos {
94            let scaled_pos = pos as f32 / scaling;
95            for i in 0..half_dim {
96                let theta = scaled_pos * inv_freq[i];
97                cos_cache.push(theta.cos());
98                sin_cache.push(theta.sin());
99            }
100            // Duplicate for both halves (interleaved format)
101            for i in 0..half_dim {
102                let theta = scaled_pos * inv_freq[i];
103                cos_cache.push(theta.cos());
104                sin_cache.push(theta.sin());
105            }
106        }
107
108        Self {
109            scale: 1.0 / (dim as f32).sqrt(),
110            config,
111            cos_cache,
112            sin_cache,
113        }
114    }
115
116    /// Apply rotary embedding to a vector at given position
117    pub fn apply_rotary(&self, x: &[f32], position: usize) -> Vec<f32> {
118        let dim = self.config.dim;
119        let half = dim / 2;
120        let pos = position.min(self.config.max_position - 1);
121        let offset = pos * dim;
122
123        let mut result = vec![0.0f32; dim];
124
125        // Apply rotation to first half
126        for i in 0..half {
127            let cos = self.cos_cache[offset + i];
128            let sin = self.sin_cache[offset + i];
129            result[i] = x[i] * cos - x[half + i] * sin;
130            result[half + i] = x[i] * sin + x[half + i] * cos;
131        }
132
133        result
134    }
135
136    /// Compute attention with positional encoding based on graph distances
137    pub fn compute_with_positions(
138        &self,
139        query: &[f32],
140        keys: &[&[f32]],
141        values: &[&[f32]],
142        query_pos: usize,
143        key_positions: &[usize],
144    ) -> AttentionResult<Vec<f32>> {
145        if keys.is_empty() {
146            return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
147        }
148        if keys.len() != key_positions.len() {
149            return Err(AttentionError::InvalidConfig(
150                "Keys and positions must have same length".to_string(),
151            ));
152        }
153        if query.len() != self.config.dim {
154            return Err(AttentionError::DimensionMismatch {
155                expected: self.config.dim,
156                actual: query.len(),
157            });
158        }
159
160        // Apply rotary to query
161        let q_rot = self.apply_rotary(query, query_pos);
162
163        // Compute attention scores with rotary keys
164        let scores: Vec<f32> = keys
165            .iter()
166            .zip(key_positions.iter())
167            .map(|(key, &pos)| {
168                let k_rot = self.apply_rotary(key, pos);
169                q_rot
170                    .iter()
171                    .zip(k_rot.iter())
172                    .map(|(q, k)| q * k)
173                    .sum::<f32>()
174                    * self.scale
175            })
176            .collect();
177
178        // Softmax
179        let weights = stable_softmax(&scores);
180
181        // Weighted sum
182        let value_dim = values[0].len();
183        let mut output = vec![0.0f32; value_dim];
184        for (w, v) in weights.iter().zip(values.iter()) {
185            for (o, &vi) in output.iter_mut().zip(v.iter()) {
186                *o += w * vi;
187            }
188        }
189
190        Ok(output)
191    }
192
193    /// Get relative position for graph distance
194    /// Converts graph hop distance to position index
195    pub fn distance_to_position(distance: usize, max_distance: usize) -> usize {
196        // Bucketize distances logarithmically for larger graphs
197        if distance <= 8 {
198            distance
199        } else {
200            let log_dist = (distance as f32).log2().ceil() as usize;
201            8 + log_dist.min(max_distance - 8)
202        }
203    }
204}
205
206impl Attention for GraphRoPE {
207    fn compute(
208        &self,
209        query: &[f32],
210        keys: &[&[f32]],
211        values: &[&[f32]],
212    ) -> AttentionResult<Vec<f32>> {
213        // Default: use sequential positions (0, 1, 2, ...)
214        let query_pos = 0;
215        let key_positions: Vec<usize> = (0..keys.len()).collect();
216        self.compute_with_positions(query, keys, values, query_pos, &key_positions)
217    }
218
219    fn compute_with_mask(
220        &self,
221        query: &[f32],
222        keys: &[&[f32]],
223        values: &[&[f32]],
224        mask: Option<&[bool]>,
225    ) -> AttentionResult<Vec<f32>> {
226        if let Some(m) = mask {
227            let filtered: Vec<(usize, bool)> = m
228                .iter()
229                .copied()
230                .enumerate()
231                .filter(|(_, keep)| *keep)
232                .collect();
233            let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
234            let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
235            self.compute(query, &filtered_keys, &filtered_values)
236        } else {
237            self.compute(query, keys, values)
238        }
239    }
240
241    fn dim(&self) -> usize {
242        self.config.dim
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_rope_basic() {
252        let config = RoPEConfig::builder().dim(64).max_position(100).build();
253
254        let rope = GraphRoPE::new(config);
255
256        let query = vec![0.5; 64];
257        let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.3; 64]).collect();
258        let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
259
260        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
261        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
262
263        let result = rope.compute(&query, &keys_refs, &values_refs).unwrap();
264        assert_eq!(result.len(), 64);
265    }
266
267    #[test]
268    fn test_rope_with_positions() {
269        let config = RoPEConfig::builder().dim(32).max_position(50).build();
270
271        let rope = GraphRoPE::new(config);
272
273        let query = vec![0.5; 32];
274        let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
275        let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
276
277        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
278        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
279
280        // Graph distances as positions
281        let key_positions = vec![1, 2, 3, 2, 4];
282
283        let result = rope
284            .compute_with_positions(&query, &keys_refs, &values_refs, 0, &key_positions)
285            .unwrap();
286        assert_eq!(result.len(), 32);
287    }
288
289    #[test]
290    fn test_rotary_embedding() {
291        let config = RoPEConfig::builder().dim(16).max_position(10).build();
292
293        let rope = GraphRoPE::new(config);
294
295        let x = vec![1.0; 16];
296
297        // Rotary should preserve norm approximately
298        let rotated = rope.apply_rotary(&x, 5);
299        let norm_orig: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
300        let norm_rot: f32 = rotated.iter().map(|v| v * v).sum::<f32>().sqrt();
301
302        assert!((norm_orig - norm_rot).abs() < 1e-5);
303    }
304
305    #[test]
306    fn test_distance_to_position() {
307        // Direct mapping for small distances
308        assert_eq!(GraphRoPE::distance_to_position(0, 20), 0);
309        assert_eq!(GraphRoPE::distance_to_position(5, 20), 5);
310        assert_eq!(GraphRoPE::distance_to_position(8, 20), 8);
311
312        // Logarithmic for larger distances
313        let pos_16 = GraphRoPE::distance_to_position(16, 20);
314        let pos_32 = GraphRoPE::distance_to_position(32, 20);
315        assert!(pos_16 > 8);
316        assert!(pos_32 > pos_16);
317    }
318}