ruvector_attention/transport/
centroid_ot.rs

1//! Centroid-Based Optimal Transport Attention
2//!
3//! Clusters keys into M centroids and computes OT between query and centroids.
4//! Much faster than full pairwise OT.
5//!
6//! ## Algorithm
7//!
8//! 1. Cluster keys into M centroids using k-means
9//! 2. Store centroid vectors and weights (fraction of keys in each cluster)
10//! 3. For each query, compute transport to centroid distribution
11//! 4. Convert transport cost to attention logits
12
13use crate::error::{AttentionError, AttentionResult};
14use crate::traits::Attention;
15use serde::{Deserialize, Serialize};
16
17/// Configuration for Centroid OT Attention
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CentroidOTConfig {
20    /// Model dimension
21    pub dim: usize,
22    /// Number of centroids (16-32 typical)
23    pub num_centroids: usize,
24    /// Number of k-means iterations
25    pub kmeans_iterations: usize,
26    /// Temperature for softmax
27    pub temperature: f32,
28    /// Regularization for Sinkhorn (0.1 typical)
29    pub sinkhorn_reg: f32,
30    /// Max Sinkhorn iterations
31    pub sinkhorn_iterations: usize,
32    /// Random seed
33    pub seed: u64,
34}
35
36impl Default for CentroidOTConfig {
37    fn default() -> Self {
38        Self {
39            dim: 512,
40            num_centroids: 16,
41            kmeans_iterations: 10,
42            temperature: 1.0,
43            sinkhorn_reg: 0.1,
44            sinkhorn_iterations: 20,
45            seed: 42,
46        }
47    }
48}
49
50/// Cached centroid information for a window
51#[derive(Debug, Clone)]
52pub struct CentroidCache {
53    /// Centroid vectors [M × dim]
54    pub centroids: Vec<Vec<f32>>,
55    /// Weights for each centroid (sum to 1)
56    pub weights: Vec<f32>,
57    /// Assignment of each key to centroid
58    pub assignments: Vec<usize>,
59    /// Number of keys
60    pub num_keys: usize,
61}
62
63impl CentroidCache {
64    /// Build centroid cache using k-means
65    pub fn build(keys: &[&[f32]], num_centroids: usize, iterations: usize, seed: u64) -> Self {
66        let num_keys = keys.len();
67        let m = num_centroids.min(num_keys);
68
69        if num_keys == 0 || keys[0].is_empty() {
70            return Self {
71                centroids: vec![],
72                weights: vec![],
73                assignments: vec![],
74                num_keys: 0,
75            };
76        }
77
78        let dim = keys[0].len();
79
80        // Initialize centroids with random keys
81        use rand::prelude::*;
82        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
83        let mut indices: Vec<usize> = (0..num_keys).collect();
84        indices.shuffle(&mut rng);
85
86        let mut centroids: Vec<Vec<f32>> = indices
87            .iter()
88            .take(m)
89            .map(|&i| keys[i].to_vec())
90            .collect();
91
92        let mut assignments = vec![0usize; num_keys];
93
94        // K-means iterations
95        for _ in 0..iterations {
96            // Assign each key to nearest centroid
97            for (key_idx, key) in keys.iter().enumerate() {
98                let mut min_dist = f32::MAX;
99                let mut best_centroid = 0;
100
101                for (c_idx, centroid) in centroids.iter().enumerate() {
102                    let dist = Self::squared_distance(key, centroid);
103                    if dist < min_dist {
104                        min_dist = dist;
105                        best_centroid = c_idx;
106                    }
107                }
108
109                assignments[key_idx] = best_centroid;
110            }
111
112            // Update centroids
113            let mut new_centroids = vec![vec![0.0f32; dim]; m];
114            let mut counts = vec![0usize; m];
115
116            for (key_idx, &assignment) in assignments.iter().enumerate() {
117                counts[assignment] += 1;
118                for (d, &v) in keys[key_idx].iter().enumerate() {
119                    new_centroids[assignment][d] += v;
120                }
121            }
122
123            for c_idx in 0..m {
124                if counts[c_idx] > 0 {
125                    for d in 0..dim {
126                        new_centroids[c_idx][d] /= counts[c_idx] as f32;
127                    }
128                    centroids[c_idx] = new_centroids[c_idx].clone();
129                }
130            }
131        }
132
133        // Compute weights
134        let mut counts = vec![0usize; m];
135        for &a in &assignments {
136            counts[a] += 1;
137        }
138        let weights: Vec<f32> = counts
139            .iter()
140            .map(|&c| c as f32 / num_keys as f32)
141            .collect();
142
143        Self {
144            centroids,
145            weights,
146            assignments,
147            num_keys,
148        }
149    }
150
151    /// Squared Euclidean distance (SIMD-friendly)
152    #[inline]
153    fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
154        let len = a.len();
155        let chunks = len / 4;
156        let remainder = len % 4;
157
158        let mut sum0 = 0.0f32;
159        let mut sum1 = 0.0f32;
160        let mut sum2 = 0.0f32;
161        let mut sum3 = 0.0f32;
162
163        for i in 0..chunks {
164            let base = i * 4;
165            let d0 = a[base] - b[base];
166            let d1 = a[base + 1] - b[base + 1];
167            let d2 = a[base + 2] - b[base + 2];
168            let d3 = a[base + 3] - b[base + 3];
169            sum0 += d0 * d0;
170            sum1 += d1 * d1;
171            sum2 += d2 * d2;
172            sum3 += d3 * d3;
173        }
174
175        let base = chunks * 4;
176        for i in 0..remainder {
177            let d = a[base + i] - b[base + i];
178            sum0 += d * d;
179        }
180
181        sum0 + sum1 + sum2 + sum3
182    }
183}
184
185/// Centroid-based OT Attention
186///
187/// Computes attention by finding optimal transport between query and
188/// centroid distribution, then distributing attention to original keys.
189#[derive(Debug, Clone)]
190pub struct CentroidOTAttention {
191    config: CentroidOTConfig,
192}
193
194impl CentroidOTAttention {
195    /// Create new Centroid OT attention
196    pub fn new(config: CentroidOTConfig) -> Self {
197        Self { config }
198    }
199
200    /// Create with dimension only
201    pub fn with_dim(dim: usize) -> Self {
202        Self::new(CentroidOTConfig {
203            dim,
204            ..Default::default()
205        })
206    }
207
208    /// Build centroid cache for a window
209    pub fn build_cache(&self, keys: &[&[f32]]) -> CentroidCache {
210        CentroidCache::build(
211            keys,
212            self.config.num_centroids,
213            self.config.kmeans_iterations,
214            self.config.seed,
215        )
216    }
217
218    /// Compute attention using cached centroids
219    pub fn compute_with_cache(
220        &self,
221        query: &[f32],
222        cache: &CentroidCache,
223        values: &[&[f32]],
224    ) -> AttentionResult<Vec<f32>> {
225        if cache.centroids.is_empty() {
226            return Err(AttentionError::InvalidConfig("Empty cache".into()));
227        }
228
229        // Compute distances from query to each centroid
230        let centroid_distances: Vec<f32> = cache
231            .centroids
232            .iter()
233            .map(|c| CentroidCache::squared_distance(query, c).sqrt())
234            .collect();
235
236        // Convert to centroid attention weights
237        let centroid_logits: Vec<f32> = centroid_distances
238            .iter()
239            .map(|d| -d / self.config.temperature)
240            .collect();
241
242        let centroid_weights = Self::stable_softmax(&centroid_logits);
243
244        // Distribute centroid weights to original keys
245        let mut key_weights = vec![0.0f32; cache.num_keys];
246        for (key_idx, &assignment) in cache.assignments.iter().enumerate() {
247            // Key weight = centroid weight / number of keys in cluster
248            let cluster_size = cache.assignments.iter().filter(|&&a| a == assignment).count();
249            if cluster_size > 0 {
250                key_weights[key_idx] = centroid_weights[assignment] / cluster_size as f32;
251            }
252        }
253
254        // Weighted sum of values
255        self.weighted_sum(&key_weights, values)
256    }
257
258    /// Fast Sinkhorn transport (simplified for point-to-distribution)
259    #[allow(dead_code)]
260    fn sinkhorn_distance(&self, query: &[f32], cache: &CentroidCache) -> f32 {
261        let m = cache.centroids.len();
262        if m == 0 {
263            return 0.0;
264        }
265
266        // Cost matrix: 1 × M (query to each centroid)
267        let costs: Vec<f32> = cache
268            .centroids
269            .iter()
270            .map(|c| CentroidCache::squared_distance(query, c))
271            .collect();
272
273        // Source is delta at query (weight 1)
274        // Target is centroid distribution (cache.weights)
275
276        // Log-domain Sinkhorn
277        let reg = self.config.sinkhorn_reg;
278        let log_k: Vec<f32> = costs.iter().map(|c| -c / reg).collect();
279
280        let mut log_v = vec![0.0f32; m];
281        let log_b: Vec<f32> = cache.weights.iter().map(|w| w.ln().max(-20.0)).collect();
282
283        for _ in 0..self.config.sinkhorn_iterations {
284            // Update log_v
285            let log_sum: f32 = log_k
286                .iter()
287                .zip(log_v.iter())
288                .map(|(&lk, &lv)| lk + lv)
289                .fold(f32::NEG_INFINITY, |max, x| if x > max { x } else { max });
290
291            let exp_sum: f32 = log_k
292                .iter()
293                .zip(log_v.iter())
294                .map(|(&lk, &lv)| (lk + lv - log_sum).exp())
295                .sum();
296
297            let log_u = -log_sum - exp_sum.ln();
298
299            // Update log_v
300            for j in 0..m {
301                log_v[j] = log_b[j] - (log_u + log_k[j]);
302            }
303        }
304
305        // Compute transport cost
306        let mut total_cost = 0.0f32;
307        for j in 0..m {
308            let gamma = (log_v[j] + log_k[j]).exp();
309            total_cost += gamma * costs[j];
310        }
311
312        total_cost
313    }
314
315    /// Stable softmax
316    fn stable_softmax(logits: &[f32]) -> Vec<f32> {
317        if logits.is_empty() {
318            return vec![];
319        }
320
321        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
322        let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
323        let sum: f32 = exp_logits.iter().sum();
324
325        exp_logits.iter().map(|&e| e / sum).collect()
326    }
327
328    /// Weighted sum of values
329    fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
330        if weights.is_empty() || values.is_empty() {
331            return Err(AttentionError::InvalidConfig("Empty inputs".into()));
332        }
333
334        let dim = values[0].len();
335        let mut output = vec![0.0f32; dim];
336
337        for (weight, value) in weights.iter().zip(values.iter()) {
338            for (o, &v) in output.iter_mut().zip(value.iter()) {
339                *o += weight * v;
340            }
341        }
342
343        Ok(output)
344    }
345}
346
347impl Attention for CentroidOTAttention {
348    fn compute(
349        &self,
350        query: &[f32],
351        keys: &[&[f32]],
352        values: &[&[f32]],
353    ) -> AttentionResult<Vec<f32>> {
354        let cache = self.build_cache(keys);
355        self.compute_with_cache(query, &cache, values)
356    }
357
358    fn compute_with_mask(
359        &self,
360        query: &[f32],
361        keys: &[&[f32]],
362        values: &[&[f32]],
363        mask: Option<&[bool]>,
364    ) -> AttentionResult<Vec<f32>> {
365        if let Some(m) = mask {
366            let filtered: Vec<(&[f32], &[f32])> = keys
367                .iter()
368                .zip(values.iter())
369                .enumerate()
370                .filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
371                .map(|(_, (k, v))| (*k, *v))
372                .collect();
373
374            let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
375            let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
376
377            self.compute(query, &filtered_keys, &filtered_values)
378        } else {
379            self.compute(query, keys, values)
380        }
381    }
382
383    fn dim(&self) -> usize {
384        self.config.dim
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_centroid_cache() {
394        let keys: Vec<Vec<f32>> = (0..50)
395            .map(|i| vec![i as f32 * 0.1; 32])
396            .collect();
397        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
398
399        let cache = CentroidCache::build(&keys_refs, 8, 5, 42);
400
401        assert_eq!(cache.centroids.len(), 8);
402        assert_eq!(cache.weights.len(), 8);
403        assert_eq!(cache.assignments.len(), 50);
404
405        // Weights should sum to 1
406        let weight_sum: f32 = cache.weights.iter().sum();
407        assert!((weight_sum - 1.0).abs() < 1e-5);
408    }
409
410    #[test]
411    fn test_centroid_ot_attention() {
412        let attention = CentroidOTAttention::with_dim(32);
413
414        let query = vec![0.5f32; 32];
415        let keys: Vec<Vec<f32>> = (0..30)
416            .map(|i| vec![i as f32 * 0.05; 32])
417            .collect();
418        let values: Vec<Vec<f32>> = (0..30)
419            .map(|i| vec![i as f32; 32])
420            .collect();
421
422        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
423        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
424
425        let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
426        assert_eq!(output.len(), 32);
427    }
428
429    #[test]
430    fn test_cache_reuse() {
431        let attention = CentroidOTAttention::with_dim(64);
432
433        let keys: Vec<Vec<f32>> = (0..40)
434            .map(|i| vec![i as f32 * 0.025; 64])
435            .collect();
436        let values: Vec<Vec<f32>> = (0..40)
437            .map(|i| vec![i as f32; 64])
438            .collect();
439
440        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
441        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
442
443        // Build cache once
444        let cache = attention.build_cache(&keys_refs);
445
446        // Reuse for multiple queries
447        for q in 0..10 {
448            let query = vec![q as f32 * 0.1; 64];
449            let output = attention.compute_with_cache(&query, &cache, &values_refs).unwrap();
450            assert_eq!(output.len(), 64);
451        }
452    }
453}