Skip to main content

sochdb_vector/
quantization_calibration.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # Quantization Error Calibration (Task 5)
19//!
20//! Calibrates quantization error per-list (or per-cluster) and uses it in stopping decisions.
21//!
22//! ## Problem
23//!
24//! With PQ/ADC, the kth score is a proxy score, not the true score.
25//! Stopping on list bounds vs kth requires them to be comparable in the true metric.
26//!
27//! ## Solution
28//!
29//! Learn empirical error envelopes:
30//! - Per-list quantiles for ε = ŝ - s under representative queries
31//! - At query time, convert proxy thresholds into safe true-score thresholds
32//!
33//! ## Math/Algorithm
34//!
35//! PAC-style calibration:
36//! - Store ε_L(1-δ) such that P(ε ≤ ε_L) ≥ 1-δ
37//! - Stopping compares LB_true(list) vs UB_true(kth) using these envelopes
38//!
39//! ## Usage
40//!
41//! ```rust,ignore
42//! use sochdb_vector::quantization_calibration::{ErrorCalibrator, ErrorEnvelope};
43//!
44//! // During offline training
45//! let mut calibrator = ErrorCalibrator::new(n_lists);
46//! calibrator.record_error(list_idx, proxy_score, true_score);
47//! let envelopes = calibrator.finalize();
48//!
49//! // At query time
50//! let proxy_kth = 0.85;
51//! let safe_threshold = envelopes.safe_true_threshold(list_idx, proxy_kth, 0.99);
52//! ```
53
54use serde::{Deserialize, Serialize};
55use std::collections::HashMap;
56
57// ============================================================================
58// Error Sample
59// ============================================================================
60
61/// A single error sample: ε = proxy - true
62#[derive(Debug, Clone, Copy)]
63pub struct ErrorSample {
64    /// Proxy score (from quantized representation)
65    pub proxy: f32,
66    /// True score (from full-precision computation)
67    pub true_score: f32,
68    /// Error: proxy - true
69    pub error: f32,
70}
71
72impl ErrorSample {
73    /// Create from proxy and true scores
74    pub fn new(proxy: f32, true_score: f32) -> Self {
75        Self {
76            proxy,
77            true_score,
78            error: proxy - true_score,
79        }
80    }
81}
82
83// ============================================================================
84// Error Envelope
85// ============================================================================
86
87/// Pre-computed error envelope for a list
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ErrorEnvelope {
90    /// List index
91    pub list_idx: u32,
92
93    /// Quantiles of error distribution
94    /// Key: quantile (e.g., 0.95, 0.99, 0.999)
95    /// Value: error bound at that quantile
96    pub quantiles: HashMap<u32, f32>, // Use u32 for serialization (quantile * 10000)
97
98    /// Mean error
99    pub mean_error: f32,
100
101    /// Standard deviation of error
102    pub std_error: f32,
103
104    /// Maximum observed error
105    pub max_error: f32,
106
107    /// Minimum observed error  
108    pub min_error: f32,
109
110    /// Number of samples used for calibration
111    pub sample_count: u32,
112}
113
114impl ErrorEnvelope {
115    /// Get error bound for a quantile
116    ///
117    /// Returns ε such that P(error ≤ ε) ≥ quantile
118    pub fn error_at_quantile(&self, quantile: f32) -> f32 {
119        let key = (quantile * 10000.0).round() as u32;
120
121        // Direct lookup
122        if let Some(&error) = self.quantiles.get(&key) {
123            return error;
124        }
125
126        // Interpolate between available quantiles
127        let mut below_key = 0u32;
128        let mut above_key = 10000u32;
129        let mut below_val = self.min_error;
130        let mut above_val = self.max_error;
131
132        for (&k, &v) in &self.quantiles {
133            if k < key && k > below_key {
134                below_key = k;
135                below_val = v;
136            }
137            if k > key && k < above_key {
138                above_key = k;
139                above_val = v;
140            }
141        }
142
143        // Linear interpolation
144        if above_key > below_key {
145            let t = (key - below_key) as f32 / (above_key - below_key) as f32;
146            below_val + t * (above_val - below_val)
147        } else {
148            self.max_error
149        }
150    }
151
152    /// Convert proxy threshold to safe true-score threshold
153    ///
154    /// For similarity (higher is better):
155    /// true_score ≥ proxy - error_bound
156    ///
157    /// Returns threshold such that P(true ≥ threshold | proxy = p) ≥ confidence
158    pub fn safe_true_threshold(&self, proxy: f32, confidence: f32) -> f32 {
159        let error_bound = self.error_at_quantile(confidence);
160        proxy - error_bound
161    }
162
163    /// Convert true threshold to safe proxy threshold
164    ///
165    /// For filtering candidates before rerank:
166    /// proxy ≥ true + error_bound (conservative)
167    pub fn safe_proxy_threshold(&self, true_threshold: f32, confidence: f32) -> f32 {
168        let error_bound = self.error_at_quantile(confidence);
169        true_threshold + error_bound
170    }
171
172    /// Check if proxy score definitely beats true threshold
173    pub fn definitely_beats(&self, proxy: f32, true_threshold: f32) -> bool {
174        // Use max error for deterministic guarantee
175        proxy - self.max_error > true_threshold
176    }
177
178    /// Check if proxy score might beat true threshold
179    pub fn might_beat(&self, proxy: f32, true_threshold: f32, confidence: f32) -> bool {
180        let error_bound = self.error_at_quantile(confidence);
181        proxy - error_bound > true_threshold
182    }
183}
184
185impl Default for ErrorEnvelope {
186    fn default() -> Self {
187        Self {
188            list_idx: 0,
189            quantiles: HashMap::new(),
190            mean_error: 0.0,
191            std_error: 0.0,
192            max_error: 0.0,
193            min_error: 0.0,
194            sample_count: 0,
195        }
196    }
197}
198
199// ============================================================================
200// Error Calibrator
201// ============================================================================
202
203/// Collects error samples and computes envelopes
204pub struct ErrorCalibrator {
205    /// Samples per list
206    samples: Vec<Vec<ErrorSample>>,
207    /// Number of lists
208    n_lists: usize,
209    /// Quantiles to compute
210    quantiles: Vec<f32>,
211}
212
213impl ErrorCalibrator {
214    /// Create new calibrator for n_lists
215    pub fn new(n_lists: usize) -> Self {
216        Self {
217            samples: vec![Vec::new(); n_lists],
218            n_lists,
219            quantiles: vec![0.50, 0.75, 0.90, 0.95, 0.99, 0.999],
220        }
221    }
222
223    /// Create with custom quantiles
224    pub fn with_quantiles(n_lists: usize, quantiles: Vec<f32>) -> Self {
225        Self {
226            samples: vec![Vec::new(); n_lists],
227            n_lists,
228            quantiles,
229        }
230    }
231
232    /// Record an error sample for a list
233    pub fn record_error(&mut self, list_idx: usize, proxy: f32, true_score: f32) {
234        if list_idx < self.n_lists {
235            self.samples[list_idx].push(ErrorSample::new(proxy, true_score));
236        }
237    }
238
239    /// Record multiple samples for a list
240    pub fn record_errors(&mut self, list_idx: usize, samples: &[(f32, f32)]) {
241        if list_idx < self.n_lists {
242            for &(proxy, true_score) in samples {
243                self.samples[list_idx].push(ErrorSample::new(proxy, true_score));
244            }
245        }
246    }
247
248    /// Compute envelopes for all lists
249    pub fn finalize(&self) -> ErrorEnvelopeSet {
250        let envelopes: Vec<ErrorEnvelope> = (0..self.n_lists)
251            .map(|i| self.compute_envelope(i))
252            .collect();
253
254        // Also compute global envelope
255        let global = self.compute_global_envelope();
256
257        ErrorEnvelopeSet { envelopes, global }
258    }
259
260    /// Compute envelope for a single list
261    fn compute_envelope(&self, list_idx: usize) -> ErrorEnvelope {
262        let samples = &self.samples[list_idx];
263
264        if samples.is_empty() {
265            return ErrorEnvelope {
266                list_idx: list_idx as u32,
267                ..Default::default()
268            };
269        }
270
271        // Extract errors
272        let mut errors: Vec<f32> = samples.iter().map(|s| s.error).collect();
273        errors.sort_by(|a, b| a.partial_cmp(b).unwrap());
274
275        let n = errors.len();
276
277        // Compute statistics
278        let sum: f32 = errors.iter().sum();
279        let mean = sum / n as f32;
280        let variance: f32 = errors.iter().map(|&e| (e - mean).powi(2)).sum::<f32>() / n as f32;
281        let std = variance.sqrt();
282
283        // Compute quantiles
284        let mut quantiles = HashMap::new();
285        for &q in &self.quantiles {
286            let idx = ((n as f32 * q) as usize).min(n - 1);
287            let key = (q * 10000.0).round() as u32;
288            quantiles.insert(key, errors[idx]);
289        }
290
291        ErrorEnvelope {
292            list_idx: list_idx as u32,
293            quantiles,
294            mean_error: mean,
295            std_error: std,
296            max_error: errors[n - 1],
297            min_error: errors[0],
298            sample_count: n as u32,
299        }
300    }
301
302    /// Compute global envelope across all lists
303    fn compute_global_envelope(&self) -> ErrorEnvelope {
304        let mut all_errors: Vec<f32> = self
305            .samples
306            .iter()
307            .flat_map(|s| s.iter().map(|e| e.error))
308            .collect();
309
310        if all_errors.is_empty() {
311            return ErrorEnvelope::default();
312        }
313
314        all_errors.sort_by(|a, b| a.partial_cmp(b).unwrap());
315        let n = all_errors.len();
316
317        let sum: f32 = all_errors.iter().sum();
318        let mean = sum / n as f32;
319        let variance: f32 = all_errors.iter().map(|&e| (e - mean).powi(2)).sum::<f32>() / n as f32;
320        let std = variance.sqrt();
321
322        let mut quantiles = HashMap::new();
323        for &q in &self.quantiles {
324            let idx = ((n as f32 * q) as usize).min(n - 1);
325            let key = (q * 10000.0).round() as u32;
326            quantiles.insert(key, all_errors[idx]);
327        }
328
329        ErrorEnvelope {
330            list_idx: u32::MAX, // Indicates global
331            quantiles,
332            mean_error: mean,
333            std_error: std,
334            max_error: all_errors[n - 1],
335            min_error: all_errors[0],
336            sample_count: n as u32,
337        }
338    }
339}
340
341// ============================================================================
342// Error Envelope Set
343// ============================================================================
344
345/// Collection of error envelopes for all lists
346#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct ErrorEnvelopeSet {
348    /// Per-list envelopes
349    pub envelopes: Vec<ErrorEnvelope>,
350    /// Global envelope
351    pub global: ErrorEnvelope,
352}
353
354impl ErrorEnvelopeSet {
355    /// Get envelope for a list, falling back to global if not available
356    pub fn get(&self, list_idx: usize) -> &ErrorEnvelope {
357        if list_idx < self.envelopes.len() && self.envelopes[list_idx].sample_count > 0 {
358            &self.envelopes[list_idx]
359        } else {
360            &self.global
361        }
362    }
363
364    /// Convert proxy kth to safe true threshold
365    pub fn safe_true_threshold(&self, list_idx: usize, proxy: f32, confidence: f32) -> f32 {
366        self.get(list_idx).safe_true_threshold(proxy, confidence)
367    }
368
369    /// Check if we can terminate: all remaining lists have bounds below kth true threshold
370    pub fn can_terminate(
371        &self,
372        kth_proxy: f32,
373        remaining_list_bounds: &[(usize, f32)],
374        confidence: f32,
375    ) -> bool {
376        // Convert kth proxy to safe true threshold (lower bound on true kth)
377        let kth_true_lower = self.global.safe_true_threshold(kth_proxy, confidence);
378
379        // Check if all remaining list bounds are below kth true threshold
380        remaining_list_bounds.iter().all(|(list_idx, bound)| {
381            // Use per-list envelope for tighter bounds
382            let envelope = self.get(*list_idx);
383            // The bound is an upper bound on proxy scores in the list
384            // Convert to upper bound on true scores
385            let true_upper = *bound + envelope.max_error.abs();
386            true_upper < kth_true_lower
387        })
388    }
389
390    /// Serialize to bytes
391    pub fn to_bytes(&self) -> Vec<u8> {
392        bincode::serialize(self).unwrap_or_default()
393    }
394
395    /// Deserialize from bytes
396    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
397        bincode::deserialize(bytes).ok()
398    }
399}
400
401// ============================================================================
402// Calibration Runner
403// ============================================================================
404
405/// Runs calibration using representative queries
406pub struct CalibrationRunner {
407    /// Number of lists
408    n_lists: usize,
409    /// Quantization function (takes vector, returns quantized codes)
410    quantize_fn: Option<Box<dyn Fn(&[f32]) -> Vec<u8> + Send + Sync>>,
411    /// Distance function for proxy (takes query, codes, returns score)
412    proxy_distance_fn: Option<Box<dyn Fn(&[f32], &[u8]) -> f32 + Send + Sync>>,
413    /// Distance function for true (takes query, vector, returns score)
414    true_distance_fn: Option<Box<dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync>>,
415}
416
417impl CalibrationRunner {
418    /// Create new calibration runner
419    pub fn new(n_lists: usize) -> Self {
420        Self {
421            n_lists,
422            quantize_fn: None,
423            proxy_distance_fn: None,
424            true_distance_fn: None,
425        }
426    }
427
428    /// Run calibration with given queries and vectors per list
429    ///
430    /// For each query, computes proxy and true scores for vectors in each list,
431    /// collecting error samples.
432    pub fn calibrate(
433        &self,
434        queries: &[Vec<f32>],
435        lists: &[Vec<Vec<f32>>],
436        quantized_lists: &[Vec<Vec<u8>>],
437    ) -> ErrorEnvelopeSet {
438        let mut calibrator = ErrorCalibrator::new(self.n_lists);
439
440        for query in queries {
441            for (list_idx, (vectors, codes)) in lists.iter().zip(quantized_lists.iter()).enumerate()
442            {
443                for (vec, code) in vectors.iter().zip(codes.iter()) {
444                    // Compute true and proxy scores
445                    let true_score = dot_product(query, vec);
446                    let proxy_score = if let Some(ref f) = self.proxy_distance_fn {
447                        f(query, code)
448                    } else {
449                        true_score // Fallback: no quantization error
450                    };
451
452                    calibrator.record_error(list_idx, proxy_score, true_score);
453                }
454            }
455        }
456
457        calibrator.finalize()
458    }
459
460    /// Simplified calibration using synthetic error model
461    ///
462    /// Generates error samples based on assumed error distribution.
463    pub fn calibrate_synthetic(
464        n_lists: usize,
465        mean_error: f32,
466        std_error: f32,
467        samples_per_list: usize,
468    ) -> ErrorEnvelopeSet {
469        let mut calibrator = ErrorCalibrator::new(n_lists);
470
471        // Use simple random number generation for reproducibility
472        let mut rng_state: u64 = 12345;
473        let mut rand = || {
474            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
475            (rng_state >> 33) as f32 / (1u64 << 31) as f32
476        };
477
478        for list_idx in 0..n_lists {
479            for _ in 0..samples_per_list {
480                // Box-Muller transform for normal distribution
481                let u1 = rand();
482                let u2 = rand();
483                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
484                let error = mean_error + std_error * z;
485
486                let true_score = 0.5 + rand() * 0.5; // Random true score in [0.5, 1.0]
487                let proxy_score = true_score + error;
488
489                calibrator.record_error(list_idx, proxy_score, true_score);
490            }
491        }
492
493        calibrator.finalize()
494    }
495}
496
497/// Dot product helper
498fn dot_product(a: &[f32], b: &[f32]) -> f32 {
499    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_error_sample() {
508        let sample = ErrorSample::new(0.92, 0.90);
509        assert!((sample.error - 0.02).abs() < 1e-6);
510    }
511
512    #[test]
513    fn test_calibrator() {
514        let mut calibrator = ErrorCalibrator::new(3);
515
516        // Add samples for list 0
517        calibrator.record_error(0, 0.90, 0.88);
518        calibrator.record_error(0, 0.85, 0.82);
519        calibrator.record_error(0, 0.92, 0.91);
520        calibrator.record_error(0, 0.88, 0.85);
521        calibrator.record_error(0, 0.95, 0.90);
522
523        let envelopes = calibrator.finalize();
524
525        assert!(envelopes.envelopes[0].sample_count == 5);
526        assert!(envelopes.envelopes[0].mean_error > 0.0);
527        assert!(envelopes.envelopes[0].max_error > envelopes.envelopes[0].mean_error);
528    }
529
530    #[test]
531    fn test_envelope_threshold() {
532        let mut quantiles = HashMap::new();
533        quantiles.insert(9500, 0.05); // 95% quantile: error ≤ 0.05
534        quantiles.insert(9900, 0.08); // 99% quantile: error ≤ 0.08
535
536        let envelope = ErrorEnvelope {
537            list_idx: 0,
538            quantiles,
539            mean_error: 0.03,
540            std_error: 0.02,
541            max_error: 0.10,
542            min_error: 0.00,
543            sample_count: 100,
544        };
545
546        // Proxy = 0.90, 95% confidence
547        // Safe true threshold = 0.90 - 0.05 = 0.85
548        let threshold = envelope.safe_true_threshold(0.90, 0.95);
549        assert!((threshold - 0.85).abs() < 0.01);
550
551        // 99% confidence needs larger margin
552        let threshold99 = envelope.safe_true_threshold(0.90, 0.99);
553        assert!((threshold99 - 0.82).abs() < 0.01);
554    }
555
556    #[test]
557    fn test_can_terminate() {
558        let envelopes = CalibrationRunner::calibrate_synthetic(5, 0.03, 0.01, 100);
559
560        // If kth proxy is high enough, should be able to terminate
561        let kth_proxy = 0.95;
562        let remaining = vec![(1, 0.70), (2, 0.65)]; // Low bounds
563
564        let can_term = envelopes.can_terminate(kth_proxy, &remaining, 0.99);
565        assert!(
566            can_term,
567            "Should be able to terminate with high kth and low bounds"
568        );
569
570        // If bounds are high, should not terminate
571        let remaining_high = vec![(1, 0.94), (2, 0.93)];
572        let cannot_term = envelopes.can_terminate(kth_proxy, &remaining_high, 0.99);
573        assert!(!cannot_term, "Should not terminate with close bounds");
574    }
575
576    #[test]
577    fn test_synthetic_calibration() {
578        let envelopes = CalibrationRunner::calibrate_synthetic(10, 0.02, 0.01, 500);
579
580        assert_eq!(envelopes.envelopes.len(), 10);
581        assert!(envelopes.global.sample_count > 0);
582
583        // Mean should be close to synthetic mean
584        assert!((envelopes.global.mean_error - 0.02).abs() < 0.01);
585    }
586}