Skip to main content

ruvector_attention/graph/
dual_space.rs

1//! Dual-space attention combining Euclidean and Hyperbolic geometries
2//!
3//! This module implements attention that operates in both Euclidean and hyperbolic
4//! spaces, combining their complementary properties:
5//! - Euclidean: Good for flat, local structure
6//! - Hyperbolic: Good for hierarchical, tree-like structure
7
8use crate::error::{AttentionError, AttentionResult};
9use crate::hyperbolic::project_to_ball;
10use crate::traits::Attention;
11use crate::utils::stable_softmax;
12
13/// Compute Poincaré distance between two points
14fn poincare_dist(u: &[f32], v: &[f32], curvature: f32) -> f32 {
15    let c = curvature.abs();
16    let sqrt_c = c.sqrt();
17
18    let diff_sq: f32 = u.iter().zip(v.iter()).map(|(a, b)| (a - b).powi(2)).sum();
19    let norm_u_sq: f32 = u.iter().map(|x| x * x).sum();
20    let norm_v_sq: f32 = v.iter().map(|x| x * x).sum();
21
22    let denom = (1.0 - c * norm_u_sq).max(1e-7) * (1.0 - c * norm_v_sq).max(1e-7);
23    let arg = 1.0 + 2.0 * c * diff_sq / denom;
24
25    (1.0 / sqrt_c) * arg.max(1.0).acosh()
26}
27
28/// Configuration for dual-space attention
29#[derive(Clone, Debug)]
30pub struct DualSpaceConfig {
31    pub dim: usize,
32    pub curvature: f32,
33    pub euclidean_weight: f32,
34    pub hyperbolic_weight: f32,
35    pub learn_weights: bool,
36    pub temperature: f32,
37}
38
39impl Default for DualSpaceConfig {
40    fn default() -> Self {
41        Self {
42            dim: 256,
43            curvature: 1.0,
44            euclidean_weight: 0.5,
45            hyperbolic_weight: 0.5,
46            learn_weights: false,
47            temperature: 1.0,
48        }
49    }
50}
51
52impl DualSpaceConfig {
53    pub fn builder() -> DualSpaceConfigBuilder {
54        DualSpaceConfigBuilder::default()
55    }
56}
57
58#[derive(Default)]
59pub struct DualSpaceConfigBuilder {
60    config: DualSpaceConfig,
61}
62
63impl DualSpaceConfigBuilder {
64    pub fn dim(mut self, d: usize) -> Self {
65        self.config.dim = d;
66        self
67    }
68
69    pub fn curvature(mut self, c: f32) -> Self {
70        self.config.curvature = c;
71        self
72    }
73
74    pub fn euclidean_weight(mut self, w: f32) -> Self {
75        self.config.euclidean_weight = w;
76        self
77    }
78
79    pub fn hyperbolic_weight(mut self, w: f32) -> Self {
80        self.config.hyperbolic_weight = w;
81        self
82    }
83
84    pub fn temperature(mut self, t: f32) -> Self {
85        self.config.temperature = t;
86        self
87    }
88
89    pub fn build(self) -> DualSpaceConfig {
90        self.config
91    }
92}
93
94/// Dual-space attention layer
95pub struct DualSpaceAttention {
96    config: DualSpaceConfig,
97    scale: f32,
98    /// Linear projection for Euclidean space
99    w_euclidean: Vec<f32>,
100    /// Linear projection for hyperbolic space
101    w_hyperbolic: Vec<f32>,
102    /// Output projection
103    w_out: Vec<f32>,
104}
105
106impl DualSpaceAttention {
107    pub fn new(config: DualSpaceConfig) -> Self {
108        let dim = config.dim;
109        let scale = 1.0 / (dim as f32).sqrt();
110
111        // Xavier initialization
112        let w_scale = (2.0 / (dim + dim) as f32).sqrt();
113        let mut seed = 42u64;
114        let mut rand = || {
115            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
116            ((seed as f32) / (u64::MAX as f32) - 0.5) * 2.0 * w_scale
117        };
118
119        let w_euclidean: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
120        let w_hyperbolic: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
121        let w_out: Vec<f32> = (0..dim * dim).map(|_| rand()).collect();
122
123        Self {
124            config,
125            scale,
126            w_euclidean,
127            w_hyperbolic,
128            w_out,
129        }
130    }
131
132    /// Project to Euclidean representation
133    fn to_euclidean(&self, x: &[f32]) -> Vec<f32> {
134        let dim = self.config.dim;
135        (0..dim)
136            .map(|i| {
137                x.iter()
138                    .enumerate()
139                    .map(|(j, &xj)| xj * self.w_euclidean[i * dim + j])
140                    .sum()
141            })
142            .collect()
143    }
144
145    /// Project to hyperbolic representation (Poincaré ball)
146    fn to_hyperbolic(&self, x: &[f32]) -> Vec<f32> {
147        let dim = self.config.dim;
148        let projected: Vec<f32> = (0..dim)
149            .map(|i| {
150                x.iter()
151                    .enumerate()
152                    .map(|(j, &xj)| xj * self.w_hyperbolic[i * dim + j])
153                    .sum()
154            })
155            .collect();
156
157        // Project to ball with curvature
158        project_to_ball(&projected, self.config.curvature, 1e-5)
159    }
160
161    /// Compute Euclidean similarity (dot product)
162    fn euclidean_similarity(&self, q: &[f32], k: &[f32]) -> f32 {
163        q.iter().zip(k.iter()).map(|(a, b)| a * b).sum::<f32>() * self.scale
164    }
165
166    /// Compute hyperbolic similarity (negative Poincaré distance)
167    fn hyperbolic_similarity(&self, q: &[f32], k: &[f32]) -> f32 {
168        -poincare_dist(q, k, self.config.curvature)
169    }
170
171    /// Output projection
172    fn project_output(&self, x: &[f32]) -> Vec<f32> {
173        let dim = self.config.dim;
174        (0..dim)
175            .map(|i| {
176                x.iter()
177                    .enumerate()
178                    .map(|(j, &xj)| xj * self.w_out[i * dim + j])
179                    .sum()
180            })
181            .collect()
182    }
183
184    /// Get the contribution weights for analysis
185    pub fn get_space_contributions(&self, query: &[f32], keys: &[&[f32]]) -> (Vec<f32>, Vec<f32>) {
186        let q_euc = self.to_euclidean(query);
187        let q_hyp = self.to_hyperbolic(query);
188
189        let euc_scores: Vec<f32> = keys
190            .iter()
191            .map(|k| {
192                let k_euc = self.to_euclidean(k);
193                self.euclidean_similarity(&q_euc, &k_euc)
194            })
195            .collect();
196
197        let hyp_scores: Vec<f32> = keys
198            .iter()
199            .map(|k| {
200                let k_hyp = self.to_hyperbolic(k);
201                self.hyperbolic_similarity(&q_hyp, &k_hyp)
202            })
203            .collect();
204
205        (euc_scores, hyp_scores)
206    }
207}
208
209impl Attention for DualSpaceAttention {
210    fn compute(
211        &self,
212        query: &[f32],
213        keys: &[&[f32]],
214        values: &[&[f32]],
215    ) -> AttentionResult<Vec<f32>> {
216        if keys.is_empty() {
217            return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
218        }
219        if query.len() != self.config.dim {
220            return Err(AttentionError::DimensionMismatch {
221                expected: self.config.dim,
222                actual: query.len(),
223            });
224        }
225
226        let n = keys.len();
227        let value_dim = values[0].len();
228        let temp = self.config.temperature;
229
230        // Project query to both spaces
231        let q_euc = self.to_euclidean(query);
232        let q_hyp = self.to_hyperbolic(query);
233
234        // Compute combined scores
235        let mut combined_scores = Vec::with_capacity(n);
236
237        for key in keys.iter() {
238            let k_euc = self.to_euclidean(key);
239            let k_hyp = self.to_hyperbolic(key);
240
241            let euc_score = self.euclidean_similarity(&q_euc, &k_euc);
242            let hyp_score = self.hyperbolic_similarity(&q_hyp, &k_hyp);
243
244            // Weighted combination
245            let combined = (self.config.euclidean_weight * euc_score
246                + self.config.hyperbolic_weight * hyp_score)
247                / temp;
248
249            combined_scores.push(combined);
250        }
251
252        // Softmax over combined scores
253        let weights = stable_softmax(&combined_scores);
254
255        // Weighted sum of values
256        let mut output = vec![0.0f32; value_dim];
257        for (w, v) in weights.iter().zip(values.iter()) {
258            for (o, &vi) in output.iter_mut().zip(v.iter()) {
259                *o += w * vi;
260            }
261        }
262
263        // Output projection
264        if value_dim == self.config.dim {
265            Ok(self.project_output(&output))
266        } else {
267            Ok(output)
268        }
269    }
270
271    fn compute_with_mask(
272        &self,
273        query: &[f32],
274        keys: &[&[f32]],
275        values: &[&[f32]],
276        mask: Option<&[bool]>,
277    ) -> AttentionResult<Vec<f32>> {
278        if let Some(m) = mask {
279            let filtered: Vec<(usize, bool)> = m
280                .iter()
281                .copied()
282                .enumerate()
283                .filter(|(_, keep)| *keep)
284                .collect();
285            let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
286            let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
287            self.compute(query, &filtered_keys, &filtered_values)
288        } else {
289            self.compute(query, keys, values)
290        }
291    }
292
293    fn dim(&self) -> usize {
294        self.config.dim
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_dual_space_basic() {
304        let config = DualSpaceConfig::builder()
305            .dim(64)
306            .curvature(1.0)
307            .euclidean_weight(0.5)
308            .hyperbolic_weight(0.5)
309            .build();
310
311        let attn = DualSpaceAttention::new(config);
312
313        let query = vec![0.1; 64];
314        let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.1; 64]).collect();
315        let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
316
317        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
318        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
319
320        let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
321        assert_eq!(result.len(), 64);
322    }
323
324    #[test]
325    fn test_euclidean_dominant() {
326        let config = DualSpaceConfig::builder()
327            .dim(32)
328            .euclidean_weight(1.0)
329            .hyperbolic_weight(0.0)
330            .build();
331
332        let attn = DualSpaceAttention::new(config);
333
334        let query = vec![0.5; 32];
335        let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
336        let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
337
338        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
339        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
340
341        let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
342        assert_eq!(result.len(), 32);
343    }
344
345    #[test]
346    fn test_hyperbolic_dominant() {
347        let config = DualSpaceConfig::builder()
348            .dim(32)
349            .curvature(0.5)
350            .euclidean_weight(0.0)
351            .hyperbolic_weight(1.0)
352            .build();
353
354        let attn = DualSpaceAttention::new(config);
355
356        let query = vec![0.1; 32]; // Small values for Poincaré ball
357        let keys: Vec<Vec<f32>> = vec![vec![0.1; 32]; 5];
358        let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
359
360        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
361        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
362
363        let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
364        assert_eq!(result.len(), 32);
365    }
366
367    #[test]
368    fn test_space_contributions() {
369        let config = DualSpaceConfig::builder()
370            .dim(16)
371            .euclidean_weight(0.5)
372            .hyperbolic_weight(0.5)
373            .build();
374
375        let attn = DualSpaceAttention::new(config);
376
377        let query = vec![0.2; 16];
378        let keys: Vec<Vec<f32>> = vec![vec![0.2; 16]; 3];
379        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
380
381        let (euc_scores, hyp_scores) = attn.get_space_contributions(&query, &keys_refs);
382
383        assert_eq!(euc_scores.len(), 3);
384        assert_eq!(hyp_scores.len(), 3);
385    }
386
387    #[test]
388    fn test_temperature_scaling() {
389        let config_low_temp = DualSpaceConfig::builder().dim(16).temperature(0.5).build();
390
391        let config_high_temp = DualSpaceConfig::builder().dim(16).temperature(2.0).build();
392
393        let attn_low = DualSpaceAttention::new(config_low_temp);
394        let attn_high = DualSpaceAttention::new(config_high_temp);
395
396        let query = vec![0.5; 16];
397        let keys: Vec<Vec<f32>> = vec![vec![0.8; 16], vec![0.2; 16]];
398        let values: Vec<Vec<f32>> = vec![vec![1.0; 16], vec![0.0; 16]];
399
400        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
401        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
402
403        let result_low = attn_low.compute(&query, &keys_refs, &values_refs).unwrap();
404        let result_high = attn_high.compute(&query, &keys_refs, &values_refs).unwrap();
405
406        // Low temperature should be more peaked (closer to [1,0,0...])
407        // High temperature should be more uniform
408        // We just verify both compute successfully
409        assert_eq!(result_low.len(), 16);
410        assert_eq!(result_high.len(), 16);
411    }
412}