Skip to main content

turboquant/codebook/
gen.rs

1//! Lloyd-Max codebook generation for the Beta distribution.
2//!
3//! This module contains the Lloyd-Max iterative algorithm and all generation
4//! helpers.  It is separated from [`crate::codebook`] (which owns the
5//! [`Codebook`] struct, static lookup tables, and the Beta PDF) to respect
6//! the Single Responsibility Principle.
7
8use super::{centroid_count, Codebook, SUPPORT_MAX, SUPPORT_MIN};
9use crate::math::{converge, ln_gamma, simpsons_integrate, HALF};
10
11// ---------------------------------------------------------------------------
12// Constants — generation-specific
13// ---------------------------------------------------------------------------
14
15/// Maximum number of Lloyd-Max iterations before we declare convergence.
16const MAX_ITERATIONS: usize = 200;
17
18/// Convergence threshold: stop when the relative change in distortion drops
19/// below this value.
20const CONVERGENCE_EPS: f64 = 1e-12;
21
22/// Number of sub-intervals used for Simpson's rule integration.
23const INTEGRATION_STEPS: usize = 1024;
24
25/// Small epsilon to guard against division by near-zero values.
26const EPSILON_ZERO: f64 = 1e-30;
27
28/// Minimum dimension for which the Beta-type PDF is well-defined.
29/// For d < 3 the exponent (d-3)/2 is negative and the distribution degenerates.
30const MIN_DIMENSION_FOR_PDF: usize = 3;
31
32/// The exponent offset in the Beta-type kernel: (d - 3) / 2.
33const KERNEL_EXPONENT_OFFSET: f64 = 3.0;
34
35// ---------------------------------------------------------------------------
36// Pure Operation: Beta PDF
37// ---------------------------------------------------------------------------
38
39/// Evaluate the Beta-type PDF of a rotated unit-vector coordinate.
40///
41/// ```text
42/// f_X(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1 - x^2)^((d-3)/2)
43/// ```
44///
45/// Pure Operation: all arithmetic (kernel + normalization) is computed
46/// inline without calls to other project functions.
47pub fn beta_pdf(x: f64, d: usize) -> f64 {
48    // Guard: dimension too low.
49    if d < MIN_DIMENSION_FOR_PDF {
50        return 0.0;
51    }
52    let df = d as f64;
53    let exponent = (df - KERNEL_EXPONENT_OFFSET) * HALF;
54    let one_minus_x2 = 1.0 - x * x;
55    if one_minus_x2 <= 0.0 {
56        return 0.0;
57    }
58    let kernel = one_minus_x2.powf(exponent);
59
60    // Normalization: ln(Gamma(d/2)) - 0.5*ln(pi) - ln(Gamma((d-1)/2))
61    let half_df = df * HALF;
62    let half_df_minus_one = (df - 1.0) * HALF;
63    let half_ln_pi = HALF * core::f64::consts::PI.ln();
64    let log_norm = ln_gamma(half_df) - half_ln_pi - ln_gamma(half_df_minus_one);
65
66    log_norm.exp() * kernel
67}
68
69// ---------------------------------------------------------------------------
70// Pure Operation: initialization
71// ---------------------------------------------------------------------------
72
73/// Place `k` centroids uniformly on `(SUPPORT_MIN, SUPPORT_MAX)` (excluding endpoints).
74fn initialize_centroids(k: usize) -> Vec<f64> {
75    let range = SUPPORT_MAX - SUPPORT_MIN; // 2.0
76    (0..k)
77        .map(|i| SUPPORT_MIN + (range * (i as f64 + HALF)) / k as f64)
78        .collect()
79}
80
81/// Compute midpoint boundaries between adjacent centroids.
82fn midpoint_boundaries(centroids: &[f64]) -> Vec<f64> {
83    centroids.windows(2).map(|w| (w[0] + w[1]) * HALF).collect()
84}
85
86// ---------------------------------------------------------------------------
87// Pure Operation: bin geometry
88// ---------------------------------------------------------------------------
89
90/// Determine the lower bound of the i-th bin given boundaries.
91fn bin_lower_bound(i: usize, boundaries: &[f64]) -> f64 {
92    if i == 0 {
93        SUPPORT_MIN
94    } else {
95        boundaries[i - 1]
96    }
97}
98
99/// Determine the upper bound of the i-th bin given boundaries and total
100/// number of centroids `k`.
101fn bin_upper_bound(i: usize, k: usize, boundaries: &[f64]) -> f64 {
102    if i == k - 1 {
103        SUPPORT_MAX
104    } else {
105        boundaries[i]
106    }
107}
108
109// ---------------------------------------------------------------------------
110// Pure Operation: convergence check & conditional selection
111// ---------------------------------------------------------------------------
112
113/// Check whether the Lloyd-Max iteration has converged by comparing the
114/// relative change in distortion against [`CONVERGENCE_EPS`].
115fn has_converged(prev_distortion: f64, distortion: f64) -> bool {
116    (prev_distortion - distortion).abs() < CONVERGENCE_EPS * prev_distortion.abs().max(EPSILON_ZERO)
117}
118
119/// Select the conditional expectation or the interval midpoint depending
120/// on whether the denominator is near zero.
121///
122/// Pure Operation: only arithmetic and comparison, no calls.
123fn select_conditional_or_midpoint(numerator: f64, denominator: f64, a: f64, b: f64) -> f64 {
124    if denominator.abs() < EPSILON_ZERO {
125        (a + b) * HALF
126    } else {
127        numerator / denominator
128    }
129}
130
131// ---------------------------------------------------------------------------
132// Pure Integration: numerical integration wrappers
133// ---------------------------------------------------------------------------
134
135/// Simpson's rule numerical integration of `f` over `[a, b]`, using the
136/// module-level [`INTEGRATION_STEPS`] constant.
137///
138/// Pure Integration: delegates to [`crate::math::simpsons_integrate`].
139fn integrate<F: Fn(f64) -> f64>(f: F, a: f64, b: f64) -> f64 {
140    simpsons_integrate(f, a, b, INTEGRATION_STEPS)
141}
142
143/// Compute `integral_a^b f(x) dx` where `f(x) = beta_pdf(x, d)`.
144///
145/// Pure Integration: delegates to `integrate` and `beta_pdf`.
146fn integrate_pdf(a: f64, b: f64, d: usize) -> f64 {
147    integrate(|x| beta_pdf(x, d), a, b)
148}
149
150/// Compute `integral_a^b x * f(x) dx` where `f(x) = beta_pdf(x, d)`.
151///
152/// Pure Integration: delegates to `integrate` and `beta_pdf`.
153fn integrate_x_pdf(a: f64, b: f64, d: usize) -> f64 {
154    integrate(|x| x * beta_pdf(x, d), a, b)
155}
156
157/// Conditional expectation `E[X | X in [a, b]]` under the Beta-type PDF.
158///
159/// Pure Integration: delegates computation to `integrate_pdf`,
160/// `integrate_x_pdf`, and `select_conditional_or_midpoint`.
161fn conditional_expectation(a: f64, b: f64, d: usize) -> f64 {
162    let denom = integrate_pdf(a, b, d);
163    let numer = integrate_x_pdf(a, b, d);
164    select_conditional_or_midpoint(numer, denom, a, b)
165}
166
167// ---------------------------------------------------------------------------
168// Pure Integration: distortion computation
169// ---------------------------------------------------------------------------
170
171/// Compute the MSE-distortion contribution of a single bin `[lo, hi]` with
172/// centroid `c` under the Beta PDF for dimension `d`.
173///
174/// Pure Integration: delegates to `integrate` and `beta_pdf`.
175fn bin_distortion(lo: f64, hi: f64, c: f64, d: usize) -> f64 {
176    integrate(|x| (x - c).powi(2) * beta_pdf(x, d), lo, hi)
177}
178
179/// Compute the MSE distortion of the current codebook under the Beta PDF.
180///
181/// Pure Integration: delegates bin bounds to `bin_lower_bound`/`bin_upper_bound`
182/// and per-bin distortion to `bin_distortion`.  Uses an iterator chain instead
183/// of explicit loop logic.
184fn compute_distortion(centroids: &[f64], boundaries: &[f64], d: usize) -> f64 {
185    let k = centroids.len();
186    centroids
187        .iter()
188        .enumerate()
189        .map(|(i, &centroid)| {
190            let lo = bin_lower_bound(i, boundaries);
191            let hi = bin_upper_bound(i, k, boundaries);
192            bin_distortion(lo, hi, centroid, d)
193        })
194        .sum()
195}
196
197// ---------------------------------------------------------------------------
198// Pure Integration: centroid update
199// ---------------------------------------------------------------------------
200
201/// Compute updated centroids for one Lloyd-Max iteration.
202///
203/// Pure Integration: delegates bin bounds to `bin_lower_bound`/`bin_upper_bound`
204/// and centroid updates to `conditional_expectation`.  Uses an iterator chain
205/// instead of explicit loop logic.
206fn update_centroids(centroids_len: usize, boundaries: &[f64], d: usize) -> Vec<f64> {
207    (0..centroids_len)
208        .map(|i| {
209            let lo = bin_lower_bound(i, boundaries);
210            let hi = bin_upper_bound(i, centroids_len, boundaries);
211            conditional_expectation(lo, hi, d)
212        })
213        .collect()
214}
215
216// ---------------------------------------------------------------------------
217// Lloyd-Max core — Pure Integration (orchestrates operation helpers)
218// ---------------------------------------------------------------------------
219
220/// Perform one Lloyd-Max iteration step: compute boundaries, update centroids,
221/// measure distortion, and check convergence.
222///
223/// Pure Integration: delegates to `midpoint_boundaries`, `update_centroids`,
224/// `compute_distortion`, and `has_converged`.  Returns the new centroids,
225/// the new distortion, and a convergence flag.
226fn lloyd_max_step(centroids: &[f64], prev_distortion: f64, d: usize) -> (Vec<f64>, f64, bool) {
227    let boundaries = midpoint_boundaries(centroids);
228    let new_centroids = update_centroids(centroids.len(), &boundaries, d);
229    let distortion = compute_distortion(&new_centroids, &boundaries, d);
230    let converged = has_converged(prev_distortion, distortion);
231    (new_centroids, distortion, converged)
232}
233
234/// Run Lloyd-Max iterations starting from the given initial `centroids` for
235/// dimension `d`.  Returns the converged [`Codebook`].
236///
237/// Pure Integration: delegates each iteration to `lloyd_max_step` and
238/// final boundary computation to `midpoint_boundaries`.
239fn lloyd_max_iterate(mut centroids: Vec<f64>, d: usize) -> Codebook {
240    let mut prev_distortion = f64::MAX;
241
242    converge(MAX_ITERATIONS, || {
243        let (new_centroids, distortion, converged) = lloyd_max_step(&centroids, prev_distortion, d);
244        centroids = new_centroids;
245        prev_distortion = distortion;
246        converged
247    });
248
249    let boundaries = midpoint_boundaries(&centroids);
250    Codebook {
251        centroids,
252        boundaries,
253    }
254}
255
256// ---------------------------------------------------------------------------
257// Public API
258// ---------------------------------------------------------------------------
259
260/// Run the Lloyd-Max algorithm from scratch for arbitrary `(bits, dim)`.
261///
262/// Pure Integration: delegates centroid count to `centroid_count`,
263/// initialization to `initialize_centroids`, and iteration to
264/// `lloyd_max_iterate`.
265pub fn generate_codebook(bits: u8, dim: usize) -> Codebook {
266    let k = centroid_count(bits);
267    let centroids = initialize_centroids(k);
268    lloyd_max_iterate(centroids, dim)
269}
270
271// ---------------------------------------------------------------------------
272// Unit tests for generation helpers
273// ---------------------------------------------------------------------------
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use approx::assert_relative_eq;
279
280    // -- Named constants for test parameters --------------------------------
281
282    /// Dimension used in integration / beta-PDF tests.
283    const TEST_DIM: usize = 128;
284    /// Number of centroids when using 3-bit quantization (2^3).
285    const TEST_CENTROIDS_8: usize = 8;
286    /// Number of centroids when using 4-bit quantization (2^4).
287    const TEST_CENTROIDS_16: usize = 16;
288    /// Bit width for 3-bit quantization in tests.
289    const TEST_BITS_3: u8 = 3;
290    /// Dimension 64 used in generate_codebook tests.
291    const TEST_DIM_64: usize = 64;
292    /// Numerator used in select_conditional_or_midpoint tests.
293    const TEST_NUMERATOR: f64 = 3.0;
294    /// Normal-case denominator for select_conditional_or_midpoint test.
295    const TEST_DENOMINATOR: f64 = 2.0;
296    /// Near-zero denominator for select_conditional_or_midpoint fallback test.
297    const TEST_NEAR_ZERO_DENOM: f64 = 1e-31;
298
299    // -- initialize_centroids -----------------------------------------------
300
301    #[test]
302    fn initialize_centroids_correct_count() {
303        assert_eq!(
304            initialize_centroids(TEST_CENTROIDS_8).len(),
305            TEST_CENTROIDS_8
306        );
307        assert_eq!(
308            initialize_centroids(TEST_CENTROIDS_16).len(),
309            TEST_CENTROIDS_16
310        );
311    }
312
313    #[test]
314    fn initialize_centroids_sorted() {
315        let c = initialize_centroids(TEST_CENTROIDS_8);
316        for w in c.windows(2) {
317            assert!(w[0] < w[1]);
318        }
319    }
320
321    #[test]
322    fn initialize_centroids_symmetric() {
323        let c = initialize_centroids(TEST_CENTROIDS_8);
324        let half = TEST_CENTROIDS_8 / 2;
325        for i in 0..half {
326            assert_relative_eq!(c[i], -c[TEST_CENTROIDS_8 - 1 - i], epsilon = 1e-14);
327        }
328    }
329
330    #[test]
331    fn initialize_centroids_within_support() {
332        let c = initialize_centroids(TEST_CENTROIDS_16);
333        for &v in &c {
334            assert!(v > SUPPORT_MIN && v < SUPPORT_MAX);
335        }
336    }
337
338    // -- midpoint_boundaries ------------------------------------------------
339
340    #[test]
341    fn midpoint_boundaries_correct_values() {
342        let centroids = vec![-0.5, 0.0, 0.5];
343        let b = midpoint_boundaries(&centroids);
344        assert_eq!(b.len(), 2);
345        assert_relative_eq!(b[0], -0.25, epsilon = 1e-14);
346        assert_relative_eq!(b[1], 0.25, epsilon = 1e-14);
347    }
348
349    // -- bin_lower_bound / bin_upper_bound -----------------------------------
350
351    #[test]
352    fn bin_lower_bound_first() {
353        let boundaries = vec![0.0];
354        assert_relative_eq!(
355            bin_lower_bound(0, &boundaries),
356            SUPPORT_MIN,
357            epsilon = 1e-15
358        );
359    }
360
361    #[test]
362    fn bin_lower_bound_second() {
363        let boundaries = vec![0.0];
364        assert_relative_eq!(bin_lower_bound(1, &boundaries), 0.0, epsilon = 1e-15);
365    }
366
367    #[test]
368    fn bin_upper_bound_last() {
369        let boundaries = vec![0.0];
370        assert_relative_eq!(
371            bin_upper_bound(1, 2, &boundaries),
372            SUPPORT_MAX,
373            epsilon = 1e-15
374        );
375    }
376
377    #[test]
378    fn bin_upper_bound_first() {
379        let boundaries = vec![0.0];
380        assert_relative_eq!(bin_upper_bound(0, 2, &boundaries), 0.0, epsilon = 1e-15);
381    }
382
383    // -- has_converged ------------------------------------------------------
384
385    #[test]
386    fn has_converged_identical_values() {
387        assert!(has_converged(1.0, 1.0));
388    }
389
390    #[test]
391    fn has_converged_large_change() {
392        assert!(!has_converged(1.0, 0.5));
393    }
394
395    // -- select_conditional_or_midpoint -------------------------------------
396
397    #[test]
398    fn select_conditional_or_midpoint_normal_case() {
399        let result = select_conditional_or_midpoint(TEST_NUMERATOR, TEST_DENOMINATOR, 0.0, 1.0);
400        assert_relative_eq!(result, TEST_NUMERATOR / TEST_DENOMINATOR, epsilon = 1e-15);
401    }
402
403    #[test]
404    fn select_conditional_or_midpoint_near_zero_denom() {
405        let result = select_conditional_or_midpoint(TEST_NUMERATOR, TEST_NEAR_ZERO_DENOM, 0.0, 1.0);
406        assert_relative_eq!(result, 0.5, epsilon = 1e-15);
407    }
408
409    // -- conditional_expectation --------------------------------------------
410
411    #[test]
412    fn conditional_expectation_symmetric_interval() {
413        // E[X | X in [-1, 1]] should be 0 by symmetry.
414        let result = conditional_expectation(SUPPORT_MIN, SUPPORT_MAX, TEST_DIM);
415        assert_relative_eq!(result, 0.0, epsilon = 1e-8);
416    }
417
418    // -- compute_distortion -------------------------------------------------
419
420    #[test]
421    fn compute_distortion_nonnegative() {
422        let centroids = vec![-0.5, 0.0, 0.5];
423        let boundaries = vec![-0.25, 0.25];
424        let d = compute_distortion(&centroids, &boundaries, TEST_DIM);
425        assert!(d >= 0.0);
426    }
427
428    // -- update_centroids ---------------------------------------------------
429
430    #[test]
431    fn update_centroids_correct_count() {
432        let boundaries = midpoint_boundaries(&initialize_centroids(TEST_CENTROIDS_8));
433        let updated = update_centroids(TEST_CENTROIDS_8, &boundaries, TEST_DIM);
434        assert_eq!(updated.len(), TEST_CENTROIDS_8);
435    }
436
437    #[test]
438    fn update_centroids_within_support() {
439        let boundaries = midpoint_boundaries(&initialize_centroids(TEST_CENTROIDS_8));
440        let updated = update_centroids(TEST_CENTROIDS_8, &boundaries, TEST_DIM);
441        for &c in &updated {
442            assert!((SUPPORT_MIN..=SUPPORT_MAX).contains(&c));
443        }
444    }
445
446    // -- generate_codebook --------------------------------------------------
447
448    #[test]
449    fn generate_codebook_valid_structure() {
450        let cb = generate_codebook(TEST_BITS_3, TEST_DIM_64);
451        assert_eq!(cb.centroids.len(), TEST_CENTROIDS_8);
452        assert_eq!(cb.boundaries.len(), TEST_CENTROIDS_8 - 1);
453        for w in cb.centroids.windows(2) {
454            assert!(w[0] < w[1]);
455        }
456    }
457
458    // -- lloyd_max_step -----------------------------------------------------
459
460    #[test]
461    fn lloyd_max_step_reduces_distortion() {
462        let centroids = initialize_centroids(TEST_CENTROIDS_8);
463        let boundaries = midpoint_boundaries(&centroids);
464        let initial_dist = compute_distortion(&centroids, &boundaries, TEST_DIM);
465        let (new_centroids, new_dist, _) = lloyd_max_step(&centroids, f64::MAX, TEST_DIM);
466        // The new distortion should be <= initial (Lloyd-Max is monotonically improving).
467        assert!(new_dist <= initial_dist + 1e-15);
468        assert_eq!(new_centroids.len(), TEST_CENTROIDS_8);
469    }
470}