Skip to main content

sochdb_vector/
shard_topology.rs

1// Copyright 2025 SochDB Authors
2//
3// Licensed under the Apache License, Version 2.0
4
5//! Shard-First ANN Topology
6//!
7//! This module provides coarse clustering for query routing to minimize
8//! fan-out across shards while maintaining high recall.
9//!
10//! # Problem
11//!
12//! Hash-based sharding forces broad fan-out:
13//! - Query must hit all S shards → S× work
14//! - Throughput limited by slowest shard
15//! - Network/coordination overhead grows with S
16//!
17//! # Solution
18//!
19//! Cluster-based routing:
20//! 1. Build K coarse centroids covering all vectors
21//! 2. Assign each vector to nearest centroid
22//! 3. Route queries to nearest 1-3 centroids only
23//! 4. Balance within clusters for hot/cold patterns
24//!
25//! # Fan-out Analysis
26//!
27//! For S shards and K clusters:
28//! - Hash routing: query all S shards, work = O(S × N/S) = O(N)
29//! - Cluster routing: query ~3 shards, work = O(3 × N/S) = O(3N/S)
30//! - Speedup: S/3 (e.g., 256 shards → 85× less work)
31//!
32//! # Trade-offs
33//!
34//! - Slight recall loss at cluster boundaries (~1-2%)
35//! - Centroid computation adds O(K × D) per query
36//! - Need rebalancing on insert skew
37
38use std::collections::HashMap;
39use std::sync::RwLock;
40
41/// Shard identifier.
42pub type ShardId = u32;
43
44/// Cluster identifier.
45pub type ClusterId = u32;
46
47/// Coarse centroid for cluster routing.
48#[derive(Debug, Clone)]
49pub struct Centroid {
50    /// Cluster ID.
51    pub id: ClusterId,
52    /// Centroid vector.
53    pub vector: Vec<f32>,
54    /// Assigned shards.
55    pub shards: Vec<ShardId>,
56    /// Vector count in this cluster.
57    pub count: usize,
58}
59
60impl Centroid {
61    /// Create a new centroid.
62    pub fn new(id: ClusterId, vector: Vec<f32>) -> Self {
63        Self {
64            id,
65            vector,
66            shards: Vec::new(),
67            count: 0,
68        }
69    }
70
71    /// Compute squared L2 distance to query.
72    #[inline]
73    pub fn distance_squared(&self, query: &[f32]) -> f32 {
74        self.vector
75            .iter()
76            .zip(query.iter())
77            .map(|(&a, &b)| {
78                let d = a - b;
79                d * d
80            })
81            .sum()
82    }
83}
84
85/// Routing decision for a query.
86#[derive(Debug, Clone)]
87pub struct RoutingDecision {
88    /// Target shards to query.
89    pub shards: Vec<ShardId>,
90    /// Distances to cluster centroids.
91    pub distances: Vec<f32>,
92    /// Number of clusters considered.
93    pub clusters_probed: usize,
94}
95
96impl RoutingDecision {
97    /// Estimated work reduction vs full scan.
98    pub fn work_reduction(&self, total_shards: usize) -> f32 {
99        if self.shards.is_empty() {
100            return 1.0;
101        }
102        self.shards.len() as f32 / total_shards as f32
103    }
104}
105
106/// Configuration for shard topology.
107#[derive(Debug, Clone)]
108pub struct TopologyConfig {
109    /// Number of clusters.
110    pub num_clusters: usize,
111    /// Number of shards per cluster.
112    pub shards_per_cluster: usize,
113    /// Number of clusters to probe per query.
114    pub probe_clusters: usize,
115    /// Rebalance threshold (imbalance ratio).
116    pub rebalance_threshold: f32,
117}
118
119impl Default for TopologyConfig {
120    fn default() -> Self {
121        Self {
122            num_clusters: 16,
123            shards_per_cluster: 16,
124            probe_clusters: 2,
125            rebalance_threshold: 2.0,
126        }
127    }
128}
129
130/// Shard topology with cluster-based routing.
131pub struct ShardTopology {
132    /// Cluster centroids.
133    centroids: Vec<Centroid>,
134    /// Shard to cluster mapping.
135    shard_to_cluster: HashMap<ShardId, ClusterId>,
136    /// Configuration.
137    config: TopologyConfig,
138    /// Total shards.
139    total_shards: usize,
140    /// Statistics lock.
141    stats: RwLock<TopologyStats>,
142}
143
144/// Topology statistics.
145#[derive(Debug, Clone, Default)]
146pub struct TopologyStats {
147    /// Total queries routed.
148    pub queries_routed: u64,
149    /// Total shards probed.
150    pub shards_probed: u64,
151    /// Average fan-out.
152    pub avg_fanout: f32,
153    /// Cluster load distribution.
154    pub cluster_loads: Vec<u64>,
155}
156
157impl ShardTopology {
158    /// Create a new topology with given centroids.
159    pub fn new(centroids: Vec<Centroid>, config: TopologyConfig) -> Self {
160        let total_shards = centroids.iter().map(|c| c.shards.len()).sum();
161
162        let mut shard_to_cluster = HashMap::new();
163        for centroid in &centroids {
164            for &shard in &centroid.shards {
165                shard_to_cluster.insert(shard, centroid.id);
166            }
167        }
168
169        let cluster_loads = vec![0; centroids.len()];
170
171        Self {
172            centroids,
173            shard_to_cluster,
174            config,
175            total_shards,
176            stats: RwLock::new(TopologyStats {
177                cluster_loads,
178                ..Default::default()
179            }),
180        }
181    }
182
183    /// Build topology from vectors using k-means clustering.
184    pub fn build_from_vectors(vectors: &[Vec<f32>], config: TopologyConfig) -> Self {
185        if vectors.is_empty() {
186            return Self::empty(config);
187        }
188
189        let dimension = vectors[0].len();
190        let num_clusters = config.num_clusters.min(vectors.len());
191
192        // Simple k-means initialization (random sampling)
193        let mut centroids: Vec<Centroid> = (0..num_clusters)
194            .map(|i| {
195                let idx = (i * vectors.len()) / num_clusters;
196                Centroid::new(i as ClusterId, vectors[idx].clone())
197            })
198            .collect();
199
200        // K-means iterations
201        for _ in 0..10 {
202            // Assign vectors to nearest centroid
203            let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); num_clusters];
204
205            for (vec_idx, vector) in vectors.iter().enumerate() {
206                let nearest = Self::find_nearest_centroid(vector, &centroids);
207                assignments[nearest].push(vec_idx);
208            }
209
210            // Update centroids
211            for (cluster_idx, assigned) in assignments.iter().enumerate() {
212                if assigned.is_empty() {
213                    continue;
214                }
215
216                let mut new_centroid = vec![0.0f32; dimension];
217                for &vec_idx in assigned {
218                    for (i, &v) in vectors[vec_idx].iter().enumerate() {
219                        new_centroid[i] += v;
220                    }
221                }
222
223                let count = assigned.len() as f32;
224                for v in &mut new_centroid {
225                    *v /= count;
226                }
227
228                centroids[cluster_idx].vector = new_centroid;
229                centroids[cluster_idx].count = assigned.len();
230            }
231        }
232
233        // Assign shards to clusters (round-robin for now)
234        let _total_shards = config.num_clusters * config.shards_per_cluster;
235        for (i, centroid) in centroids.iter_mut().enumerate() {
236            let start_shard = i * config.shards_per_cluster;
237            let end_shard = start_shard + config.shards_per_cluster;
238            centroid.shards = (start_shard..end_shard).map(|s| s as ShardId).collect();
239        }
240
241        Self::new(centroids, config)
242    }
243
244    /// Create empty topology.
245    pub fn empty(config: TopologyConfig) -> Self {
246        Self {
247            centroids: Vec::new(),
248            shard_to_cluster: HashMap::new(),
249            config,
250            total_shards: 0,
251            stats: RwLock::new(TopologyStats::default()),
252        }
253    }
254
255    /// Route a query to target shards.
256    pub fn route(&self, query: &[f32]) -> RoutingDecision {
257        if self.centroids.is_empty() {
258            return RoutingDecision {
259                shards: Vec::new(),
260                distances: Vec::new(),
261                clusters_probed: 0,
262            };
263        }
264
265        // Find nearest clusters
266        let mut cluster_dists: Vec<(ClusterId, f32)> = self
267            .centroids
268            .iter()
269            .map(|c| (c.id, c.distance_squared(query)))
270            .collect();
271
272        cluster_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
273
274        // Take top probe_clusters
275        let probe_count = self.config.probe_clusters.min(cluster_dists.len());
276        let probed: Vec<_> = cluster_dists[..probe_count].to_vec();
277
278        // Collect shards from probed clusters
279        let mut shards = Vec::new();
280        let mut distances = Vec::new();
281
282        for (cluster_id, dist) in &probed {
283            if let Some(centroid) = self.centroids.get(*cluster_id as usize) {
284                shards.extend_from_slice(&centroid.shards);
285                distances.push(*dist);
286            }
287        }
288
289        // Update stats
290        if let Ok(mut stats) = self.stats.write() {
291            stats.queries_routed += 1;
292            stats.shards_probed += shards.len() as u64;
293            stats.avg_fanout = stats.shards_probed as f32 / stats.queries_routed as f32;
294
295            for (cluster_id, _) in &probed {
296                if (*cluster_id as usize) < stats.cluster_loads.len() {
297                    stats.cluster_loads[*cluster_id as usize] += 1;
298                }
299            }
300        }
301
302        RoutingDecision {
303            shards,
304            distances,
305            clusters_probed: probe_count,
306        }
307    }
308
309    /// Find which cluster a shard belongs to.
310    pub fn shard_cluster(&self, shard: ShardId) -> Option<ClusterId> {
311        self.shard_to_cluster.get(&shard).copied()
312    }
313
314    /// Get all shards.
315    pub fn all_shards(&self) -> Vec<ShardId> {
316        self.shard_to_cluster.keys().copied().collect()
317    }
318
319    /// Get cluster by ID.
320    pub fn cluster(&self, id: ClusterId) -> Option<&Centroid> {
321        self.centroids.get(id as usize)
322    }
323
324    /// Number of clusters.
325    pub fn num_clusters(&self) -> usize {
326        self.centroids.len()
327    }
328
329    /// Total shards.
330    pub fn num_shards(&self) -> usize {
331        self.total_shards
332    }
333
334    /// Check if rebalancing is needed.
335    pub fn needs_rebalance(&self) -> bool {
336        if self.centroids.len() < 2 {
337            return false;
338        }
339
340        let counts: Vec<usize> = self.centroids.iter().map(|c| c.count).collect();
341        let max_count = *counts.iter().max().unwrap_or(&1) as f32;
342        let min_count = *counts.iter().min().unwrap_or(&1).max(&1) as f32;
343
344        max_count / min_count > self.config.rebalance_threshold
345    }
346
347    /// Get topology statistics.
348    pub fn stats(&self) -> TopologyStats {
349        self.stats.read().unwrap().clone()
350    }
351
352    /// Find nearest centroid index.
353    fn find_nearest_centroid(vector: &[f32], centroids: &[Centroid]) -> usize {
354        centroids
355            .iter()
356            .enumerate()
357            .min_by(|(_, a), (_, b)| {
358                a.distance_squared(vector)
359                    .partial_cmp(&b.distance_squared(vector))
360                    .unwrap()
361            })
362            .map(|(i, _)| i)
363            .unwrap_or(0)
364    }
365}
366
367/// Router for shard-aware ANN search.
368pub struct ShardRouter {
369    /// Topology.
370    topology: ShardTopology,
371    /// Enable adaptive probing.
372    #[allow(dead_code)]
373    adaptive: bool,
374}
375
376impl ShardRouter {
377    /// Create a new router.
378    pub fn new(topology: ShardTopology) -> Self {
379        Self {
380            topology,
381            adaptive: true,
382        }
383    }
384
385    /// Route query with adaptive probe depth.
386    pub fn route_adaptive(&self, query: &[f32], target_recall: f32) -> RoutingDecision {
387        // Adjust probe depth based on target recall
388        let base_probe = self.topology.config.probe_clusters;
389
390        let _probe = if target_recall > 0.99 {
391            // High recall: probe more clusters
392            (base_probe * 2).min(self.topology.num_clusters())
393        } else if target_recall > 0.95 {
394            base_probe
395        } else {
396            // Low recall acceptable: probe fewer
397            (base_probe / 2).max(1)
398        };
399
400        // Temporarily adjust config (clone and modify)
401        let mut decision = self.topology.route(query);
402
403        // For high recall, ensure minimum shard coverage
404        if target_recall > 0.95 && decision.shards.len() < 4 {
405            // Add more shards from nearby clusters
406            decision.shards.extend(
407                self.topology
408                    .all_shards()
409                    .into_iter()
410                    .take(4 - decision.shards.len()),
411            );
412        }
413
414        decision
415    }
416
417    /// Get estimated recall for routing decision.
418    pub fn estimated_recall(&self, decision: &RoutingDecision) -> f32 {
419        if self.topology.num_shards() == 0 {
420            return 0.0;
421        }
422
423        // Simple model: recall ≈ coverage^0.5
424        let coverage = decision.shards.len() as f32 / self.topology.num_shards() as f32;
425        coverage.sqrt().min(1.0)
426    }
427
428    /// Get underlying topology.
429    pub fn topology(&self) -> &ShardTopology {
430        &self.topology
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    fn random_vectors(count: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
439        (0..count)
440            .map(|i| {
441                (0..dim)
442                    .map(|d| {
443                        let x = ((i as u64 * 13 + d as u64 * 7 + seed) % 1000) as f32 / 1000.0;
444                        x * 2.0 - 1.0
445                    })
446                    .collect()
447            })
448            .collect()
449    }
450
451    #[test]
452    fn test_centroid_distance() {
453        let centroid = Centroid::new(0, vec![1.0, 0.0, 0.0]);
454        let query = vec![0.0, 0.0, 0.0];
455
456        assert!((centroid.distance_squared(&query) - 1.0).abs() < 1e-6);
457    }
458
459    #[test]
460    fn test_topology_build() {
461        let vectors = random_vectors(1000, 128, 42);
462        let config = TopologyConfig {
463            num_clusters: 4,
464            shards_per_cluster: 4,
465            probe_clusters: 2,
466            ..Default::default()
467        };
468
469        let topology = ShardTopology::build_from_vectors(&vectors, config);
470
471        assert_eq!(topology.num_clusters(), 4);
472        assert_eq!(topology.num_shards(), 16);
473    }
474
475    #[test]
476    fn test_query_routing() {
477        let vectors = random_vectors(1000, 128, 42);
478        let config = TopologyConfig {
479            num_clusters: 4,
480            shards_per_cluster: 4,
481            probe_clusters: 2,
482            ..Default::default()
483        };
484
485        let topology = ShardTopology::build_from_vectors(&vectors, config);
486        let query = random_vectors(1, 128, 99)[0].clone();
487
488        let decision = topology.route(&query);
489
490        // Should probe 2 clusters × 4 shards = 8 shards
491        assert_eq!(decision.clusters_probed, 2);
492        assert_eq!(decision.shards.len(), 8);
493
494        // Work reduction: 8/16 = 0.5
495        assert!((decision.work_reduction(16) - 0.5).abs() < 1e-6);
496    }
497
498    #[test]
499    fn test_shard_cluster_mapping() {
500        let config = TopologyConfig {
501            num_clusters: 4,
502            shards_per_cluster: 4,
503            ..Default::default()
504        };
505
506        let centroids: Vec<Centroid> = (0..4)
507            .map(|i| {
508                let mut c = Centroid::new(i, vec![i as f32; 128]);
509                c.shards = vec![i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3];
510                c
511            })
512            .collect();
513
514        let topology = ShardTopology::new(centroids, config);
515
516        assert_eq!(topology.shard_cluster(0), Some(0));
517        assert_eq!(topology.shard_cluster(5), Some(1));
518        assert_eq!(topology.shard_cluster(10), Some(2));
519        assert_eq!(topology.shard_cluster(15), Some(3));
520    }
521
522    #[test]
523    fn test_adaptive_routing() {
524        let vectors = random_vectors(1000, 128, 42);
525        let config = TopologyConfig {
526            num_clusters: 8,
527            shards_per_cluster: 4,
528            probe_clusters: 2,
529            ..Default::default()
530        };
531
532        let topology = ShardTopology::build_from_vectors(&vectors, config);
533        let router = ShardRouter::new(topology);
534        let query = random_vectors(1, 128, 99)[0].clone();
535
536        // Low recall: fewer shards
537        let low_recall = router.route_adaptive(&query, 0.80);
538
539        // High recall: more shards
540        let high_recall = router.route_adaptive(&query, 0.99);
541
542        // High recall should probe at least as many shards
543        assert!(high_recall.shards.len() >= low_recall.shards.len());
544    }
545
546    #[test]
547    fn test_empty_topology() {
548        let config = TopologyConfig::default();
549        let topology = ShardTopology::empty(config);
550
551        assert_eq!(topology.num_clusters(), 0);
552        assert_eq!(topology.num_shards(), 0);
553
554        let decision = topology.route(&[0.0, 0.0, 0.0]);
555        assert!(decision.shards.is_empty());
556    }
557
558    #[test]
559    fn test_stats_tracking() {
560        let vectors = random_vectors(1000, 128, 42);
561        let config = TopologyConfig {
562            num_clusters: 4,
563            shards_per_cluster: 4,
564            probe_clusters: 2,
565            ..Default::default()
566        };
567
568        let topology = ShardTopology::build_from_vectors(&vectors, config);
569
570        // Route multiple queries
571        for i in 0..10 {
572            let query = random_vectors(1, 128, i)[0].clone();
573            topology.route(&query);
574        }
575
576        let stats = topology.stats();
577        assert_eq!(stats.queries_routed, 10);
578        assert!(stats.avg_fanout > 0.0);
579    }
580
581    #[test]
582    fn test_rebalance_detection() {
583        let mut centroids: Vec<Centroid> = (0..4)
584            .map(|i| {
585                let mut c = Centroid::new(i, vec![i as f32; 128]);
586                c.shards = vec![i * 4];
587                c.count = if i == 0 { 1000 } else { 100 }; // Imbalanced
588                c
589            })
590            .collect();
591
592        let config = TopologyConfig {
593            rebalance_threshold: 2.0,
594            ..Default::default()
595        };
596
597        let topology = ShardTopology::new(centroids, config);
598        assert!(topology.needs_rebalance());
599    }
600
601    #[test]
602    fn test_estimated_recall() {
603        let vectors = random_vectors(100, 128, 42);
604        let config = TopologyConfig {
605            num_clusters: 4,
606            shards_per_cluster: 4,
607            probe_clusters: 2,
608            ..Default::default()
609        };
610
611        let topology = ShardTopology::build_from_vectors(&vectors, config);
612        let router = ShardRouter::new(topology);
613        let query = random_vectors(1, 128, 99)[0].clone();
614
615        let decision = router.topology().route(&query);
616        let recall = router.estimated_recall(&decision);
617
618        // With 50% shard coverage, recall ≈ sqrt(0.5) ≈ 0.707
619        assert!(recall > 0.5 && recall < 1.0);
620    }
621}