sochdb_vector/list_bounds.rs
1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # Cosine/Dot (MIPS) List Bounds (Task 3)
19//!
20//! Provides computable upper bounds for cosine/dot similarity over IVF lists,
21//! enabling best-first probing and bound-based termination for non-L2 metrics.
22//!
23//! ## Architecture
24//!
25//! For cosine/dot similarity, we store spherical cap metadata per list:
26//! - Centroid direction c (unit vector)
27//! - Max angular deviation θ_max (or min dot to centroid)
28//!
29//! ## Math/Algorithm
30//!
31//! For normalized queries q, the upper bound on achievable similarity in list L is:
32//!
33//! ```text
34//! max_{v∈L} q·v ≤ cos(max(0, angle(q,c) - θ_max))
35//! ```
36//!
37//! where:
38//! - angle(q,c) = arccos(q·c)
39//! - θ_max = max_{v∈L} arccos(v·c)
40//!
41//! Bound evaluation is O(1) per list after precomputing q·c.
42//!
43//! ## Usage
44//!
45//! ```rust,ignore
46//! use sochdb_vector::list_bounds::{SphericalCapMetadata, ListBoundComputer};
47//!
48//! // Build metadata during indexing
49//! let metadata = SphericalCapMetadata::from_vectors(&vectors, centroid);
50//!
51//! // At query time, compute bounds
52//! let computer = ListBoundComputer::new(&query);
53//! let bound = computer.upper_bound(&metadata);
54//! ```
55
56use std::f32::consts::PI;
57
58// ============================================================================
59// Distance Metric
60// ============================================================================
61
62/// Supported distance metrics with bound computation methods
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum DistanceMetric {
65 /// L2 (Euclidean) distance - smaller is better
66 L2,
67 /// Cosine similarity - larger is better (1 = identical)
68 Cosine,
69 /// Inner product (dot) - larger is better
70 InnerProduct,
71 /// Negative inner product (for MIPS via nearest neighbor)
72 NegativeInnerProduct,
73}
74
75impl DistanceMetric {
76 /// Returns true if higher scores are better
77 pub fn higher_is_better(&self) -> bool {
78 matches!(self, Self::Cosine | Self::InnerProduct)
79 }
80
81 /// Returns true if vectors should be normalized
82 pub fn requires_normalization(&self) -> bool {
83 matches!(self, Self::Cosine)
84 }
85}
86
87// ============================================================================
88// Spherical Cap Metadata
89// ============================================================================
90
91/// Spherical cap metadata for a list/partition
92///
93/// Represents the region of the unit sphere covered by vectors in this list.
94/// Used to compute tight bounds for cosine/dot similarity.
95#[derive(Debug, Clone)]
96pub struct SphericalCapMetadata {
97 /// Centroid direction (unit vector)
98 pub centroid: Vec<f32>,
99
100 /// Maximum angular deviation from centroid (in radians)
101 /// θ_max = max_{v∈L} arccos(v·c)
102 pub theta_max: f32,
103
104 /// Minimum dot product with centroid
105 /// min_dot = min_{v∈L} v·c = cos(θ_max)
106 pub min_dot_to_centroid: f32,
107
108 /// Maximum dot product with centroid (typically ~1.0 for tight clusters)
109 pub max_dot_to_centroid: f32,
110
111 /// Number of vectors in this list
112 pub vector_count: u32,
113
114 /// Mean dot product with centroid (for statistics)
115 pub mean_dot_to_centroid: f32,
116}
117
118impl SphericalCapMetadata {
119 /// Build metadata from a set of normalized vectors and their centroid
120 ///
121 /// # Arguments
122 /// * `vectors` - Normalized vectors in the list (each row is a vector)
123 /// * `centroid` - Normalized centroid of the list
124 ///
125 /// # Complexity
126 /// O(n × d) where n = number of vectors, d = dimension
127 pub fn from_vectors(vectors: &[Vec<f32>], centroid: &[f32]) -> Self {
128 if vectors.is_empty() {
129 return Self {
130 centroid: centroid.to_vec(),
131 theta_max: 0.0,
132 min_dot_to_centroid: 1.0,
133 max_dot_to_centroid: 1.0,
134 vector_count: 0,
135 mean_dot_to_centroid: 1.0,
136 };
137 }
138
139 let mut min_dot = f32::MAX;
140 let mut max_dot = f32::MIN;
141 let mut sum_dot = 0.0;
142
143 for v in vectors {
144 let dot = dot_product(v, centroid);
145 min_dot = min_dot.min(dot);
146 max_dot = max_dot.max(dot);
147 sum_dot += dot;
148 }
149
150 // Clamp to valid range for arccos
151 let clamped_min = min_dot.clamp(-1.0, 1.0);
152 let theta_max = clamped_min.acos();
153
154 Self {
155 centroid: centroid.to_vec(),
156 theta_max,
157 min_dot_to_centroid: min_dot,
158 max_dot_to_centroid: max_dot,
159 vector_count: vectors.len() as u32,
160 mean_dot_to_centroid: sum_dot / vectors.len() as f32,
161 }
162 }
163
164 /// Build metadata from flat vector data
165 pub fn from_flat_vectors(data: &[f32], dim: usize, centroid: &[f32]) -> Self {
166 let n_vectors = data.len() / dim;
167
168 if n_vectors == 0 {
169 return Self {
170 centroid: centroid.to_vec(),
171 theta_max: 0.0,
172 min_dot_to_centroid: 1.0,
173 max_dot_to_centroid: 1.0,
174 vector_count: 0,
175 mean_dot_to_centroid: 1.0,
176 };
177 }
178
179 let mut min_dot = f32::MAX;
180 let mut max_dot = f32::MIN;
181 let mut sum_dot = 0.0;
182
183 for i in 0..n_vectors {
184 let v = &data[i * dim..(i + 1) * dim];
185 let dot = dot_product(v, centroid);
186 min_dot = min_dot.min(dot);
187 max_dot = max_dot.max(dot);
188 sum_dot += dot;
189 }
190
191 let clamped_min = min_dot.clamp(-1.0, 1.0);
192 let theta_max = clamped_min.acos();
193
194 Self {
195 centroid: centroid.to_vec(),
196 theta_max,
197 min_dot_to_centroid: min_dot,
198 max_dot_to_centroid: max_dot,
199 vector_count: n_vectors as u32,
200 mean_dot_to_centroid: sum_dot / n_vectors as f32,
201 }
202 }
203
204 /// Update metadata incrementally when a new vector is added
205 pub fn add_vector(&mut self, vector: &[f32]) {
206 let dot = dot_product(vector, &self.centroid);
207
208 let old_sum = self.mean_dot_to_centroid * self.vector_count as f32;
209 self.vector_count += 1;
210 self.mean_dot_to_centroid = (old_sum + dot) / self.vector_count as f32;
211
212 if dot < self.min_dot_to_centroid {
213 self.min_dot_to_centroid = dot;
214 self.theta_max = dot.clamp(-1.0, 1.0).acos();
215 }
216 if dot > self.max_dot_to_centroid {
217 self.max_dot_to_centroid = dot;
218 }
219 }
220
221 /// Get the angular radius of the spherical cap (in radians)
222 pub fn angular_radius(&self) -> f32 {
223 self.theta_max
224 }
225
226 /// Get the angular radius in degrees
227 pub fn angular_radius_degrees(&self) -> f32 {
228 self.theta_max * 180.0 / PI
229 }
230
231 /// Estimate the "tightness" of the cluster (0 = loose, 1 = tight)
232 pub fn tightness(&self) -> f32 {
233 // A perfectly tight cluster has all vectors at the centroid (theta_max = 0)
234 // A maximally loose cluster has theta_max = π
235 1.0 - (self.theta_max / PI)
236 }
237}
238
239// ============================================================================
240// L2 List Metadata
241// ============================================================================
242
243/// Metadata for L2 distance bounds (centroid + radius)
244#[derive(Debug, Clone)]
245pub struct L2ListMetadata {
246 /// Centroid of the list
247 pub centroid: Vec<f32>,
248
249 /// Maximum L2 distance from centroid to any vector in list
250 pub radius: f32,
251
252 /// Mean L2 distance from centroid
253 pub mean_radius: f32,
254
255 /// Number of vectors
256 pub vector_count: u32,
257}
258
259impl L2ListMetadata {
260 /// Build from vectors
261 pub fn from_vectors(vectors: &[Vec<f32>], centroid: &[f32]) -> Self {
262 if vectors.is_empty() {
263 return Self {
264 centroid: centroid.to_vec(),
265 radius: 0.0,
266 mean_radius: 0.0,
267 vector_count: 0,
268 };
269 }
270
271 let mut max_dist = 0.0f32;
272 let mut sum_dist = 0.0;
273
274 for v in vectors {
275 let dist = l2_distance(v, centroid);
276 max_dist = max_dist.max(dist);
277 sum_dist += dist;
278 }
279
280 Self {
281 centroid: centroid.to_vec(),
282 radius: max_dist,
283 mean_radius: sum_dist / vectors.len() as f32,
284 vector_count: vectors.len() as u32,
285 }
286 }
287
288 /// Compute lower bound on L2 distance from query to any vector in list
289 ///
290 /// LB = max(0, dist(q, c) - radius)
291 pub fn lower_bound(&self, query: &[f32]) -> f32 {
292 let dist_to_centroid = l2_distance(query, &self.centroid);
293 (dist_to_centroid - self.radius).max(0.0)
294 }
295}
296
297// ============================================================================
298// List Bound Computer
299// ============================================================================
300
301/// Computes bounds for a query across multiple lists
302///
303/// Precomputes query-related values once, then evaluates bounds per list in O(1).
304pub struct ListBoundComputer<'a> {
305 /// Query vector
306 query: &'a [f32],
307
308 /// Precomputed query norm (for L2)
309 query_norm: f32,
310
311 /// Distance metric
312 metric: DistanceMetric,
313}
314
315impl<'a> ListBoundComputer<'a> {
316 /// Create a new bound computer for a query
317 pub fn new(query: &'a [f32], metric: DistanceMetric) -> Self {
318 let query_norm = l2_norm(query);
319 Self {
320 query,
321 query_norm,
322 metric,
323 }
324 }
325
326 /// Compute upper bound on similarity for a spherical cap (cosine/dot)
327 ///
328 /// For normalized query q and list with centroid c and max deviation θ_max:
329 /// max_{v∈L} q·v ≤ cos(max(0, angle(q,c) - θ_max))
330 ///
331 /// Complexity: O(d) for dot product, rest is O(1)
332 pub fn cosine_upper_bound(&self, metadata: &SphericalCapMetadata) -> f32 {
333 // Compute q·c
334 let query_dot_centroid = dot_product(self.query, &metadata.centroid);
335
336 // angle(q,c) = arccos(q·c)
337 let clamped = query_dot_centroid.clamp(-1.0, 1.0);
338 let angle_to_centroid = clamped.acos();
339
340 // Upper bound angle to best vector: max(0, angle - θ_max)
341 let min_angle = (angle_to_centroid - metadata.theta_max).max(0.0);
342
343 // Upper bound on similarity: cos(min_angle)
344 min_angle.cos()
345 }
346
347 /// Compute lower bound on L2 distance for a list
348 ///
349 /// LB = max(0, dist(q, c) - radius)
350 pub fn l2_lower_bound(&self, metadata: &L2ListMetadata) -> f32 {
351 let dist_to_centroid = l2_distance(self.query, &metadata.centroid);
352 (dist_to_centroid - metadata.radius).max(0.0)
353 }
354
355 /// Compute bound appropriate for the configured metric
356 ///
357 /// For similarity metrics (cosine, dot): returns upper bound (higher = tighter)
358 /// For distance metrics (L2): returns lower bound (lower = tighter)
359 pub fn compute_bound(&self, cap: &SphericalCapMetadata, l2: Option<&L2ListMetadata>) -> f32 {
360 match self.metric {
361 DistanceMetric::Cosine | DistanceMetric::InnerProduct => self.cosine_upper_bound(cap),
362 DistanceMetric::L2 => {
363 if let Some(l2_meta) = l2 {
364 self.l2_lower_bound(l2_meta)
365 } else {
366 // Fall back to using spherical cap for normalized vectors
367 // LB ≈ sqrt(2 - 2*cos(angle))
368 let ub = self.cosine_upper_bound(cap);
369 (2.0 - 2.0 * ub).max(0.0).sqrt()
370 }
371 }
372 DistanceMetric::NegativeInnerProduct => -self.cosine_upper_bound(cap),
373 }
374 }
375}
376
377// ============================================================================
378// Multi-List Bound Ordering
379// ============================================================================
380
381/// Precomputed bounds for best-first list ordering
382#[derive(Debug, Clone)]
383pub struct ListBound {
384 /// List index
385 pub list_idx: u32,
386 /// Bound value (interpretation depends on metric)
387 pub bound: f32,
388}
389
390impl ListBound {
391 /// Order lists by bound for best-first probing
392 ///
393 /// For similarity metrics: descending order (best first)
394 /// For distance metrics: ascending order (best first)
395 pub fn order_for_probing(bounds: &mut [ListBound], metric: DistanceMetric) {
396 match metric {
397 DistanceMetric::Cosine | DistanceMetric::InnerProduct => {
398 // Higher similarity = better, sort descending
399 bounds.sort_by(|a, b| b.bound.partial_cmp(&a.bound).unwrap());
400 }
401 DistanceMetric::L2 | DistanceMetric::NegativeInnerProduct => {
402 // Lower distance = better, sort ascending
403 bounds.sort_by(|a, b| a.bound.partial_cmp(&b.bound).unwrap());
404 }
405 }
406 }
407
408 /// Check if we can terminate based on kth best score and remaining bounds
409 ///
410 /// For similarity: stop if kth_score > best_remaining_bound
411 /// For distance: stop if kth_score < best_remaining_bound
412 pub fn can_terminate(
413 kth_score: f32,
414 best_remaining_bound: f32,
415 metric: DistanceMetric,
416 ) -> bool {
417 match metric {
418 DistanceMetric::Cosine | DistanceMetric::InnerProduct => {
419 kth_score > best_remaining_bound
420 }
421 DistanceMetric::L2 | DistanceMetric::NegativeInnerProduct => {
422 kth_score < best_remaining_bound
423 }
424 }
425 }
426}
427
428// ============================================================================
429// Unified List Metadata
430// ============================================================================
431
432/// Combined metadata supporting all metrics
433#[derive(Debug, Clone)]
434pub struct UnifiedListMetadata {
435 /// Spherical cap for cosine/dot
436 pub cap: SphericalCapMetadata,
437
438 /// L2 metadata (optional, computed on demand)
439 pub l2: Option<L2ListMetadata>,
440
441 /// List index
442 pub list_idx: u32,
443}
444
445impl UnifiedListMetadata {
446 /// Build unified metadata
447 pub fn new(list_idx: u32, cap: SphericalCapMetadata) -> Self {
448 Self {
449 cap,
450 l2: None,
451 list_idx,
452 }
453 }
454
455 /// Add L2 metadata
456 pub fn with_l2(mut self, l2: L2ListMetadata) -> Self {
457 self.l2 = Some(l2);
458 self
459 }
460}
461
462// ============================================================================
463// Helper Functions
464// ============================================================================
465
466/// Compute dot product of two vectors
467#[inline]
468fn dot_product(a: &[f32], b: &[f32]) -> f32 {
469 debug_assert_eq!(a.len(), b.len());
470 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
471}
472
473/// Compute L2 norm
474#[inline]
475fn l2_norm(v: &[f32]) -> f32 {
476 v.iter().map(|x| x * x).sum::<f32>().sqrt()
477}
478
479/// Compute L2 distance
480#[inline]
481fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
482 debug_assert_eq!(a.len(), b.len());
483 a.iter()
484 .zip(b.iter())
485 .map(|(x, y)| (x - y).powi(2))
486 .sum::<f32>()
487 .sqrt()
488}
489
490/// Normalize a vector in-place
491pub fn normalize_inplace(v: &mut [f32]) {
492 let norm = l2_norm(v);
493 if norm > 1e-10 {
494 for x in v.iter_mut() {
495 *x /= norm;
496 }
497 }
498}
499
500/// Normalize a vector, returning new vector
501pub fn normalize(v: &[f32]) -> Vec<f32> {
502 let norm = l2_norm(v);
503 if norm > 1e-10 {
504 v.iter().map(|x| x / norm).collect()
505 } else {
506 v.to_vec()
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_spherical_cap_metadata() {
516 // Create a tight cluster of 3D unit vectors
517 let centroid = vec![1.0, 0.0, 0.0];
518 let vectors = vec![
519 normalize(&[1.0, 0.1, 0.0]),
520 normalize(&[1.0, -0.1, 0.0]),
521 normalize(&[1.0, 0.0, 0.1]),
522 normalize(&[1.0, 0.0, -0.1]),
523 ];
524
525 let metadata = SphericalCapMetadata::from_vectors(&vectors, ¢roid);
526
527 assert!(metadata.theta_max > 0.0);
528 assert!(metadata.theta_max < PI / 4.0); // Should be a tight cluster
529 assert!(metadata.tightness() > 0.5);
530 }
531
532 #[test]
533 fn test_cosine_upper_bound() {
534 // Centroid pointing in x direction
535 let centroid = vec![1.0, 0.0, 0.0];
536 let metadata = SphericalCapMetadata {
537 centroid: centroid.clone(),
538 theta_max: 0.3, // About 17 degrees
539 min_dot_to_centroid: 0.3_f32.cos(),
540 max_dot_to_centroid: 1.0,
541 vector_count: 10,
542 mean_dot_to_centroid: 0.95,
543 };
544
545 // Query in same direction as centroid
546 let query = vec![1.0, 0.0, 0.0];
547 let computer = ListBoundComputer::new(&query, DistanceMetric::Cosine);
548 let bound = computer.cosine_upper_bound(&metadata);
549
550 // Should be close to 1.0 since query aligns with centroid
551 assert!((bound - 1.0).abs() < 0.01);
552
553 // Query perpendicular to centroid
554 let query2 = vec![0.0, 1.0, 0.0];
555 let computer2 = ListBoundComputer::new(&query2, DistanceMetric::Cosine);
556 let bound2 = computer2.cosine_upper_bound(&metadata);
557
558 // Upper bound should account for theta_max
559 // angle = π/2, so upper bound = cos(π/2 - 0.3) = sin(0.3)
560 assert!((bound2 - 0.3_f32.sin()).abs() < 0.01);
561 }
562
563 #[test]
564 fn test_l2_lower_bound() {
565 let centroid = vec![0.0, 0.0, 0.0];
566 let metadata = L2ListMetadata {
567 centroid,
568 radius: 1.0,
569 mean_radius: 0.5,
570 vector_count: 100,
571 };
572
573 // Query at distance 2 from centroid
574 let query = vec![2.0, 0.0, 0.0];
575 let computer = ListBoundComputer::new(&query, DistanceMetric::L2);
576 let lb = computer.l2_lower_bound(&metadata);
577
578 // Lower bound should be 2 - 1 = 1
579 assert!((lb - 1.0).abs() < 0.01);
580
581 // Query inside the radius
582 let query2 = vec![0.5, 0.0, 0.0];
583 let computer2 = ListBoundComputer::new(&query2, DistanceMetric::L2);
584 let lb2 = computer2.l2_lower_bound(&metadata);
585
586 // Lower bound should be 0 (query is within radius)
587 assert!((lb2 - 0.0).abs() < 0.01);
588 }
589
590 #[test]
591 fn test_list_ordering() {
592 let mut bounds = vec![
593 ListBound {
594 list_idx: 0,
595 bound: 0.5,
596 },
597 ListBound {
598 list_idx: 1,
599 bound: 0.9,
600 },
601 ListBound {
602 list_idx: 2,
603 bound: 0.3,
604 },
605 ];
606
607 // For cosine, descending order (highest similarity first)
608 ListBound::order_for_probing(&mut bounds, DistanceMetric::Cosine);
609 assert_eq!(bounds[0].list_idx, 1); // 0.9
610 assert_eq!(bounds[1].list_idx, 0); // 0.5
611 assert_eq!(bounds[2].list_idx, 2); // 0.3
612
613 // For L2, ascending order (lowest distance first)
614 ListBound::order_for_probing(&mut bounds, DistanceMetric::L2);
615 assert_eq!(bounds[0].list_idx, 2); // 0.3
616 assert_eq!(bounds[1].list_idx, 0); // 0.5
617 assert_eq!(bounds[2].list_idx, 1); // 0.9
618 }
619}