ruvector_attention/transport/
cached_projections.rs1use rand::prelude::*;
7use rand::rngs::StdRng;
8
9#[derive(Debug, Clone)]
11pub struct ProjectionCache {
12 pub directions: Vec<Vec<f32>>,
14 pub num_projections: usize,
16 pub dim: usize,
18}
19
20impl ProjectionCache {
21 pub fn new(dim: usize, num_projections: usize, seed: u64) -> Self {
23 let mut rng = StdRng::seed_from_u64(seed);
24
25 let directions: Vec<Vec<f32>> = (0..num_projections)
26 .map(|_| {
27 let mut dir: Vec<f32> = (0..dim)
28 .map(|_| rng.sample::<f32, _>(rand::distributions::Standard) * 2.0 - 1.0)
29 .collect();
30 let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
32 if norm > 1e-8 {
33 for x in &mut dir {
34 *x /= norm;
35 }
36 }
37 dir
38 })
39 .collect();
40
41 Self {
42 directions,
43 num_projections,
44 dim,
45 }
46 }
47
48 #[inline]
51 pub fn project(&self, vector: &[f32]) -> Vec<f32> {
52 self.directions
53 .iter()
54 .map(|dir| Self::dot_product_simd(vector, dir))
55 .collect()
56 }
57
58 #[inline]
60 pub fn project_into(&self, vector: &[f32], out: &mut [f32]) {
61 for (i, dir) in self.directions.iter().enumerate() {
62 out[i] = Self::dot_product_simd(vector, dir);
63 }
64 }
65
66 #[inline(always)]
68 fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
69 let len = a.len();
70 let chunks = len / 4;
71 let remainder = len % 4;
72
73 let mut sum0 = 0.0f32;
74 let mut sum1 = 0.0f32;
75 let mut sum2 = 0.0f32;
76 let mut sum3 = 0.0f32;
77
78 for i in 0..chunks {
79 let base = i * 4;
80 sum0 += a[base] * b[base];
81 sum1 += a[base + 1] * b[base + 1];
82 sum2 += a[base + 2] * b[base + 2];
83 sum3 += a[base + 3] * b[base + 3];
84 }
85
86 let base = chunks * 4;
87 for i in 0..remainder {
88 sum0 += a[base + i] * b[base + i];
89 }
90
91 sum0 + sum1 + sum2 + sum3
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct WindowCache {
98 pub key_projections: Vec<Vec<f32>>,
100 pub sorted_indices: Vec<Vec<usize>>,
102 pub sorted_values: Vec<Vec<f32>>,
104 pub histograms: Option<Vec<Vec<f32>>>,
106 pub cdfs: Option<Vec<Vec<f32>>>,
108 pub num_keys: usize,
110}
111
112impl WindowCache {
113 pub fn build(keys: &[&[f32]], proj_cache: &ProjectionCache) -> Self {
115 let num_keys = keys.len();
116 let num_proj = proj_cache.num_projections;
117
118 let key_projections: Vec<Vec<f32>> = keys
120 .iter()
121 .map(|k| proj_cache.project(k))
122 .collect();
123
124 let mut sorted_indices = vec![Vec::with_capacity(num_keys); num_proj];
126 let mut sorted_values = vec![Vec::with_capacity(num_keys); num_proj];
127
128 for p in 0..num_proj {
129 let mut indexed: Vec<(usize, f32)> = key_projections
130 .iter()
131 .enumerate()
132 .map(|(i, projs)| (i, projs[p]))
133 .collect();
134 indexed.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
135
136 sorted_indices[p] = indexed.iter().map(|(i, _)| *i).collect();
137 sorted_values[p] = indexed.iter().map(|(_, v)| *v).collect();
138 }
139
140 Self {
141 key_projections,
142 sorted_indices,
143 sorted_values,
144 histograms: None,
145 cdfs: None,
146 num_keys,
147 }
148 }
149
150 pub fn build_histograms(&mut self, num_bins: usize) {
152 let num_proj = self.sorted_values.len();
153
154 let mut histograms = vec![vec![0.0f32; num_bins]; num_proj];
155 let mut cdfs = vec![vec![0.0f32; num_bins]; num_proj];
156
157 for p in 0..num_proj {
158 let vals = &self.sorted_values[p];
159 if vals.is_empty() {
160 continue;
161 }
162
163 let min_val = vals[0];
164 let max_val = vals[vals.len() - 1];
165 let range = (max_val - min_val).max(1e-8);
166
167 for &v in vals {
169 let bin = ((v - min_val) / range * (num_bins - 1) as f32)
170 .clamp(0.0, (num_bins - 1) as f32) as usize;
171 histograms[p][bin] += 1.0 / self.num_keys as f32;
172 }
173
174 let mut cumsum = 0.0f32;
176 for bin in 0..num_bins {
177 cumsum += histograms[p][bin];
178 cdfs[p][bin] = cumsum;
179 }
180 }
181
182 self.histograms = Some(histograms);
183 self.cdfs = Some(cdfs);
184 }
185
186 #[inline]
188 pub fn get_sorted(&self, projection_idx: usize) -> &[f32] {
189 &self.sorted_values[projection_idx]
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_projection_cache() {
199 let cache = ProjectionCache::new(64, 8, 42);
200
201 assert_eq!(cache.num_projections, 8);
202 assert_eq!(cache.dim, 64);
203
204 for dir in &cache.directions {
206 let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
207 assert!((norm - 1.0).abs() < 1e-5);
208 }
209 }
210
211 #[test]
212 fn test_window_cache() {
213 let proj_cache = ProjectionCache::new(32, 4, 42);
214
215 let keys: Vec<Vec<f32>> = (0..10)
216 .map(|i| vec![i as f32 * 0.1; 32])
217 .collect();
218 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
219
220 let window_cache = WindowCache::build(&keys_refs, &proj_cache);
221
222 assert_eq!(window_cache.num_keys, 10);
223 assert_eq!(window_cache.sorted_indices.len(), 4);
224 }
225
226 #[test]
227 fn test_histograms() {
228 let proj_cache = ProjectionCache::new(16, 2, 42);
229
230 let keys: Vec<Vec<f32>> = (0..20)
231 .map(|i| vec![i as f32 * 0.05; 16])
232 .collect();
233 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
234
235 let mut window_cache = WindowCache::build(&keys_refs, &proj_cache);
236 window_cache.build_histograms(10);
237
238 assert!(window_cache.cdfs.is_some());
239
240 let cdfs = window_cache.cdfs.as_ref().unwrap();
242 for cdf in cdfs {
243 assert!((cdf[cdf.len() - 1] - 1.0).abs() < 1e-5);
244 }
245 }
246}