ruvector_attention/graph/
dual_space.rs1use crate::error::{AttentionError, AttentionResult};
9use crate::hyperbolic::project_to_ball;
10use crate::traits::Attention;
11use crate::utils::stable_softmax;
12
13fn poincare_dist(u: &[f32], v: &[f32], curvature: f32) -> f32 {
15 let c = curvature.abs();
16 let sqrt_c = c.sqrt();
17
18 let diff_sq: f32 = u.iter().zip(v.iter()).map(|(a, b)| (a - b).powi(2)).sum();
19 let norm_u_sq: f32 = u.iter().map(|x| x * x).sum();
20 let norm_v_sq: f32 = v.iter().map(|x| x * x).sum();
21
22 let denom = (1.0 - c * norm_u_sq).max(1e-7) * (1.0 - c * norm_v_sq).max(1e-7);
23 let arg = 1.0 + 2.0 * c * diff_sq / denom;
24
25 (1.0 / sqrt_c) * arg.max(1.0).acosh()
26}
27
28#[derive(Clone, Debug)]
30pub struct DualSpaceConfig {
31 pub dim: usize,
32 pub curvature: f32,
33 pub euclidean_weight: f32,
34 pub hyperbolic_weight: f32,
35 pub learn_weights: bool,
36 pub temperature: f32,
37}
38
39impl Default for DualSpaceConfig {
40 fn default() -> Self {
41 Self {
42 dim: 256,
43 curvature: 1.0,
44 euclidean_weight: 0.5,
45 hyperbolic_weight: 0.5,
46 learn_weights: false,
47 temperature: 1.0,
48 }
49 }
50}
51
52impl DualSpaceConfig {
53 pub fn builder() -> DualSpaceConfigBuilder {
54 DualSpaceConfigBuilder::default()
55 }
56}
57
58#[derive(Default)]
59pub struct DualSpaceConfigBuilder {
60 config: DualSpaceConfig,
61}
62
63impl DualSpaceConfigBuilder {
64 pub fn dim(mut self, d: usize) -> Self {
65 self.config.dim = d;
66 self
67 }
68
69 pub fn curvature(mut self, c: f32) -> Self {
70 self.config.curvature = c;
71 self
72 }
73
74 pub fn euclidean_weight(mut self, w: f32) -> Self {
75 self.config.euclidean_weight = w;
76 self
77 }
78
79 pub fn hyperbolic_weight(mut self, w: f32) -> Self {
80 self.config.hyperbolic_weight = w;
81 self
82 }
83
84 pub fn temperature(mut self, t: f32) -> Self {
85 self.config.temperature = t;
86 self
87 }
88
89 pub fn build(self) -> DualSpaceConfig {
90 self.config
91 }
92}
93
94pub struct DualSpaceAttention {
96 config: DualSpaceConfig,
97 scale: f32,
98 w_euclidean: Vec<f32>,
100 w_hyperbolic: Vec<f32>,
102 w_out: Vec<f32>,
104}
105
106impl DualSpaceAttention {
107 pub fn new(config: DualSpaceConfig) -> Self {
108 let dim = config.dim;
109 let scale = 1.0 / (dim as f32).sqrt();
110
111 let w_scale = (2.0 / (dim + dim) as f32).sqrt();
113 let mut seed = 42u64;
114 let mut rand = || {
115 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
116 ((seed as f32) / (u64::MAX as f32) - 0.5) * 2.0 * w_scale
117 };
118
119 let w_euclidean: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
120 let w_hyperbolic: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
121 let w_out: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
122
123 Self {
124 config,
125 scale,
126 w_euclidean,
127 w_hyperbolic,
128 w_out,
129 }
130 }
131
132 fn to_euclidean(&self, x: &[f32]) -> Vec<f32> {
134 let dim = self.config.dim;
135 (0..dim)
136 .map(|i| {
137 x.iter()
138 .enumerate()
139 .map(|(j, &xj)| xj * self.w_euclidean[i * dim + j])
140 .sum()
141 })
142 .collect()
143 }
144
145 fn to_hyperbolic(&self, x: &[f32]) -> Vec<f32> {
147 let dim = self.config.dim;
148 let projected: Vec<f32> = (0..dim)
149 .map(|i| {
150 x.iter()
151 .enumerate()
152 .map(|(j, &xj)| xj * self.w_hyperbolic[i * dim + j])
153 .sum()
154 })
155 .collect();
156
157 project_to_ball(&projected, self.config.curvature, 1e-5)
159 }
160
161 fn euclidean_similarity(&self, q: &[f32], k: &[f32]) -> f32 {
163 q.iter().zip(k.iter()).map(|(a, b)| a * b).sum::<f32>() * self.scale
164 }
165
166 fn hyperbolic_similarity(&self, q: &[f32], k: &[f32]) -> f32 {
168 -poincare_dist(q, k, self.config.curvature)
169 }
170
171 fn project_output(&self, x: &[f32]) -> Vec<f32> {
173 let dim = self.config.dim;
174 (0..dim)
175 .map(|i| {
176 x.iter()
177 .enumerate()
178 .map(|(j, &xj)| xj * self.w_out[i * dim + j])
179 .sum()
180 })
181 .collect()
182 }
183
184 pub fn get_space_contributions(&self, query: &[f32], keys: &[&[f32]]) -> (Vec<f32>, Vec<f32>) {
186 let q_euc = self.to_euclidean(query);
187 let q_hyp = self.to_hyperbolic(query);
188
189 let euc_scores: Vec<f32> = keys
190 .iter()
191 .map(|k| {
192 let k_euc = self.to_euclidean(k);
193 self.euclidean_similarity(&q_euc, &k_euc)
194 })
195 .collect();
196
197 let hyp_scores: Vec<f32> = keys
198 .iter()
199 .map(|k| {
200 let k_hyp = self.to_hyperbolic(k);
201 self.hyperbolic_similarity(&q_hyp, &k_hyp)
202 })
203 .collect();
204
205 (euc_scores, hyp_scores)
206 }
207}
208
209impl Attention for DualSpaceAttention {
210 fn compute(
211 &self,
212 query: &[f32],
213 keys: &[&[f32]],
214 values: &[&[f32]],
215 ) -> AttentionResult<Vec<f32>> {
216 if keys.is_empty() {
217 return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
218 }
219 if query.len() != self.config.dim {
220 return Err(AttentionError::DimensionMismatch {
221 expected: self.config.dim,
222 actual: query.len(),
223 });
224 }
225
226 let n = keys.len();
227 let value_dim = values[0].len();
228 let temp = self.config.temperature;
229
230 let q_euc = self.to_euclidean(query);
232 let q_hyp = self.to_hyperbolic(query);
233
234 let mut combined_scores = Vec::with_capacity(n);
236
237 for key in keys.iter() {
238 let k_euc = self.to_euclidean(key);
239 let k_hyp = self.to_hyperbolic(key);
240
241 let euc_score = self.euclidean_similarity(&q_euc, &k_euc);
242 let hyp_score = self.hyperbolic_similarity(&q_hyp, &k_hyp);
243
244 let combined = (self.config.euclidean_weight * euc_score
246 + self.config.hyperbolic_weight * hyp_score)
247 / temp;
248
249 combined_scores.push(combined);
250 }
251
252 let weights = stable_softmax(&combined_scores);
254
255 let mut output = vec![0.0f32; value_dim];
257 for (w, v) in weights.iter().zip(values.iter()) {
258 for (o, &vi) in output.iter_mut().zip(v.iter()) {
259 *o += w * vi;
260 }
261 }
262
263 if value_dim == self.config.dim {
265 Ok(self.project_output(&output))
266 } else {
267 Ok(output)
268 }
269 }
270
271 fn compute_with_mask(
272 &self,
273 query: &[f32],
274 keys: &[&[f32]],
275 values: &[&[f32]],
276 mask: Option<&[bool]>,
277 ) -> AttentionResult<Vec<f32>> {
278 if let Some(m) = mask {
279 let filtered: Vec<(usize, bool)> = m
280 .iter()
281 .copied()
282 .enumerate()
283 .filter(|(_, keep)| *keep)
284 .collect();
285 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
286 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
287 self.compute(query, &filtered_keys, &filtered_values)
288 } else {
289 self.compute(query, keys, values)
290 }
291 }
292
293 fn dim(&self) -> usize {
294 self.config.dim
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_dual_space_basic() {
304 let config = DualSpaceConfig::builder()
305 .dim(64)
306 .curvature(1.0)
307 .euclidean_weight(0.5)
308 .hyperbolic_weight(0.5)
309 .build();
310
311 let attn = DualSpaceAttention::new(config);
312
313 let query = vec![0.1; 64];
314 let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.1; 64]).collect();
315 let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
316
317 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
318 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
319
320 let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
321 assert_eq!(result.len(), 64);
322 }
323
324 #[test]
325 fn test_euclidean_dominant() {
326 let config = DualSpaceConfig::builder()
327 .dim(32)
328 .euclidean_weight(1.0)
329 .hyperbolic_weight(0.0)
330 .build();
331
332 let attn = DualSpaceAttention::new(config);
333
334 let query = vec![0.5; 32];
335 let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
336 let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
337
338 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
339 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
340
341 let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
342 assert_eq!(result.len(), 32);
343 }
344
345 #[test]
346 fn test_hyperbolic_dominant() {
347 let config = DualSpaceConfig::builder()
348 .dim(32)
349 .curvature(0.5)
350 .euclidean_weight(0.0)
351 .hyperbolic_weight(1.0)
352 .build();
353
354 let attn = DualSpaceAttention::new(config);
355
356 let query = vec![0.1; 32]; let keys: Vec<Vec<f32>> = vec![vec![0.1; 32]; 5];
358 let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
359
360 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
361 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
362
363 let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
364 assert_eq!(result.len(), 32);
365 }
366
367 #[test]
368 fn test_space_contributions() {
369 let config = DualSpaceConfig::builder()
370 .dim(16)
371 .euclidean_weight(0.5)
372 .hyperbolic_weight(0.5)
373 .build();
374
375 let attn = DualSpaceAttention::new(config);
376
377 let query = vec![0.2; 16];
378 let keys: Vec<Vec<f32>> = vec![vec![0.2; 16]; 3];
379 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
380
381 let (euc_scores, hyp_scores) = attn.get_space_contributions(&query, &keys_refs);
382
383 assert_eq!(euc_scores.len(), 3);
384 assert_eq!(hyp_scores.len(), 3);
385 }
386
387 #[test]
388 fn test_temperature_scaling() {
389 let config_low_temp = DualSpaceConfig::builder().dim(16).temperature(0.5).build();
390
391 let config_high_temp = DualSpaceConfig::builder().dim(16).temperature(2.0).build();
392
393 let attn_low = DualSpaceAttention::new(config_low_temp);
394 let attn_high = DualSpaceAttention::new(config_high_temp);
395
396 let query = vec![0.5; 16];
397 let keys: Vec<Vec<f32>> = vec![vec![0.8; 16], vec![0.2; 16]];
398 let values: Vec<Vec<f32>> = vec![vec![1.0; 16], vec![0.0; 16]];
399
400 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
401 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
402
403 let result_low = attn_low.compute(&query, &keys_refs, &values_refs).unwrap();
404 let result_high = attn_high.compute(&query, &keys_refs, &values_refs).unwrap();
405
406 assert_eq!(result_low.len(), 16);
410 assert_eq!(result_high.len(), 16);
411 }
412}