ruvector_attention/transport/
centroid_ot.rs1use crate::error::{AttentionError, AttentionResult};
14use crate::traits::Attention;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CentroidOTConfig {
20 pub dim: usize,
22 pub num_centroids: usize,
24 pub kmeans_iterations: usize,
26 pub temperature: f32,
28 pub sinkhorn_reg: f32,
30 pub sinkhorn_iterations: usize,
32 pub seed: u64,
34}
35
36impl Default for CentroidOTConfig {
37 fn default() -> Self {
38 Self {
39 dim: 512,
40 num_centroids: 16,
41 kmeans_iterations: 10,
42 temperature: 1.0,
43 sinkhorn_reg: 0.1,
44 sinkhorn_iterations: 20,
45 seed: 42,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct CentroidCache {
53 pub centroids: Vec<Vec<f32>>,
55 pub weights: Vec<f32>,
57 pub assignments: Vec<usize>,
59 pub num_keys: usize,
61}
62
63impl CentroidCache {
64 pub fn build(keys: &[&[f32]], num_centroids: usize, iterations: usize, seed: u64) -> Self {
66 let num_keys = keys.len();
67 let m = num_centroids.min(num_keys);
68
69 if num_keys == 0 || keys[0].is_empty() {
70 return Self {
71 centroids: vec![],
72 weights: vec![],
73 assignments: vec![],
74 num_keys: 0,
75 };
76 }
77
78 let dim = keys[0].len();
79
80 use rand::prelude::*;
82 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
83 let mut indices: Vec<usize> = (0..num_keys).collect();
84 indices.shuffle(&mut rng);
85
86 let mut centroids: Vec<Vec<f32>> =
87 indices.iter().take(m).map(|&i| keys[i].to_vec()).collect();
88
89 let mut assignments = vec![0usize; num_keys];
90
91 for _ in 0..iterations {
93 for (key_idx, key) in keys.iter().enumerate() {
95 let mut min_dist = f32::MAX;
96 let mut best_centroid = 0;
97
98 for (c_idx, centroid) in centroids.iter().enumerate() {
99 let dist = Self::squared_distance(key, centroid);
100 if dist < min_dist {
101 min_dist = dist;
102 best_centroid = c_idx;
103 }
104 }
105
106 assignments[key_idx] = best_centroid;
107 }
108
109 let mut new_centroids = vec![vec![0.0f32; dim]; m];
111 let mut counts = vec![0usize; m];
112
113 for (key_idx, &assignment) in assignments.iter().enumerate() {
114 counts[assignment] += 1;
115 for (d, &v) in keys[key_idx].iter().enumerate() {
116 new_centroids[assignment][d] += v;
117 }
118 }
119
120 for c_idx in 0..m {
121 if counts[c_idx] > 0 {
122 for d in 0..dim {
123 new_centroids[c_idx][d] /= counts[c_idx] as f32;
124 }
125 centroids[c_idx] = new_centroids[c_idx].clone();
126 }
127 }
128 }
129
130 let mut counts = vec![0usize; m];
132 for &a in &assignments {
133 counts[a] += 1;
134 }
135 let weights: Vec<f32> = counts.iter().map(|&c| c as f32 / num_keys as f32).collect();
136
137 Self {
138 centroids,
139 weights,
140 assignments,
141 num_keys,
142 }
143 }
144
145 #[inline]
147 fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
148 let len = a.len();
149 let chunks = len / 4;
150 let remainder = len % 4;
151
152 let mut sum0 = 0.0f32;
153 let mut sum1 = 0.0f32;
154 let mut sum2 = 0.0f32;
155 let mut sum3 = 0.0f32;
156
157 for i in 0..chunks {
158 let base = i * 4;
159 let d0 = a[base] - b[base];
160 let d1 = a[base + 1] - b[base + 1];
161 let d2 = a[base + 2] - b[base + 2];
162 let d3 = a[base + 3] - b[base + 3];
163 sum0 += d0 * d0;
164 sum1 += d1 * d1;
165 sum2 += d2 * d2;
166 sum3 += d3 * d3;
167 }
168
169 let base = chunks * 4;
170 for i in 0..remainder {
171 let d = a[base + i] - b[base + i];
172 sum0 += d * d;
173 }
174
175 sum0 + sum1 + sum2 + sum3
176 }
177}
178
179#[derive(Debug, Clone)]
184pub struct CentroidOTAttention {
185 config: CentroidOTConfig,
186}
187
188impl CentroidOTAttention {
189 pub fn new(config: CentroidOTConfig) -> Self {
191 Self { config }
192 }
193
194 pub fn with_dim(dim: usize) -> Self {
196 Self::new(CentroidOTConfig {
197 dim,
198 ..Default::default()
199 })
200 }
201
202 pub fn build_cache(&self, keys: &[&[f32]]) -> CentroidCache {
204 CentroidCache::build(
205 keys,
206 self.config.num_centroids,
207 self.config.kmeans_iterations,
208 self.config.seed,
209 )
210 }
211
212 pub fn compute_with_cache(
214 &self,
215 query: &[f32],
216 cache: &CentroidCache,
217 values: &[&[f32]],
218 ) -> AttentionResult<Vec<f32>> {
219 if cache.centroids.is_empty() {
220 return Err(AttentionError::InvalidConfig("Empty cache".into()));
221 }
222
223 let centroid_distances: Vec<f32> = cache
225 .centroids
226 .iter()
227 .map(|c| CentroidCache::squared_distance(query, c).sqrt())
228 .collect();
229
230 let centroid_logits: Vec<f32> = centroid_distances
232 .iter()
233 .map(|d| -d / self.config.temperature)
234 .collect();
235
236 let centroid_weights = Self::stable_softmax(¢roid_logits);
237
238 let mut key_weights = vec![0.0f32; cache.num_keys];
240 for (key_idx, &assignment) in cache.assignments.iter().enumerate() {
241 let cluster_size = cache
243 .assignments
244 .iter()
245 .filter(|&&a| a == assignment)
246 .count();
247 if cluster_size > 0 {
248 key_weights[key_idx] = centroid_weights[assignment] / cluster_size as f32;
249 }
250 }
251
252 self.weighted_sum(&key_weights, values)
254 }
255
256 #[allow(dead_code)]
258 fn sinkhorn_distance(&self, query: &[f32], cache: &CentroidCache) -> f32 {
259 let m = cache.centroids.len();
260 if m == 0 {
261 return 0.0;
262 }
263
264 let costs: Vec<f32> = cache
266 .centroids
267 .iter()
268 .map(|c| CentroidCache::squared_distance(query, c))
269 .collect();
270
271 let reg = self.config.sinkhorn_reg;
276 let log_k: Vec<f32> = costs.iter().map(|c| -c / reg).collect();
277
278 let mut log_v = vec![0.0f32; m];
279 let log_b: Vec<f32> = cache.weights.iter().map(|w| w.ln().max(-20.0)).collect();
280
281 for _ in 0..self.config.sinkhorn_iterations {
282 let log_sum: f32 = log_k
284 .iter()
285 .zip(log_v.iter())
286 .map(|(&lk, &lv)| lk + lv)
287 .fold(f32::NEG_INFINITY, |max, x| if x > max { x } else { max });
288
289 let exp_sum: f32 = log_k
290 .iter()
291 .zip(log_v.iter())
292 .map(|(&lk, &lv)| (lk + lv - log_sum).exp())
293 .sum();
294
295 let log_u = -log_sum - exp_sum.ln();
296
297 for j in 0..m {
299 log_v[j] = log_b[j] - (log_u + log_k[j]);
300 }
301 }
302
303 let mut total_cost = 0.0f32;
305 for j in 0..m {
306 let gamma = (log_v[j] + log_k[j]).exp();
307 total_cost += gamma * costs[j];
308 }
309
310 total_cost
311 }
312
313 fn stable_softmax(logits: &[f32]) -> Vec<f32> {
315 if logits.is_empty() {
316 return vec![];
317 }
318
319 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
320 let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
321 let sum: f32 = exp_logits.iter().sum();
322
323 exp_logits.iter().map(|&e| e / sum).collect()
324 }
325
326 fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
328 if weights.is_empty() || values.is_empty() {
329 return Err(AttentionError::InvalidConfig("Empty inputs".into()));
330 }
331
332 let dim = values[0].len();
333 let mut output = vec![0.0f32; dim];
334
335 for (weight, value) in weights.iter().zip(values.iter()) {
336 for (o, &v) in output.iter_mut().zip(value.iter()) {
337 *o += weight * v;
338 }
339 }
340
341 Ok(output)
342 }
343}
344
345impl Attention for CentroidOTAttention {
346 fn compute(
347 &self,
348 query: &[f32],
349 keys: &[&[f32]],
350 values: &[&[f32]],
351 ) -> AttentionResult<Vec<f32>> {
352 let cache = self.build_cache(keys);
353 self.compute_with_cache(query, &cache, values)
354 }
355
356 fn compute_with_mask(
357 &self,
358 query: &[f32],
359 keys: &[&[f32]],
360 values: &[&[f32]],
361 mask: Option<&[bool]>,
362 ) -> AttentionResult<Vec<f32>> {
363 if let Some(m) = mask {
364 let filtered: Vec<(&[f32], &[f32])> = keys
365 .iter()
366 .zip(values.iter())
367 .enumerate()
368 .filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
369 .map(|(_, (k, v))| (*k, *v))
370 .collect();
371
372 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
373 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
374
375 self.compute(query, &filtered_keys, &filtered_values)
376 } else {
377 self.compute(query, keys, values)
378 }
379 }
380
381 fn dim(&self) -> usize {
382 self.config.dim
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_centroid_cache() {
392 let keys: Vec<Vec<f32>> = (0..50).map(|i| vec![i as f32 * 0.1; 32]).collect();
393 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
394
395 let cache = CentroidCache::build(&keys_refs, 8, 5, 42);
396
397 assert_eq!(cache.centroids.len(), 8);
398 assert_eq!(cache.weights.len(), 8);
399 assert_eq!(cache.assignments.len(), 50);
400
401 let weight_sum: f32 = cache.weights.iter().sum();
403 assert!((weight_sum - 1.0).abs() < 1e-5);
404 }
405
406 #[test]
407 fn test_centroid_ot_attention() {
408 let attention = CentroidOTAttention::with_dim(32);
409
410 let query = vec![0.5f32; 32];
411 let keys: Vec<Vec<f32>> = (0..30).map(|i| vec![i as f32 * 0.05; 32]).collect();
412 let values: Vec<Vec<f32>> = (0..30).map(|i| vec![i as f32; 32]).collect();
413
414 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
415 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
416
417 let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
418 assert_eq!(output.len(), 32);
419 }
420
421 #[test]
422 fn test_cache_reuse() {
423 let attention = CentroidOTAttention::with_dim(64);
424
425 let keys: Vec<Vec<f32>> = (0..40).map(|i| vec![i as f32 * 0.025; 64]).collect();
426 let values: Vec<Vec<f32>> = (0..40).map(|i| vec![i as f32; 64]).collect();
427
428 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
429 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
430
431 let cache = attention.build_cache(&keys_refs);
433
434 for q in 0..10 {
436 let query = vec![q as f32 * 0.1; 64];
437 let output = attention
438 .compute_with_cache(&query, &cache, &values_refs)
439 .unwrap();
440 assert_eq!(output.len(), 64);
441 }
442 }
443}