ruvector_attention/transport/
cached_projections.rs1use rand::prelude::*;
7use rand::rngs::StdRng;
8
9#[derive(Debug, Clone)]
11pub struct ProjectionCache {
12 pub directions: Vec<Vec<f32>>,
14 pub num_projections: usize,
16 pub dim: usize,
18}
19
20impl ProjectionCache {
21 pub fn new(dim: usize, num_projections: usize, seed: u64) -> Self {
23 let mut rng = StdRng::seed_from_u64(seed);
24
25 let directions: Vec<Vec<f32>> = (0..num_projections)
26 .map(|_| {
27 let mut dir: Vec<f32> = (0..dim)
28 .map(|_| rng.sample::<f32, _>(rand::distributions::Standard) * 2.0 - 1.0)
29 .collect();
30 let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
32 if norm > 1e-8 {
33 for x in &mut dir {
34 *x /= norm;
35 }
36 }
37 dir
38 })
39 .collect();
40
41 Self {
42 directions,
43 num_projections,
44 dim,
45 }
46 }
47
48 #[inline]
51 pub fn project(&self, vector: &[f32]) -> Vec<f32> {
52 self.directions
53 .iter()
54 .map(|dir| Self::dot_product_simd(vector, dir))
55 .collect()
56 }
57
58 #[inline]
60 pub fn project_into(&self, vector: &[f32], out: &mut [f32]) {
61 for (i, dir) in self.directions.iter().enumerate() {
62 out[i] = Self::dot_product_simd(vector, dir);
63 }
64 }
65
66 #[inline(always)]
68 fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
69 let len = a.len();
70 let chunks = len / 4;
71 let remainder = len % 4;
72
73 let mut sum0 = 0.0f32;
74 let mut sum1 = 0.0f32;
75 let mut sum2 = 0.0f32;
76 let mut sum3 = 0.0f32;
77
78 for i in 0..chunks {
79 let base = i * 4;
80 sum0 += a[base] * b[base];
81 sum1 += a[base + 1] * b[base + 1];
82 sum2 += a[base + 2] * b[base + 2];
83 sum3 += a[base + 3] * b[base + 3];
84 }
85
86 let base = chunks * 4;
87 for i in 0..remainder {
88 sum0 += a[base + i] * b[base + i];
89 }
90
91 sum0 + sum1 + sum2 + sum3
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct WindowCache {
98 pub key_projections: Vec<Vec<f32>>,
100 pub sorted_indices: Vec<Vec<usize>>,
102 pub sorted_values: Vec<Vec<f32>>,
104 pub histograms: Option<Vec<Vec<f32>>>,
106 pub cdfs: Option<Vec<Vec<f32>>>,
108 pub num_keys: usize,
110}
111
112impl WindowCache {
113 pub fn build(keys: &[&[f32]], proj_cache: &ProjectionCache) -> Self {
115 let num_keys = keys.len();
116 let num_proj = proj_cache.num_projections;
117
118 let key_projections: Vec<Vec<f32>> = keys.iter().map(|k| proj_cache.project(k)).collect();
120
121 let mut sorted_indices = vec![Vec::with_capacity(num_keys); num_proj];
123 let mut sorted_values = vec![Vec::with_capacity(num_keys); num_proj];
124
125 for p in 0..num_proj {
126 let mut indexed: Vec<(usize, f32)> = key_projections
127 .iter()
128 .enumerate()
129 .map(|(i, projs)| (i, projs[p]))
130 .collect();
131 indexed.sort_unstable_by(|a, b| {
132 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
133 });
134
135 sorted_indices[p] = indexed.iter().map(|(i, _)| *i).collect();
136 sorted_values[p] = indexed.iter().map(|(_, v)| *v).collect();
137 }
138
139 Self {
140 key_projections,
141 sorted_indices,
142 sorted_values,
143 histograms: None,
144 cdfs: None,
145 num_keys,
146 }
147 }
148
149 pub fn build_histograms(&mut self, num_bins: usize) {
151 let num_proj = self.sorted_values.len();
152
153 let mut histograms = vec![vec![0.0f32; num_bins]; num_proj];
154 let mut cdfs = vec![vec![0.0f32; num_bins]; num_proj];
155
156 for p in 0..num_proj {
157 let vals = &self.sorted_values[p];
158 if vals.is_empty() {
159 continue;
160 }
161
162 let min_val = vals[0];
163 let max_val = vals[vals.len() - 1];
164 let range = (max_val - min_val).max(1e-8);
165
166 for &v in vals {
168 let bin = ((v - min_val) / range * (num_bins - 1) as f32)
169 .clamp(0.0, (num_bins - 1) as f32) as usize;
170 histograms[p][bin] += 1.0 / self.num_keys as f32;
171 }
172
173 let mut cumsum = 0.0f32;
175 for bin in 0..num_bins {
176 cumsum += histograms[p][bin];
177 cdfs[p][bin] = cumsum;
178 }
179 }
180
181 self.histograms = Some(histograms);
182 self.cdfs = Some(cdfs);
183 }
184
185 #[inline]
187 pub fn get_sorted(&self, projection_idx: usize) -> &[f32] {
188 &self.sorted_values[projection_idx]
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_projection_cache() {
198 let cache = ProjectionCache::new(64, 8, 42);
199
200 assert_eq!(cache.num_projections, 8);
201 assert_eq!(cache.dim, 64);
202
203 for dir in &cache.directions {
205 let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
206 assert!((norm - 1.0).abs() < 1e-5);
207 }
208 }
209
210 #[test]
211 fn test_window_cache() {
212 let proj_cache = ProjectionCache::new(32, 4, 42);
213
214 let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 32]).collect();
215 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
216
217 let window_cache = WindowCache::build(&keys_refs, &proj_cache);
218
219 assert_eq!(window_cache.num_keys, 10);
220 assert_eq!(window_cache.sorted_indices.len(), 4);
221 }
222
223 #[test]
224 fn test_histograms() {
225 let proj_cache = ProjectionCache::new(16, 2, 42);
226
227 let keys: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 16]).collect();
228 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
229
230 let mut window_cache = WindowCache::build(&keys_refs, &proj_cache);
231 window_cache.build_histograms(10);
232
233 assert!(window_cache.cdfs.is_some());
234
235 let cdfs = window_cache.cdfs.as_ref().unwrap();
237 for cdf in cdfs {
238 assert!((cdf[cdf.len() - 1] - 1.0).abs() < 1e-5);
239 }
240 }
241}