Skip to main content

sochdb_vector/
list_bounds.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # Cosine/Dot (MIPS) List Bounds (Task 3)
19//!
20//! Provides computable upper bounds for cosine/dot similarity over IVF lists,
21//! enabling best-first probing and bound-based termination for non-L2 metrics.
22//!
23//! ## Architecture
24//!
25//! For cosine/dot similarity, we store spherical cap metadata per list:
26//! - Centroid direction c (unit vector)
27//! - Max angular deviation θ_max (or min dot to centroid)
28//!
29//! ## Math/Algorithm
30//!
31//! For normalized queries q, the upper bound on achievable similarity in list L is:
32//!
33//! ```text
34//! max_{v∈L} q·v ≤ cos(max(0, angle(q,c) - θ_max))
35//! ```
36//!
37//! where:
38//! - angle(q,c) = arccos(q·c)
39//! - θ_max = max_{v∈L} arccos(v·c)
40//!
41//! Bound evaluation is O(1) per list after precomputing q·c.
42//!
43//! ## Usage
44//!
45//! ```rust,ignore
46//! use sochdb_vector::list_bounds::{SphericalCapMetadata, ListBoundComputer};
47//!
48//! // Build metadata during indexing
49//! let metadata = SphericalCapMetadata::from_vectors(&vectors, centroid);
50//!
51//! // At query time, compute bounds
52//! let computer = ListBoundComputer::new(&query);
53//! let bound = computer.upper_bound(&metadata);
54//! ```
55
56use std::f32::consts::PI;
57
58// ============================================================================
59// Distance Metric
60// ============================================================================
61
62/// Supported distance metrics with bound computation methods
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum DistanceMetric {
65    /// L2 (Euclidean) distance - smaller is better
66    L2,
67    /// Cosine similarity - larger is better (1 = identical)
68    Cosine,
69    /// Inner product (dot) - larger is better
70    InnerProduct,
71    /// Negative inner product (for MIPS via nearest neighbor)
72    NegativeInnerProduct,
73}
74
75impl DistanceMetric {
76    /// Returns true if higher scores are better
77    pub fn higher_is_better(&self) -> bool {
78        matches!(self, Self::Cosine | Self::InnerProduct)
79    }
80
81    /// Returns true if vectors should be normalized
82    pub fn requires_normalization(&self) -> bool {
83        matches!(self, Self::Cosine)
84    }
85}
86
87// ============================================================================
88// Spherical Cap Metadata
89// ============================================================================
90
91/// Spherical cap metadata for a list/partition
92///
93/// Represents the region of the unit sphere covered by vectors in this list.
94/// Used to compute tight bounds for cosine/dot similarity.
95#[derive(Debug, Clone)]
96pub struct SphericalCapMetadata {
97    /// Centroid direction (unit vector)
98    pub centroid: Vec<f32>,
99
100    /// Maximum angular deviation from centroid (in radians)
101    /// θ_max = max_{v∈L} arccos(v·c)
102    pub theta_max: f32,
103
104    /// Minimum dot product with centroid
105    /// min_dot = min_{v∈L} v·c = cos(θ_max)
106    pub min_dot_to_centroid: f32,
107
108    /// Maximum dot product with centroid (typically ~1.0 for tight clusters)
109    pub max_dot_to_centroid: f32,
110
111    /// Number of vectors in this list
112    pub vector_count: u32,
113
114    /// Mean dot product with centroid (for statistics)
115    pub mean_dot_to_centroid: f32,
116}
117
118impl SphericalCapMetadata {
119    /// Build metadata from a set of normalized vectors and their centroid
120    ///
121    /// # Arguments
122    /// * `vectors` - Normalized vectors in the list (each row is a vector)
123    /// * `centroid` - Normalized centroid of the list
124    ///
125    /// # Complexity
126    /// O(n × d) where n = number of vectors, d = dimension
127    pub fn from_vectors(vectors: &[Vec<f32>], centroid: &[f32]) -> Self {
128        if vectors.is_empty() {
129            return Self {
130                centroid: centroid.to_vec(),
131                theta_max: 0.0,
132                min_dot_to_centroid: 1.0,
133                max_dot_to_centroid: 1.0,
134                vector_count: 0,
135                mean_dot_to_centroid: 1.0,
136            };
137        }
138
139        let mut min_dot = f32::MAX;
140        let mut max_dot = f32::MIN;
141        let mut sum_dot = 0.0;
142
143        for v in vectors {
144            let dot = dot_product(v, centroid);
145            min_dot = min_dot.min(dot);
146            max_dot = max_dot.max(dot);
147            sum_dot += dot;
148        }
149
150        // Clamp to valid range for arccos
151        let clamped_min = min_dot.clamp(-1.0, 1.0);
152        let theta_max = clamped_min.acos();
153
154        Self {
155            centroid: centroid.to_vec(),
156            theta_max,
157            min_dot_to_centroid: min_dot,
158            max_dot_to_centroid: max_dot,
159            vector_count: vectors.len() as u32,
160            mean_dot_to_centroid: sum_dot / vectors.len() as f32,
161        }
162    }
163
164    /// Build metadata from flat vector data
165    pub fn from_flat_vectors(data: &[f32], dim: usize, centroid: &[f32]) -> Self {
166        let n_vectors = data.len() / dim;
167
168        if n_vectors == 0 {
169            return Self {
170                centroid: centroid.to_vec(),
171                theta_max: 0.0,
172                min_dot_to_centroid: 1.0,
173                max_dot_to_centroid: 1.0,
174                vector_count: 0,
175                mean_dot_to_centroid: 1.0,
176            };
177        }
178
179        let mut min_dot = f32::MAX;
180        let mut max_dot = f32::MIN;
181        let mut sum_dot = 0.0;
182
183        for i in 0..n_vectors {
184            let v = &data[i * dim..(i + 1) * dim];
185            let dot = dot_product(v, centroid);
186            min_dot = min_dot.min(dot);
187            max_dot = max_dot.max(dot);
188            sum_dot += dot;
189        }
190
191        let clamped_min = min_dot.clamp(-1.0, 1.0);
192        let theta_max = clamped_min.acos();
193
194        Self {
195            centroid: centroid.to_vec(),
196            theta_max,
197            min_dot_to_centroid: min_dot,
198            max_dot_to_centroid: max_dot,
199            vector_count: n_vectors as u32,
200            mean_dot_to_centroid: sum_dot / n_vectors as f32,
201        }
202    }
203
204    /// Update metadata incrementally when a new vector is added
205    pub fn add_vector(&mut self, vector: &[f32]) {
206        let dot = dot_product(vector, &self.centroid);
207
208        let old_sum = self.mean_dot_to_centroid * self.vector_count as f32;
209        self.vector_count += 1;
210        self.mean_dot_to_centroid = (old_sum + dot) / self.vector_count as f32;
211
212        if dot < self.min_dot_to_centroid {
213            self.min_dot_to_centroid = dot;
214            self.theta_max = dot.clamp(-1.0, 1.0).acos();
215        }
216        if dot > self.max_dot_to_centroid {
217            self.max_dot_to_centroid = dot;
218        }
219    }
220
221    /// Get the angular radius of the spherical cap (in radians)
222    pub fn angular_radius(&self) -> f32 {
223        self.theta_max
224    }
225
226    /// Get the angular radius in degrees
227    pub fn angular_radius_degrees(&self) -> f32 {
228        self.theta_max * 180.0 / PI
229    }
230
231    /// Estimate the "tightness" of the cluster (0 = loose, 1 = tight)
232    pub fn tightness(&self) -> f32 {
233        // A perfectly tight cluster has all vectors at the centroid (theta_max = 0)
234        // A maximally loose cluster has theta_max = π
235        1.0 - (self.theta_max / PI)
236    }
237}
238
239// ============================================================================
240// L2 List Metadata
241// ============================================================================
242
243/// Metadata for L2 distance bounds (centroid + radius)
244#[derive(Debug, Clone)]
245pub struct L2ListMetadata {
246    /// Centroid of the list
247    pub centroid: Vec<f32>,
248
249    /// Maximum L2 distance from centroid to any vector in list
250    pub radius: f32,
251
252    /// Mean L2 distance from centroid
253    pub mean_radius: f32,
254
255    /// Number of vectors
256    pub vector_count: u32,
257}
258
259impl L2ListMetadata {
260    /// Build from vectors
261    pub fn from_vectors(vectors: &[Vec<f32>], centroid: &[f32]) -> Self {
262        if vectors.is_empty() {
263            return Self {
264                centroid: centroid.to_vec(),
265                radius: 0.0,
266                mean_radius: 0.0,
267                vector_count: 0,
268            };
269        }
270
271        let mut max_dist = 0.0f32;
272        let mut sum_dist = 0.0;
273
274        for v in vectors {
275            let dist = l2_distance(v, centroid);
276            max_dist = max_dist.max(dist);
277            sum_dist += dist;
278        }
279
280        Self {
281            centroid: centroid.to_vec(),
282            radius: max_dist,
283            mean_radius: sum_dist / vectors.len() as f32,
284            vector_count: vectors.len() as u32,
285        }
286    }
287
288    /// Compute lower bound on L2 distance from query to any vector in list
289    ///
290    /// LB = max(0, dist(q, c) - radius)
291    pub fn lower_bound(&self, query: &[f32]) -> f32 {
292        let dist_to_centroid = l2_distance(query, &self.centroid);
293        (dist_to_centroid - self.radius).max(0.0)
294    }
295}
296
297// ============================================================================
298// List Bound Computer
299// ============================================================================
300
301/// Computes bounds for a query across multiple lists
302///
303/// Precomputes query-related values once, then evaluates bounds per list in O(1).
304pub struct ListBoundComputer<'a> {
305    /// Query vector
306    query: &'a [f32],
307
308    /// Precomputed query norm (for L2)
309    query_norm: f32,
310
311    /// Distance metric
312    metric: DistanceMetric,
313}
314
315impl<'a> ListBoundComputer<'a> {
316    /// Create a new bound computer for a query
317    pub fn new(query: &'a [f32], metric: DistanceMetric) -> Self {
318        let query_norm = l2_norm(query);
319        Self {
320            query,
321            query_norm,
322            metric,
323        }
324    }
325
326    /// Compute upper bound on similarity for a spherical cap (cosine/dot)
327    ///
328    /// For normalized query q and list with centroid c and max deviation θ_max:
329    /// max_{v∈L} q·v ≤ cos(max(0, angle(q,c) - θ_max))
330    ///
331    /// Complexity: O(d) for dot product, rest is O(1)
332    pub fn cosine_upper_bound(&self, metadata: &SphericalCapMetadata) -> f32 {
333        // Compute q·c
334        let query_dot_centroid = dot_product(self.query, &metadata.centroid);
335
336        // angle(q,c) = arccos(q·c)
337        let clamped = query_dot_centroid.clamp(-1.0, 1.0);
338        let angle_to_centroid = clamped.acos();
339
340        // Upper bound angle to best vector: max(0, angle - θ_max)
341        let min_angle = (angle_to_centroid - metadata.theta_max).max(0.0);
342
343        // Upper bound on similarity: cos(min_angle)
344        min_angle.cos()
345    }
346
347    /// Compute lower bound on L2 distance for a list
348    ///
349    /// LB = max(0, dist(q, c) - radius)
350    pub fn l2_lower_bound(&self, metadata: &L2ListMetadata) -> f32 {
351        let dist_to_centroid = l2_distance(self.query, &metadata.centroid);
352        (dist_to_centroid - metadata.radius).max(0.0)
353    }
354
355    /// Compute bound appropriate for the configured metric
356    ///
357    /// For similarity metrics (cosine, dot): returns upper bound (higher = tighter)
358    /// For distance metrics (L2): returns lower bound (lower = tighter)
359    pub fn compute_bound(&self, cap: &SphericalCapMetadata, l2: Option<&L2ListMetadata>) -> f32 {
360        match self.metric {
361            DistanceMetric::Cosine | DistanceMetric::InnerProduct => self.cosine_upper_bound(cap),
362            DistanceMetric::L2 => {
363                if let Some(l2_meta) = l2 {
364                    self.l2_lower_bound(l2_meta)
365                } else {
366                    // Fall back to using spherical cap for normalized vectors
367                    // LB ≈ sqrt(2 - 2*cos(angle))
368                    let ub = self.cosine_upper_bound(cap);
369                    (2.0 - 2.0 * ub).max(0.0).sqrt()
370                }
371            }
372            DistanceMetric::NegativeInnerProduct => -self.cosine_upper_bound(cap),
373        }
374    }
375}
376
377// ============================================================================
378// Multi-List Bound Ordering
379// ============================================================================
380
381/// Precomputed bounds for best-first list ordering
382#[derive(Debug, Clone)]
383pub struct ListBound {
384    /// List index
385    pub list_idx: u32,
386    /// Bound value (interpretation depends on metric)
387    pub bound: f32,
388}
389
390impl ListBound {
391    /// Order lists by bound for best-first probing
392    ///
393    /// For similarity metrics: descending order (best first)
394    /// For distance metrics: ascending order (best first)
395    pub fn order_for_probing(bounds: &mut [ListBound], metric: DistanceMetric) {
396        match metric {
397            DistanceMetric::Cosine | DistanceMetric::InnerProduct => {
398                // Higher similarity = better, sort descending
399                bounds.sort_by(|a, b| b.bound.partial_cmp(&a.bound).unwrap());
400            }
401            DistanceMetric::L2 | DistanceMetric::NegativeInnerProduct => {
402                // Lower distance = better, sort ascending
403                bounds.sort_by(|a, b| a.bound.partial_cmp(&b.bound).unwrap());
404            }
405        }
406    }
407
408    /// Check if we can terminate based on kth best score and remaining bounds
409    ///
410    /// For similarity: stop if kth_score > best_remaining_bound
411    /// For distance: stop if kth_score < best_remaining_bound
412    pub fn can_terminate(
413        kth_score: f32,
414        best_remaining_bound: f32,
415        metric: DistanceMetric,
416    ) -> bool {
417        match metric {
418            DistanceMetric::Cosine | DistanceMetric::InnerProduct => {
419                kth_score > best_remaining_bound
420            }
421            DistanceMetric::L2 | DistanceMetric::NegativeInnerProduct => {
422                kth_score < best_remaining_bound
423            }
424        }
425    }
426}
427
428// ============================================================================
429// Unified List Metadata
430// ============================================================================
431
432/// Combined metadata supporting all metrics
433#[derive(Debug, Clone)]
434pub struct UnifiedListMetadata {
435    /// Spherical cap for cosine/dot
436    pub cap: SphericalCapMetadata,
437
438    /// L2 metadata (optional, computed on demand)
439    pub l2: Option<L2ListMetadata>,
440
441    /// List index
442    pub list_idx: u32,
443}
444
445impl UnifiedListMetadata {
446    /// Build unified metadata
447    pub fn new(list_idx: u32, cap: SphericalCapMetadata) -> Self {
448        Self {
449            cap,
450            l2: None,
451            list_idx,
452        }
453    }
454
455    /// Add L2 metadata
456    pub fn with_l2(mut self, l2: L2ListMetadata) -> Self {
457        self.l2 = Some(l2);
458        self
459    }
460}
461
462// ============================================================================
463// Helper Functions
464// ============================================================================
465
466/// Compute dot product of two vectors
467#[inline]
468fn dot_product(a: &[f32], b: &[f32]) -> f32 {
469    debug_assert_eq!(a.len(), b.len());
470    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
471}
472
473/// Compute L2 norm
474#[inline]
475fn l2_norm(v: &[f32]) -> f32 {
476    v.iter().map(|x| x * x).sum::<f32>().sqrt()
477}
478
479/// Compute L2 distance
480#[inline]
481fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
482    debug_assert_eq!(a.len(), b.len());
483    a.iter()
484        .zip(b.iter())
485        .map(|(x, y)| (x - y).powi(2))
486        .sum::<f32>()
487        .sqrt()
488}
489
490/// Normalize a vector in-place
491pub fn normalize_inplace(v: &mut [f32]) {
492    let norm = l2_norm(v);
493    if norm > 1e-10 {
494        for x in v.iter_mut() {
495            *x /= norm;
496        }
497    }
498}
499
500/// Normalize a vector, returning new vector
501pub fn normalize(v: &[f32]) -> Vec<f32> {
502    let norm = l2_norm(v);
503    if norm > 1e-10 {
504        v.iter().map(|x| x / norm).collect()
505    } else {
506        v.to_vec()
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_spherical_cap_metadata() {
516        // Create a tight cluster of 3D unit vectors
517        let centroid = vec![1.0, 0.0, 0.0];
518        let vectors = vec![
519            normalize(&[1.0, 0.1, 0.0]),
520            normalize(&[1.0, -0.1, 0.0]),
521            normalize(&[1.0, 0.0, 0.1]),
522            normalize(&[1.0, 0.0, -0.1]),
523        ];
524
525        let metadata = SphericalCapMetadata::from_vectors(&vectors, &centroid);
526
527        assert!(metadata.theta_max > 0.0);
528        assert!(metadata.theta_max < PI / 4.0); // Should be a tight cluster
529        assert!(metadata.tightness() > 0.5);
530    }
531
532    #[test]
533    fn test_cosine_upper_bound() {
534        // Centroid pointing in x direction
535        let centroid = vec![1.0, 0.0, 0.0];
536        let metadata = SphericalCapMetadata {
537            centroid: centroid.clone(),
538            theta_max: 0.3, // About 17 degrees
539            min_dot_to_centroid: 0.3_f32.cos(),
540            max_dot_to_centroid: 1.0,
541            vector_count: 10,
542            mean_dot_to_centroid: 0.95,
543        };
544
545        // Query in same direction as centroid
546        let query = vec![1.0, 0.0, 0.0];
547        let computer = ListBoundComputer::new(&query, DistanceMetric::Cosine);
548        let bound = computer.cosine_upper_bound(&metadata);
549
550        // Should be close to 1.0 since query aligns with centroid
551        assert!((bound - 1.0).abs() < 0.01);
552
553        // Query perpendicular to centroid
554        let query2 = vec![0.0, 1.0, 0.0];
555        let computer2 = ListBoundComputer::new(&query2, DistanceMetric::Cosine);
556        let bound2 = computer2.cosine_upper_bound(&metadata);
557
558        // Upper bound should account for theta_max
559        // angle = π/2, so upper bound = cos(π/2 - 0.3) = sin(0.3)
560        assert!((bound2 - 0.3_f32.sin()).abs() < 0.01);
561    }
562
563    #[test]
564    fn test_l2_lower_bound() {
565        let centroid = vec![0.0, 0.0, 0.0];
566        let metadata = L2ListMetadata {
567            centroid,
568            radius: 1.0,
569            mean_radius: 0.5,
570            vector_count: 100,
571        };
572
573        // Query at distance 2 from centroid
574        let query = vec![2.0, 0.0, 0.0];
575        let computer = ListBoundComputer::new(&query, DistanceMetric::L2);
576        let lb = computer.l2_lower_bound(&metadata);
577
578        // Lower bound should be 2 - 1 = 1
579        assert!((lb - 1.0).abs() < 0.01);
580
581        // Query inside the radius
582        let query2 = vec![0.5, 0.0, 0.0];
583        let computer2 = ListBoundComputer::new(&query2, DistanceMetric::L2);
584        let lb2 = computer2.l2_lower_bound(&metadata);
585
586        // Lower bound should be 0 (query is within radius)
587        assert!((lb2 - 0.0).abs() < 0.01);
588    }
589
590    #[test]
591    fn test_list_ordering() {
592        let mut bounds = vec![
593            ListBound {
594                list_idx: 0,
595                bound: 0.5,
596            },
597            ListBound {
598                list_idx: 1,
599                bound: 0.9,
600            },
601            ListBound {
602                list_idx: 2,
603                bound: 0.3,
604            },
605        ];
606
607        // For cosine, descending order (highest similarity first)
608        ListBound::order_for_probing(&mut bounds, DistanceMetric::Cosine);
609        assert_eq!(bounds[0].list_idx, 1); // 0.9
610        assert_eq!(bounds[1].list_idx, 0); // 0.5
611        assert_eq!(bounds[2].list_idx, 2); // 0.3
612
613        // For L2, ascending order (lowest distance first)
614        ListBound::order_for_probing(&mut bounds, DistanceMetric::L2);
615        assert_eq!(bounds[0].list_idx, 2); // 0.3
616        assert_eq!(bounds[1].list_idx, 0); // 0.5
617        assert_eq!(bounds[2].list_idx, 1); // 0.9
618    }
619}