ruvector_attention/graph/
rope.rs1use crate::error::{AttentionError, AttentionResult};
7use crate::traits::Attention;
8use crate::utils::stable_softmax;
9
10#[derive(Clone, Debug)]
12pub struct RoPEConfig {
13 pub dim: usize,
14 pub base: f32,
15 pub max_position: usize,
16 pub scaling_factor: f32,
17}
18
19impl Default for RoPEConfig {
20 fn default() -> Self {
21 Self {
22 dim: 256,
23 base: 10000.0,
24 max_position: 512,
25 scaling_factor: 1.0,
26 }
27 }
28}
29
30impl RoPEConfig {
31 pub fn builder() -> RoPEConfigBuilder {
32 RoPEConfigBuilder::default()
33 }
34}
35
36#[derive(Default)]
37pub struct RoPEConfigBuilder {
38 config: RoPEConfig,
39}
40
41impl RoPEConfigBuilder {
42 pub fn dim(mut self, d: usize) -> Self {
43 self.config.dim = d;
44 self
45 }
46
47 pub fn base(mut self, b: f32) -> Self {
48 self.config.base = b;
49 self
50 }
51
52 pub fn max_position(mut self, m: usize) -> Self {
53 self.config.max_position = m;
54 self
55 }
56
57 pub fn scaling_factor(mut self, s: f32) -> Self {
58 self.config.scaling_factor = s;
59 self
60 }
61
62 pub fn build(self) -> RoPEConfig {
63 self.config
64 }
65}
66
67pub struct GraphRoPE {
69 config: RoPEConfig,
70 cos_cache: Vec<f32>,
72 sin_cache: Vec<f32>,
73 scale: f32,
74}
75
76impl GraphRoPE {
77 pub fn new(config: RoPEConfig) -> Self {
78 let dim = config.dim;
79 let max_pos = config.max_position;
80 let base = config.base;
81 let scaling = config.scaling_factor;
82
83 let half_dim = dim / 2;
85 let inv_freq: Vec<f32> = (0..half_dim)
86 .map(|i| 1.0 / (base.powf(2.0 * i as f32 / dim as f32)))
87 .collect();
88
89 let mut cos_cache = Vec::with_capacity(max_pos * dim);
91 let mut sin_cache = Vec::with_capacity(max_pos * dim);
92
93 for pos in 0..max_pos {
94 let scaled_pos = pos as f32 / scaling;
95 for i in 0..half_dim {
96 let theta = scaled_pos * inv_freq[i];
97 cos_cache.push(theta.cos());
98 sin_cache.push(theta.sin());
99 }
100 for i in 0..half_dim {
102 let theta = scaled_pos * inv_freq[i];
103 cos_cache.push(theta.cos());
104 sin_cache.push(theta.sin());
105 }
106 }
107
108 Self {
109 scale: 1.0 / (dim as f32).sqrt(),
110 config,
111 cos_cache,
112 sin_cache,
113 }
114 }
115
116 pub fn apply_rotary(&self, x: &[f32], position: usize) -> Vec<f32> {
118 let dim = self.config.dim;
119 let half = dim / 2;
120 let pos = position.min(self.config.max_position - 1);
121 let offset = pos * dim;
122
123 let mut result = vec![0.0f32; dim];
124
125 for i in 0..half {
127 let cos = self.cos_cache[offset + i];
128 let sin = self.sin_cache[offset + i];
129 result[i] = x[i] * cos - x[half + i] * sin;
130 result[half + i] = x[i] * sin + x[half + i] * cos;
131 }
132
133 result
134 }
135
136 pub fn compute_with_positions(
138 &self,
139 query: &[f32],
140 keys: &[&[f32]],
141 values: &[&[f32]],
142 query_pos: usize,
143 key_positions: &[usize],
144 ) -> AttentionResult<Vec<f32>> {
145 if keys.is_empty() {
146 return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
147 }
148 if keys.len() != key_positions.len() {
149 return Err(AttentionError::InvalidConfig(
150 "Keys and positions must have same length".to_string(),
151 ));
152 }
153 if query.len() != self.config.dim {
154 return Err(AttentionError::DimensionMismatch {
155 expected: self.config.dim,
156 actual: query.len(),
157 });
158 }
159
160 let q_rot = self.apply_rotary(query, query_pos);
162
163 let scores: Vec<f32> = keys
165 .iter()
166 .zip(key_positions.iter())
167 .map(|(key, &pos)| {
168 let k_rot = self.apply_rotary(key, pos);
169 q_rot
170 .iter()
171 .zip(k_rot.iter())
172 .map(|(q, k)| q * k)
173 .sum::<f32>()
174 * self.scale
175 })
176 .collect();
177
178 let weights = stable_softmax(&scores);
180
181 let value_dim = values[0].len();
183 let mut output = vec![0.0f32; value_dim];
184 for (w, v) in weights.iter().zip(values.iter()) {
185 for (o, &vi) in output.iter_mut().zip(v.iter()) {
186 *o += w * vi;
187 }
188 }
189
190 Ok(output)
191 }
192
193 pub fn distance_to_position(distance: usize, max_distance: usize) -> usize {
196 if distance <= 8 {
198 distance
199 } else {
200 let log_dist = (distance as f32).log2().ceil() as usize;
201 8 + log_dist.min(max_distance - 8)
202 }
203 }
204}
205
206impl Attention for GraphRoPE {
207 fn compute(
208 &self,
209 query: &[f32],
210 keys: &[&[f32]],
211 values: &[&[f32]],
212 ) -> AttentionResult<Vec<f32>> {
213 let query_pos = 0;
215 let key_positions: Vec<usize> = (0..keys.len()).collect();
216 self.compute_with_positions(query, keys, values, query_pos, &key_positions)
217 }
218
219 fn compute_with_mask(
220 &self,
221 query: &[f32],
222 keys: &[&[f32]],
223 values: &[&[f32]],
224 mask: Option<&[bool]>,
225 ) -> AttentionResult<Vec<f32>> {
226 if let Some(m) = mask {
227 let filtered: Vec<(usize, bool)> = m
228 .iter()
229 .copied()
230 .enumerate()
231 .filter(|(_, keep)| *keep)
232 .collect();
233 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
234 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
235 self.compute(query, &filtered_keys, &filtered_values)
236 } else {
237 self.compute(query, keys, values)
238 }
239 }
240
241 fn dim(&self) -> usize {
242 self.config.dim
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_rope_basic() {
252 let config = RoPEConfig::builder().dim(64).max_position(100).build();
253
254 let rope = GraphRoPE::new(config);
255
256 let query = vec![0.5; 64];
257 let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.3; 64]).collect();
258 let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
259
260 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
261 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
262
263 let result = rope.compute(&query, &keys_refs, &values_refs).unwrap();
264 assert_eq!(result.len(), 64);
265 }
266
267 #[test]
268 fn test_rope_with_positions() {
269 let config = RoPEConfig::builder().dim(32).max_position(50).build();
270
271 let rope = GraphRoPE::new(config);
272
273 let query = vec![0.5; 32];
274 let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
275 let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
276
277 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
278 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
279
280 let key_positions = vec![1, 2, 3, 2, 4];
282
283 let result = rope
284 .compute_with_positions(&query, &keys_refs, &values_refs, 0, &key_positions)
285 .unwrap();
286 assert_eq!(result.len(), 32);
287 }
288
289 #[test]
290 fn test_rotary_embedding() {
291 let config = RoPEConfig::builder().dim(16).max_position(10).build();
292
293 let rope = GraphRoPE::new(config);
294
295 let x = vec![1.0; 16];
296
297 let rotated = rope.apply_rotary(&x, 5);
299 let norm_orig: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
300 let norm_rot: f32 = rotated.iter().map(|v| v * v).sum::<f32>().sqrt();
301
302 assert!((norm_orig - norm_rot).abs() < 1e-5);
303 }
304
305 #[test]
306 fn test_distance_to_position() {
307 assert_eq!(GraphRoPE::distance_to_position(0, 20), 0);
309 assert_eq!(GraphRoPE::distance_to_position(5, 20), 5);
310 assert_eq!(GraphRoPE::distance_to_position(8, 20), 8);
311
312 let pos_16 = GraphRoPE::distance_to_position(16, 20);
314 let pos_32 = GraphRoPE::distance_to_position(32, 20);
315 assert!(pos_16 > 8);
316 assert!(pos_32 > pos_16);
317 }
318}