ruvector_attention/transport/
centroid_ot.rs1use crate::error::{AttentionError, AttentionResult};
14use crate::traits::Attention;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CentroidOTConfig {
20 pub dim: usize,
22 pub num_centroids: usize,
24 pub kmeans_iterations: usize,
26 pub temperature: f32,
28 pub sinkhorn_reg: f32,
30 pub sinkhorn_iterations: usize,
32 pub seed: u64,
34}
35
36impl Default for CentroidOTConfig {
37 fn default() -> Self {
38 Self {
39 dim: 512,
40 num_centroids: 16,
41 kmeans_iterations: 10,
42 temperature: 1.0,
43 sinkhorn_reg: 0.1,
44 sinkhorn_iterations: 20,
45 seed: 42,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct CentroidCache {
53 pub centroids: Vec<Vec<f32>>,
55 pub weights: Vec<f32>,
57 pub assignments: Vec<usize>,
59 pub num_keys: usize,
61}
62
63impl CentroidCache {
64 pub fn build(keys: &[&[f32]], num_centroids: usize, iterations: usize, seed: u64) -> Self {
66 let num_keys = keys.len();
67 let m = num_centroids.min(num_keys);
68
69 if num_keys == 0 || keys[0].is_empty() {
70 return Self {
71 centroids: vec![],
72 weights: vec![],
73 assignments: vec![],
74 num_keys: 0,
75 };
76 }
77
78 let dim = keys[0].len();
79
80 use rand::prelude::*;
82 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
83 let mut indices: Vec<usize> = (0..num_keys).collect();
84 indices.shuffle(&mut rng);
85
86 let mut centroids: Vec<Vec<f32>> = indices
87 .iter()
88 .take(m)
89 .map(|&i| keys[i].to_vec())
90 .collect();
91
92 let mut assignments = vec![0usize; num_keys];
93
94 for _ in 0..iterations {
96 for (key_idx, key) in keys.iter().enumerate() {
98 let mut min_dist = f32::MAX;
99 let mut best_centroid = 0;
100
101 for (c_idx, centroid) in centroids.iter().enumerate() {
102 let dist = Self::squared_distance(key, centroid);
103 if dist < min_dist {
104 min_dist = dist;
105 best_centroid = c_idx;
106 }
107 }
108
109 assignments[key_idx] = best_centroid;
110 }
111
112 let mut new_centroids = vec![vec![0.0f32; dim]; m];
114 let mut counts = vec![0usize; m];
115
116 for (key_idx, &assignment) in assignments.iter().enumerate() {
117 counts[assignment] += 1;
118 for (d, &v) in keys[key_idx].iter().enumerate() {
119 new_centroids[assignment][d] += v;
120 }
121 }
122
123 for c_idx in 0..m {
124 if counts[c_idx] > 0 {
125 for d in 0..dim {
126 new_centroids[c_idx][d] /= counts[c_idx] as f32;
127 }
128 centroids[c_idx] = new_centroids[c_idx].clone();
129 }
130 }
131 }
132
133 let mut counts = vec![0usize; m];
135 for &a in &assignments {
136 counts[a] += 1;
137 }
138 let weights: Vec<f32> = counts
139 .iter()
140 .map(|&c| c as f32 / num_keys as f32)
141 .collect();
142
143 Self {
144 centroids,
145 weights,
146 assignments,
147 num_keys,
148 }
149 }
150
151 #[inline]
153 fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
154 let len = a.len();
155 let chunks = len / 4;
156 let remainder = len % 4;
157
158 let mut sum0 = 0.0f32;
159 let mut sum1 = 0.0f32;
160 let mut sum2 = 0.0f32;
161 let mut sum3 = 0.0f32;
162
163 for i in 0..chunks {
164 let base = i * 4;
165 let d0 = a[base] - b[base];
166 let d1 = a[base + 1] - b[base + 1];
167 let d2 = a[base + 2] - b[base + 2];
168 let d3 = a[base + 3] - b[base + 3];
169 sum0 += d0 * d0;
170 sum1 += d1 * d1;
171 sum2 += d2 * d2;
172 sum3 += d3 * d3;
173 }
174
175 let base = chunks * 4;
176 for i in 0..remainder {
177 let d = a[base + i] - b[base + i];
178 sum0 += d * d;
179 }
180
181 sum0 + sum1 + sum2 + sum3
182 }
183}
184
185#[derive(Debug, Clone)]
190pub struct CentroidOTAttention {
191 config: CentroidOTConfig,
192}
193
194impl CentroidOTAttention {
195 pub fn new(config: CentroidOTConfig) -> Self {
197 Self { config }
198 }
199
200 pub fn with_dim(dim: usize) -> Self {
202 Self::new(CentroidOTConfig {
203 dim,
204 ..Default::default()
205 })
206 }
207
208 pub fn build_cache(&self, keys: &[&[f32]]) -> CentroidCache {
210 CentroidCache::build(
211 keys,
212 self.config.num_centroids,
213 self.config.kmeans_iterations,
214 self.config.seed,
215 )
216 }
217
218 pub fn compute_with_cache(
220 &self,
221 query: &[f32],
222 cache: &CentroidCache,
223 values: &[&[f32]],
224 ) -> AttentionResult<Vec<f32>> {
225 if cache.centroids.is_empty() {
226 return Err(AttentionError::InvalidConfig("Empty cache".into()));
227 }
228
229 let centroid_distances: Vec<f32> = cache
231 .centroids
232 .iter()
233 .map(|c| CentroidCache::squared_distance(query, c).sqrt())
234 .collect();
235
236 let centroid_logits: Vec<f32> = centroid_distances
238 .iter()
239 .map(|d| -d / self.config.temperature)
240 .collect();
241
242 let centroid_weights = Self::stable_softmax(¢roid_logits);
243
244 let mut key_weights = vec![0.0f32; cache.num_keys];
246 for (key_idx, &assignment) in cache.assignments.iter().enumerate() {
247 let cluster_size = cache.assignments.iter().filter(|&&a| a == assignment).count();
249 if cluster_size > 0 {
250 key_weights[key_idx] = centroid_weights[assignment] / cluster_size as f32;
251 }
252 }
253
254 self.weighted_sum(&key_weights, values)
256 }
257
258 #[allow(dead_code)]
260 fn sinkhorn_distance(&self, query: &[f32], cache: &CentroidCache) -> f32 {
261 let m = cache.centroids.len();
262 if m == 0 {
263 return 0.0;
264 }
265
266 let costs: Vec<f32> = cache
268 .centroids
269 .iter()
270 .map(|c| CentroidCache::squared_distance(query, c))
271 .collect();
272
273 let reg = self.config.sinkhorn_reg;
278 let log_k: Vec<f32> = costs.iter().map(|c| -c / reg).collect();
279
280 let mut log_v = vec![0.0f32; m];
281 let log_b: Vec<f32> = cache.weights.iter().map(|w| w.ln().max(-20.0)).collect();
282
283 for _ in 0..self.config.sinkhorn_iterations {
284 let log_sum: f32 = log_k
286 .iter()
287 .zip(log_v.iter())
288 .map(|(&lk, &lv)| lk + lv)
289 .fold(f32::NEG_INFINITY, |max, x| if x > max { x } else { max });
290
291 let exp_sum: f32 = log_k
292 .iter()
293 .zip(log_v.iter())
294 .map(|(&lk, &lv)| (lk + lv - log_sum).exp())
295 .sum();
296
297 let log_u = -log_sum - exp_sum.ln();
298
299 for j in 0..m {
301 log_v[j] = log_b[j] - (log_u + log_k[j]);
302 }
303 }
304
305 let mut total_cost = 0.0f32;
307 for j in 0..m {
308 let gamma = (log_v[j] + log_k[j]).exp();
309 total_cost += gamma * costs[j];
310 }
311
312 total_cost
313 }
314
315 fn stable_softmax(logits: &[f32]) -> Vec<f32> {
317 if logits.is_empty() {
318 return vec![];
319 }
320
321 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
322 let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
323 let sum: f32 = exp_logits.iter().sum();
324
325 exp_logits.iter().map(|&e| e / sum).collect()
326 }
327
328 fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
330 if weights.is_empty() || values.is_empty() {
331 return Err(AttentionError::InvalidConfig("Empty inputs".into()));
332 }
333
334 let dim = values[0].len();
335 let mut output = vec![0.0f32; dim];
336
337 for (weight, value) in weights.iter().zip(values.iter()) {
338 for (o, &v) in output.iter_mut().zip(value.iter()) {
339 *o += weight * v;
340 }
341 }
342
343 Ok(output)
344 }
345}
346
347impl Attention for CentroidOTAttention {
348 fn compute(
349 &self,
350 query: &[f32],
351 keys: &[&[f32]],
352 values: &[&[f32]],
353 ) -> AttentionResult<Vec<f32>> {
354 let cache = self.build_cache(keys);
355 self.compute_with_cache(query, &cache, values)
356 }
357
358 fn compute_with_mask(
359 &self,
360 query: &[f32],
361 keys: &[&[f32]],
362 values: &[&[f32]],
363 mask: Option<&[bool]>,
364 ) -> AttentionResult<Vec<f32>> {
365 if let Some(m) = mask {
366 let filtered: Vec<(&[f32], &[f32])> = keys
367 .iter()
368 .zip(values.iter())
369 .enumerate()
370 .filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
371 .map(|(_, (k, v))| (*k, *v))
372 .collect();
373
374 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
375 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
376
377 self.compute(query, &filtered_keys, &filtered_values)
378 } else {
379 self.compute(query, keys, values)
380 }
381 }
382
383 fn dim(&self) -> usize {
384 self.config.dim
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_centroid_cache() {
394 let keys: Vec<Vec<f32>> = (0..50)
395 .map(|i| vec![i as f32 * 0.1; 32])
396 .collect();
397 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
398
399 let cache = CentroidCache::build(&keys_refs, 8, 5, 42);
400
401 assert_eq!(cache.centroids.len(), 8);
402 assert_eq!(cache.weights.len(), 8);
403 assert_eq!(cache.assignments.len(), 50);
404
405 let weight_sum: f32 = cache.weights.iter().sum();
407 assert!((weight_sum - 1.0).abs() < 1e-5);
408 }
409
410 #[test]
411 fn test_centroid_ot_attention() {
412 let attention = CentroidOTAttention::with_dim(32);
413
414 let query = vec![0.5f32; 32];
415 let keys: Vec<Vec<f32>> = (0..30)
416 .map(|i| vec![i as f32 * 0.05; 32])
417 .collect();
418 let values: Vec<Vec<f32>> = (0..30)
419 .map(|i| vec![i as f32; 32])
420 .collect();
421
422 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
423 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
424
425 let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
426 assert_eq!(output.len(), 32);
427 }
428
429 #[test]
430 fn test_cache_reuse() {
431 let attention = CentroidOTAttention::with_dim(64);
432
433 let keys: Vec<Vec<f32>> = (0..40)
434 .map(|i| vec![i as f32 * 0.025; 64])
435 .collect();
436 let values: Vec<Vec<f32>> = (0..40)
437 .map(|i| vec![i as f32; 64])
438 .collect();
439
440 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
441 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
442
443 let cache = attention.build_cache(&keys_refs);
445
446 for q in 0..10 {
448 let query = vec![q as f32 * 0.1; 64];
449 let output = attention.compute_with_cache(&query, &cache, &values_refs).unwrap();
450 assert_eq!(output.len(), 64);
451 }
452 }
453}