1use std::collections::HashMap;
39use std::sync::RwLock;
40
41pub type ShardId = u32;
43
44pub type ClusterId = u32;
46
47#[derive(Debug, Clone)]
49pub struct Centroid {
50 pub id: ClusterId,
52 pub vector: Vec<f32>,
54 pub shards: Vec<ShardId>,
56 pub count: usize,
58}
59
60impl Centroid {
61 pub fn new(id: ClusterId, vector: Vec<f32>) -> Self {
63 Self {
64 id,
65 vector,
66 shards: Vec::new(),
67 count: 0,
68 }
69 }
70
71 #[inline]
73 pub fn distance_squared(&self, query: &[f32]) -> f32 {
74 self.vector
75 .iter()
76 .zip(query.iter())
77 .map(|(&a, &b)| {
78 let d = a - b;
79 d * d
80 })
81 .sum()
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct RoutingDecision {
88 pub shards: Vec<ShardId>,
90 pub distances: Vec<f32>,
92 pub clusters_probed: usize,
94}
95
96impl RoutingDecision {
97 pub fn work_reduction(&self, total_shards: usize) -> f32 {
99 if self.shards.is_empty() {
100 return 1.0;
101 }
102 self.shards.len() as f32 / total_shards as f32
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct TopologyConfig {
109 pub num_clusters: usize,
111 pub shards_per_cluster: usize,
113 pub probe_clusters: usize,
115 pub rebalance_threshold: f32,
117}
118
119impl Default for TopologyConfig {
120 fn default() -> Self {
121 Self {
122 num_clusters: 16,
123 shards_per_cluster: 16,
124 probe_clusters: 2,
125 rebalance_threshold: 2.0,
126 }
127 }
128}
129
130pub struct ShardTopology {
132 centroids: Vec<Centroid>,
134 shard_to_cluster: HashMap<ShardId, ClusterId>,
136 config: TopologyConfig,
138 total_shards: usize,
140 stats: RwLock<TopologyStats>,
142}
143
144#[derive(Debug, Clone, Default)]
146pub struct TopologyStats {
147 pub queries_routed: u64,
149 pub shards_probed: u64,
151 pub avg_fanout: f32,
153 pub cluster_loads: Vec<u64>,
155}
156
157impl ShardTopology {
158 pub fn new(centroids: Vec<Centroid>, config: TopologyConfig) -> Self {
160 let total_shards = centroids.iter().map(|c| c.shards.len()).sum();
161
162 let mut shard_to_cluster = HashMap::new();
163 for centroid in ¢roids {
164 for &shard in ¢roid.shards {
165 shard_to_cluster.insert(shard, centroid.id);
166 }
167 }
168
169 let cluster_loads = vec![0; centroids.len()];
170
171 Self {
172 centroids,
173 shard_to_cluster,
174 config,
175 total_shards,
176 stats: RwLock::new(TopologyStats {
177 cluster_loads,
178 ..Default::default()
179 }),
180 }
181 }
182
183 pub fn build_from_vectors(vectors: &[Vec<f32>], config: TopologyConfig) -> Self {
185 if vectors.is_empty() {
186 return Self::empty(config);
187 }
188
189 let dimension = vectors[0].len();
190 let num_clusters = config.num_clusters.min(vectors.len());
191
192 let mut centroids: Vec<Centroid> = (0..num_clusters)
194 .map(|i| {
195 let idx = (i * vectors.len()) / num_clusters;
196 Centroid::new(i as ClusterId, vectors[idx].clone())
197 })
198 .collect();
199
200 for _ in 0..10 {
202 let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); num_clusters];
204
205 for (vec_idx, vector) in vectors.iter().enumerate() {
206 let nearest = Self::find_nearest_centroid(vector, ¢roids);
207 assignments[nearest].push(vec_idx);
208 }
209
210 for (cluster_idx, assigned) in assignments.iter().enumerate() {
212 if assigned.is_empty() {
213 continue;
214 }
215
216 let mut new_centroid = vec![0.0f32; dimension];
217 for &vec_idx in assigned {
218 for (i, &v) in vectors[vec_idx].iter().enumerate() {
219 new_centroid[i] += v;
220 }
221 }
222
223 let count = assigned.len() as f32;
224 for v in &mut new_centroid {
225 *v /= count;
226 }
227
228 centroids[cluster_idx].vector = new_centroid;
229 centroids[cluster_idx].count = assigned.len();
230 }
231 }
232
233 let _total_shards = config.num_clusters * config.shards_per_cluster;
235 for (i, centroid) in centroids.iter_mut().enumerate() {
236 let start_shard = i * config.shards_per_cluster;
237 let end_shard = start_shard + config.shards_per_cluster;
238 centroid.shards = (start_shard..end_shard).map(|s| s as ShardId).collect();
239 }
240
241 Self::new(centroids, config)
242 }
243
244 pub fn empty(config: TopologyConfig) -> Self {
246 Self {
247 centroids: Vec::new(),
248 shard_to_cluster: HashMap::new(),
249 config,
250 total_shards: 0,
251 stats: RwLock::new(TopologyStats::default()),
252 }
253 }
254
255 pub fn route(&self, query: &[f32]) -> RoutingDecision {
257 if self.centroids.is_empty() {
258 return RoutingDecision {
259 shards: Vec::new(),
260 distances: Vec::new(),
261 clusters_probed: 0,
262 };
263 }
264
265 let mut cluster_dists: Vec<(ClusterId, f32)> = self
267 .centroids
268 .iter()
269 .map(|c| (c.id, c.distance_squared(query)))
270 .collect();
271
272 cluster_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
273
274 let probe_count = self.config.probe_clusters.min(cluster_dists.len());
276 let probed: Vec<_> = cluster_dists[..probe_count].to_vec();
277
278 let mut shards = Vec::new();
280 let mut distances = Vec::new();
281
282 for (cluster_id, dist) in &probed {
283 if let Some(centroid) = self.centroids.get(*cluster_id as usize) {
284 shards.extend_from_slice(¢roid.shards);
285 distances.push(*dist);
286 }
287 }
288
289 if let Ok(mut stats) = self.stats.write() {
291 stats.queries_routed += 1;
292 stats.shards_probed += shards.len() as u64;
293 stats.avg_fanout = stats.shards_probed as f32 / stats.queries_routed as f32;
294
295 for (cluster_id, _) in &probed {
296 if (*cluster_id as usize) < stats.cluster_loads.len() {
297 stats.cluster_loads[*cluster_id as usize] += 1;
298 }
299 }
300 }
301
302 RoutingDecision {
303 shards,
304 distances,
305 clusters_probed: probe_count,
306 }
307 }
308
309 pub fn shard_cluster(&self, shard: ShardId) -> Option<ClusterId> {
311 self.shard_to_cluster.get(&shard).copied()
312 }
313
314 pub fn all_shards(&self) -> Vec<ShardId> {
316 self.shard_to_cluster.keys().copied().collect()
317 }
318
319 pub fn cluster(&self, id: ClusterId) -> Option<&Centroid> {
321 self.centroids.get(id as usize)
322 }
323
324 pub fn num_clusters(&self) -> usize {
326 self.centroids.len()
327 }
328
329 pub fn num_shards(&self) -> usize {
331 self.total_shards
332 }
333
334 pub fn needs_rebalance(&self) -> bool {
336 if self.centroids.len() < 2 {
337 return false;
338 }
339
340 let counts: Vec<usize> = self.centroids.iter().map(|c| c.count).collect();
341 let max_count = *counts.iter().max().unwrap_or(&1) as f32;
342 let min_count = *counts.iter().min().unwrap_or(&1).max(&1) as f32;
343
344 max_count / min_count > self.config.rebalance_threshold
345 }
346
347 pub fn stats(&self) -> TopologyStats {
349 self.stats.read().unwrap().clone()
350 }
351
352 fn find_nearest_centroid(vector: &[f32], centroids: &[Centroid]) -> usize {
354 centroids
355 .iter()
356 .enumerate()
357 .min_by(|(_, a), (_, b)| {
358 a.distance_squared(vector)
359 .partial_cmp(&b.distance_squared(vector))
360 .unwrap()
361 })
362 .map(|(i, _)| i)
363 .unwrap_or(0)
364 }
365}
366
367pub struct ShardRouter {
369 topology: ShardTopology,
371 #[allow(dead_code)]
373 adaptive: bool,
374}
375
376impl ShardRouter {
377 pub fn new(topology: ShardTopology) -> Self {
379 Self {
380 topology,
381 adaptive: true,
382 }
383 }
384
385 pub fn route_adaptive(&self, query: &[f32], target_recall: f32) -> RoutingDecision {
387 let base_probe = self.topology.config.probe_clusters;
389
390 let _probe = if target_recall > 0.99 {
391 (base_probe * 2).min(self.topology.num_clusters())
393 } else if target_recall > 0.95 {
394 base_probe
395 } else {
396 (base_probe / 2).max(1)
398 };
399
400 let mut decision = self.topology.route(query);
402
403 if target_recall > 0.95 && decision.shards.len() < 4 {
405 decision.shards.extend(
407 self.topology
408 .all_shards()
409 .into_iter()
410 .take(4 - decision.shards.len()),
411 );
412 }
413
414 decision
415 }
416
417 pub fn estimated_recall(&self, decision: &RoutingDecision) -> f32 {
419 if self.topology.num_shards() == 0 {
420 return 0.0;
421 }
422
423 let coverage = decision.shards.len() as f32 / self.topology.num_shards() as f32;
425 coverage.sqrt().min(1.0)
426 }
427
428 pub fn topology(&self) -> &ShardTopology {
430 &self.topology
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 fn random_vectors(count: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
439 (0..count)
440 .map(|i| {
441 (0..dim)
442 .map(|d| {
443 let x = ((i as u64 * 13 + d as u64 * 7 + seed) % 1000) as f32 / 1000.0;
444 x * 2.0 - 1.0
445 })
446 .collect()
447 })
448 .collect()
449 }
450
451 #[test]
452 fn test_centroid_distance() {
453 let centroid = Centroid::new(0, vec![1.0, 0.0, 0.0]);
454 let query = vec![0.0, 0.0, 0.0];
455
456 assert!((centroid.distance_squared(&query) - 1.0).abs() < 1e-6);
457 }
458
459 #[test]
460 fn test_topology_build() {
461 let vectors = random_vectors(1000, 128, 42);
462 let config = TopologyConfig {
463 num_clusters: 4,
464 shards_per_cluster: 4,
465 probe_clusters: 2,
466 ..Default::default()
467 };
468
469 let topology = ShardTopology::build_from_vectors(&vectors, config);
470
471 assert_eq!(topology.num_clusters(), 4);
472 assert_eq!(topology.num_shards(), 16);
473 }
474
475 #[test]
476 fn test_query_routing() {
477 let vectors = random_vectors(1000, 128, 42);
478 let config = TopologyConfig {
479 num_clusters: 4,
480 shards_per_cluster: 4,
481 probe_clusters: 2,
482 ..Default::default()
483 };
484
485 let topology = ShardTopology::build_from_vectors(&vectors, config);
486 let query = random_vectors(1, 128, 99)[0].clone();
487
488 let decision = topology.route(&query);
489
490 assert_eq!(decision.clusters_probed, 2);
492 assert_eq!(decision.shards.len(), 8);
493
494 assert!((decision.work_reduction(16) - 0.5).abs() < 1e-6);
496 }
497
498 #[test]
499 fn test_shard_cluster_mapping() {
500 let config = TopologyConfig {
501 num_clusters: 4,
502 shards_per_cluster: 4,
503 ..Default::default()
504 };
505
506 let centroids: Vec<Centroid> = (0..4)
507 .map(|i| {
508 let mut c = Centroid::new(i, vec![i as f32; 128]);
509 c.shards = vec![i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3];
510 c
511 })
512 .collect();
513
514 let topology = ShardTopology::new(centroids, config);
515
516 assert_eq!(topology.shard_cluster(0), Some(0));
517 assert_eq!(topology.shard_cluster(5), Some(1));
518 assert_eq!(topology.shard_cluster(10), Some(2));
519 assert_eq!(topology.shard_cluster(15), Some(3));
520 }
521
522 #[test]
523 fn test_adaptive_routing() {
524 let vectors = random_vectors(1000, 128, 42);
525 let config = TopologyConfig {
526 num_clusters: 8,
527 shards_per_cluster: 4,
528 probe_clusters: 2,
529 ..Default::default()
530 };
531
532 let topology = ShardTopology::build_from_vectors(&vectors, config);
533 let router = ShardRouter::new(topology);
534 let query = random_vectors(1, 128, 99)[0].clone();
535
536 let low_recall = router.route_adaptive(&query, 0.80);
538
539 let high_recall = router.route_adaptive(&query, 0.99);
541
542 assert!(high_recall.shards.len() >= low_recall.shards.len());
544 }
545
546 #[test]
547 fn test_empty_topology() {
548 let config = TopologyConfig::default();
549 let topology = ShardTopology::empty(config);
550
551 assert_eq!(topology.num_clusters(), 0);
552 assert_eq!(topology.num_shards(), 0);
553
554 let decision = topology.route(&[0.0, 0.0, 0.0]);
555 assert!(decision.shards.is_empty());
556 }
557
558 #[test]
559 fn test_stats_tracking() {
560 let vectors = random_vectors(1000, 128, 42);
561 let config = TopologyConfig {
562 num_clusters: 4,
563 shards_per_cluster: 4,
564 probe_clusters: 2,
565 ..Default::default()
566 };
567
568 let topology = ShardTopology::build_from_vectors(&vectors, config);
569
570 for i in 0..10 {
572 let query = random_vectors(1, 128, i)[0].clone();
573 topology.route(&query);
574 }
575
576 let stats = topology.stats();
577 assert_eq!(stats.queries_routed, 10);
578 assert!(stats.avg_fanout > 0.0);
579 }
580
581 #[test]
582 fn test_rebalance_detection() {
583 let mut centroids: Vec<Centroid> = (0..4)
584 .map(|i| {
585 let mut c = Centroid::new(i, vec![i as f32; 128]);
586 c.shards = vec![i * 4];
587 c.count = if i == 0 { 1000 } else { 100 }; c
589 })
590 .collect();
591
592 let config = TopologyConfig {
593 rebalance_threshold: 2.0,
594 ..Default::default()
595 };
596
597 let topology = ShardTopology::new(centroids, config);
598 assert!(topology.needs_rebalance());
599 }
600
601 #[test]
602 fn test_estimated_recall() {
603 let vectors = random_vectors(100, 128, 42);
604 let config = TopologyConfig {
605 num_clusters: 4,
606 shards_per_cluster: 4,
607 probe_clusters: 2,
608 ..Default::default()
609 };
610
611 let topology = ShardTopology::build_from_vectors(&vectors, config);
612 let router = ShardRouter::new(topology);
613 let query = random_vectors(1, 128, 99)[0].clone();
614
615 let decision = router.topology().route(&query);
616 let recall = router.estimated_recall(&decision);
617
618 assert!(recall > 0.5 && recall < 1.0);
620 }
621}