ruvector_attention/transport/
cached_projections.rs

1//! Cached Projections for Fast OT
2//!
3//! Pre-compute and cache random projections per window to avoid
4//! redundant computation across queries.
5
6use rand::prelude::*;
7use rand::rngs::StdRng;
8
9/// Cache for random projection directions
10#[derive(Debug, Clone)]
11pub struct ProjectionCache {
12    /// Random unit directions [P × dim]
13    pub directions: Vec<Vec<f32>>,
14    /// Number of projections
15    pub num_projections: usize,
16    /// Dimension
17    pub dim: usize,
18}
19
20impl ProjectionCache {
21    /// Create new projection cache with P random directions
22    pub fn new(dim: usize, num_projections: usize, seed: u64) -> Self {
23        let mut rng = StdRng::seed_from_u64(seed);
24
25        let directions: Vec<Vec<f32>> = (0..num_projections)
26            .map(|_| {
27                let mut dir: Vec<f32> = (0..dim)
28                    .map(|_| rng.sample::<f32, _>(rand::distributions::Standard) * 2.0 - 1.0)
29                    .collect();
30                // Normalize to unit vector
31                let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
32                if norm > 1e-8 {
33                    for x in &mut dir {
34                        *x /= norm;
35                    }
36                }
37                dir
38            })
39            .collect();
40
41        Self {
42            directions,
43            num_projections,
44            dim,
45        }
46    }
47
48    /// Project a single vector onto all directions
49    /// Returns [P] projected values
50    #[inline]
51    pub fn project(&self, vector: &[f32]) -> Vec<f32> {
52        self.directions
53            .iter()
54            .map(|dir| Self::dot_product_simd(vector, dir))
55            .collect()
56    }
57
58    /// Project a single vector into pre-allocated buffer
59    #[inline]
60    pub fn project_into(&self, vector: &[f32], out: &mut [f32]) {
61        for (i, dir) in self.directions.iter().enumerate() {
62            out[i] = Self::dot_product_simd(vector, dir);
63        }
64    }
65
66    /// SIMD-friendly 4-way unrolled dot product
67    #[inline(always)]
68    fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
69        let len = a.len();
70        let chunks = len / 4;
71        let remainder = len % 4;
72
73        let mut sum0 = 0.0f32;
74        let mut sum1 = 0.0f32;
75        let mut sum2 = 0.0f32;
76        let mut sum3 = 0.0f32;
77
78        for i in 0..chunks {
79            let base = i * 4;
80            sum0 += a[base] * b[base];
81            sum1 += a[base + 1] * b[base + 1];
82            sum2 += a[base + 2] * b[base + 2];
83            sum3 += a[base + 3] * b[base + 3];
84        }
85
86        let base = chunks * 4;
87        for i in 0..remainder {
88            sum0 += a[base + i] * b[base + i];
89        }
90
91        sum0 + sum1 + sum2 + sum3
92    }
93}
94
95/// Per-window cache containing sorted projections
96#[derive(Debug, Clone)]
97pub struct WindowCache {
98    /// Projected keys [num_keys × P]
99    pub key_projections: Vec<Vec<f32>>,
100    /// Sorted indices per projection [P × num_keys]
101    pub sorted_indices: Vec<Vec<usize>>,
102    /// Sorted values per projection [P × num_keys]
103    pub sorted_values: Vec<Vec<f32>>,
104    /// Histogram bins per projection [P × num_bins]
105    pub histograms: Option<Vec<Vec<f32>>>,
106    /// CDF per projection [P × num_bins]
107    pub cdfs: Option<Vec<Vec<f32>>>,
108    /// Number of keys in window
109    pub num_keys: usize,
110}
111
112impl WindowCache {
113    /// Build cache from keys using projection cache
114    pub fn build(keys: &[&[f32]], proj_cache: &ProjectionCache) -> Self {
115        let num_keys = keys.len();
116        let num_proj = proj_cache.num_projections;
117
118        // Project all keys
119        let key_projections: Vec<Vec<f32>> = keys
120            .iter()
121            .map(|k| proj_cache.project(k))
122            .collect();
123
124        // Sort indices and values for each projection
125        let mut sorted_indices = vec![Vec::with_capacity(num_keys); num_proj];
126        let mut sorted_values = vec![Vec::with_capacity(num_keys); num_proj];
127
128        for p in 0..num_proj {
129            let mut indexed: Vec<(usize, f32)> = key_projections
130                .iter()
131                .enumerate()
132                .map(|(i, projs)| (i, projs[p]))
133                .collect();
134            indexed.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
135
136            sorted_indices[p] = indexed.iter().map(|(i, _)| *i).collect();
137            sorted_values[p] = indexed.iter().map(|(_, v)| *v).collect();
138        }
139
140        Self {
141            key_projections,
142            sorted_indices,
143            sorted_values,
144            histograms: None,
145            cdfs: None,
146            num_keys,
147        }
148    }
149
150    /// Build histograms for ultra-fast CDF comparison
151    pub fn build_histograms(&mut self, num_bins: usize) {
152        let num_proj = self.sorted_values.len();
153
154        let mut histograms = vec![vec![0.0f32; num_bins]; num_proj];
155        let mut cdfs = vec![vec![0.0f32; num_bins]; num_proj];
156
157        for p in 0..num_proj {
158            let vals = &self.sorted_values[p];
159            if vals.is_empty() {
160                continue;
161            }
162
163            let min_val = vals[0];
164            let max_val = vals[vals.len() - 1];
165            let range = (max_val - min_val).max(1e-8);
166
167            // Build histogram
168            for &v in vals {
169                let bin = ((v - min_val) / range * (num_bins - 1) as f32)
170                    .clamp(0.0, (num_bins - 1) as f32) as usize;
171                histograms[p][bin] += 1.0 / self.num_keys as f32;
172            }
173
174            // Build CDF
175            let mut cumsum = 0.0f32;
176            for bin in 0..num_bins {
177                cumsum += histograms[p][bin];
178                cdfs[p][bin] = cumsum;
179            }
180        }
181
182        self.histograms = Some(histograms);
183        self.cdfs = Some(cdfs);
184    }
185
186    /// Get sorted values for a projection
187    #[inline]
188    pub fn get_sorted(&self, projection_idx: usize) -> &[f32] {
189        &self.sorted_values[projection_idx]
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_projection_cache() {
199        let cache = ProjectionCache::new(64, 8, 42);
200
201        assert_eq!(cache.num_projections, 8);
202        assert_eq!(cache.dim, 64);
203
204        // Check directions are unit vectors
205        for dir in &cache.directions {
206            let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
207            assert!((norm - 1.0).abs() < 1e-5);
208        }
209    }
210
211    #[test]
212    fn test_window_cache() {
213        let proj_cache = ProjectionCache::new(32, 4, 42);
214
215        let keys: Vec<Vec<f32>> = (0..10)
216            .map(|i| vec![i as f32 * 0.1; 32])
217            .collect();
218        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
219
220        let window_cache = WindowCache::build(&keys_refs, &proj_cache);
221
222        assert_eq!(window_cache.num_keys, 10);
223        assert_eq!(window_cache.sorted_indices.len(), 4);
224    }
225
226    #[test]
227    fn test_histograms() {
228        let proj_cache = ProjectionCache::new(16, 2, 42);
229
230        let keys: Vec<Vec<f32>> = (0..20)
231            .map(|i| vec![i as f32 * 0.05; 16])
232            .collect();
233        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
234
235        let mut window_cache = WindowCache::build(&keys_refs, &proj_cache);
236        window_cache.build_histograms(10);
237
238        assert!(window_cache.cdfs.is_some());
239
240        // CDF should end at 1.0
241        let cdfs = window_cache.cdfs.as_ref().unwrap();
242        for cdf in cdfs {
243            assert!((cdf[cdf.len() - 1] - 1.0).abs() < 1e-5);
244        }
245    }
246}