Skip to main content

trueno_rag/multivector/
codec.rs

1//! Residual quantization codec for WARP
2//!
3//! This module implements the residual quantization scheme used to compress
4//! token embeddings in the WARP algorithm. Each vector is decomposed into:
5//! - A centroid (learned via k-means)
6//! - A residual (difference from centroid), quantized to 2-4 bits per dimension
7//!
8//! The codec enables efficient scoring without full decompression by using
9//! precomputed centroid scores and bucket weights.
10
11use crate::multivector::types::WarpIndexConfig;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14
15/// Residual quantization codec for compressing token embeddings.
16///
17/// The codec learns centroids via k-means clustering, then quantizes the
18/// residuals (v - centroid) to a small number of bits per dimension.
19///
20/// # Compression Process
21///
22/// 1. Find nearest centroid for input vector
23/// 2. Compute residual = vector - centroid
24/// 3. Quantize each dimension to `nbits` using learned bucket boundaries
25/// 4. Pack quantized values into bytes
26///
27/// # Scoring
28///
29/// Score computation avoids full decompression:
30/// ```text
31/// q · v ≈ q · c + Σ_d q[d] × bucket_weight[d, code[d]]
32/// ```
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ResidualCodec {
35    /// Centroid vectors: [num_centroids × dim], flattened
36    centroids: Vec<f32>,
37    /// Number of centroids
38    num_centroids: usize,
39    /// Token dimension
40    dim: usize,
41    /// Quantization bucket boundaries per dimension: [dim × (num_buckets - 1)]
42    bucket_cutoffs: Vec<f32>,
43    /// Reconstruction weights per bucket: [dim × num_buckets]
44    bucket_weights: Vec<f32>,
45    /// Bits per dimension (2 or 4)
46    nbits: u8,
47}
48
49impl ResidualCodec {
50    /// Train a codec from sample embeddings.
51    ///
52    /// # Arguments
53    ///
54    /// * `embeddings` - Flattened sample embeddings [n × dim]
55    /// * `dim` - Embedding dimension
56    /// * `num_centroids` - Number of k-means centroids
57    /// * `nbits` - Bits per dimension (2 or 4)
58    /// * `iterations` - K-means iterations
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if training data is insufficient or parameters invalid.
63    pub fn train(
64        embeddings: &[f32],
65        dim: usize,
66        num_centroids: usize,
67        nbits: u8,
68        iterations: usize,
69    ) -> Result<Self> {
70        if nbits != 2 && nbits != 4 {
71            return Err(crate::Error::InvalidInput("nbits must be 2 or 4".to_string()));
72        }
73
74        let n = embeddings.len() / dim;
75        if n < num_centroids {
76            return Err(crate::Error::InvalidInput(format!(
77                "Insufficient training data: {n} samples for {num_centroids} centroids"
78            )));
79        }
80
81        // Step 1: K-means clustering to find centroids
82        let centroids = Self::kmeans_clustering(embeddings, dim, num_centroids, iterations);
83
84        // Step 2: Compute residuals for all training points
85        let residuals = Self::compute_all_residuals(embeddings, dim, &centroids, num_centroids);
86
87        // Step 3: Learn quantization boundaries from residual distribution
88        let (bucket_cutoffs, bucket_weights) =
89            Self::learn_quantization_params(&residuals, dim, nbits);
90
91        Ok(Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits })
92    }
93
94    /// Create a codec with pre-trained parameters.
95    #[must_use]
96    pub fn with_params(
97        centroids: Vec<f32>,
98        num_centroids: usize,
99        dim: usize,
100        bucket_cutoffs: Vec<f32>,
101        bucket_weights: Vec<f32>,
102        nbits: u8,
103    ) -> Self {
104        Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits }
105    }
106
107    /// Get the number of centroids.
108    #[must_use]
109    pub fn num_centroids(&self) -> usize {
110        self.num_centroids
111    }
112
113    /// Get the embedding dimension.
114    #[must_use]
115    pub fn dim(&self) -> usize {
116        self.dim
117    }
118
119    /// Get bits per dimension.
120    #[must_use]
121    pub fn nbits(&self) -> u8 {
122        self.nbits
123    }
124
125    /// Get the packed residual size in bytes.
126    #[must_use]
127    pub fn packed_size(&self) -> usize {
128        (self.dim * self.nbits as usize + 7) / 8
129    }
130
131    /// Get centroid slice by ID.
132    #[must_use]
133    pub fn centroid(&self, id: usize) -> &[f32] {
134        let start = id * self.dim;
135        &self.centroids[start..start + self.dim]
136    }
137
138    /// Get all centroids as a flat slice.
139    #[must_use]
140    pub fn centroids(&self) -> &[f32] {
141        &self.centroids
142    }
143
144    /// Find the nearest centroid for a vector.
145    #[must_use]
146    pub fn find_nearest_centroid(&self, embedding: &[f32]) -> usize {
147        let mut best_id = 0;
148        let mut best_dist = f32::MAX;
149
150        for c in 0..self.num_centroids {
151            let centroid = self.centroid(c);
152            let dist = Self::squared_distance(embedding, centroid);
153            if dist < best_dist {
154                best_dist = dist;
155                best_id = c;
156            }
157        }
158
159        best_id
160    }
161
162    /// Compress an embedding to (centroid_id, packed_residual).
163    #[must_use]
164    pub fn compress(&self, embedding: &[f32]) -> (usize, Vec<u8>) {
165        // Find nearest centroid
166        let centroid_id = self.find_nearest_centroid(embedding);
167        let centroid = self.centroid(centroid_id);
168
169        // Compute residual
170        let residual: Vec<f32> =
171            embedding.iter().zip(centroid.iter()).map(|(e, c)| e - c).collect();
172
173        // Quantize residual
174        let codes = self.quantize_residual(&residual);
175
176        // Pack codes into bytes
177        let packed = self.pack_codes(&codes);
178
179        (centroid_id, packed)
180    }
181
182    /// Compute score between query token and compressed document token.
183    ///
184    /// score ≈ q · d = q · c + q · r
185    ///
186    /// # Arguments
187    ///
188    /// * `query_token` - Query embedding
189    /// * `centroid_id` - Assigned centroid
190    /// * `centroid_score` - Precomputed q · c
191    /// * `packed_residual` - Packed quantized residual
192    #[must_use]
193    pub fn decompress_score(
194        &self,
195        query_token: &[f32],
196        centroid_id: usize,
197        centroid_score: f32,
198        packed_residual: &[u8],
199    ) -> f32 {
200        let _ = centroid_id; // Centroid info already in centroid_score
201
202        // Unpack residual codes
203        let codes = self.unpack_codes(packed_residual);
204
205        // Compute q · r using bucket weights
206        let num_buckets = 1usize << self.nbits;
207        let residual_score: f32 = codes
208            .iter()
209            .enumerate()
210            .map(|(d, &code)| {
211                let weight_idx = d * num_buckets + code as usize;
212                query_token[d] * self.bucket_weights[weight_idx]
213            })
214            .sum();
215
216        centroid_score + residual_score
217    }
218
219    /// Compute dot product between query and centroid.
220    #[must_use]
221    pub fn centroid_score(&self, query_token: &[f32], centroid_id: usize) -> f32 {
222        let centroid = self.centroid(centroid_id);
223        Self::dot_product(query_token, centroid)
224    }
225
226    /// Quantize a residual vector to codes.
227    fn quantize_residual(&self, residual: &[f32]) -> Vec<u8> {
228        let num_buckets = 1usize << self.nbits;
229
230        residual
231            .iter()
232            .enumerate()
233            .map(|(d, &value)| {
234                // Binary search for bucket
235                let cutoff_start = d * (num_buckets - 1);
236                let cutoffs = &self.bucket_cutoffs[cutoff_start..cutoff_start + num_buckets - 1];
237
238                // Find first cutoff >= value
239                cutoffs.iter().position(|&c| value < c).unwrap_or(num_buckets - 1) as u8
240            })
241            .collect()
242    }
243
244    /// Pack quantization codes into bytes.
245    fn pack_codes(&self, codes: &[u8]) -> Vec<u8> {
246        match self.nbits {
247            2 => {
248                // Pack 4 codes per byte
249                codes
250                    .chunks(4)
251                    .map(|chunk| {
252                        let mut byte = 0u8;
253                        for (i, &code) in chunk.iter().enumerate() {
254                            byte |= (code & 0x03) << (i * 2);
255                        }
256                        byte
257                    })
258                    .collect()
259            }
260            4 => {
261                // Pack 2 codes per byte
262                codes
263                    .chunks(2)
264                    .map(|chunk| {
265                        let low = chunk.first().copied().unwrap_or(0) & 0x0F;
266                        let high = chunk.get(1).copied().unwrap_or(0) & 0x0F;
267                        low | (high << 4)
268                    })
269                    .collect()
270            }
271            _ => panic!("Unsupported nbits: {}", self.nbits),
272        }
273    }
274
275    /// Unpack codes from packed bytes.
276    fn unpack_codes(&self, packed: &[u8]) -> Vec<u8> {
277        match self.nbits {
278            2 => packed
279                .iter()
280                .flat_map(|&byte| (0..4).map(move |i| (byte >> (i * 2)) & 0x03))
281                .take(self.dim)
282                .collect(),
283            4 => packed
284                .iter()
285                .flat_map(|&byte| vec![byte & 0x0F, (byte >> 4) & 0x0F])
286                .take(self.dim)
287                .collect(),
288            _ => panic!("Unsupported nbits: {}", self.nbits),
289        }
290    }
291
292    // ============ K-means Implementation ============
293
294    /// K-means clustering with k-means++ initialization.
295    fn kmeans_clustering(embeddings: &[f32], dim: usize, k: usize, iterations: usize) -> Vec<f32> {
296        let n = embeddings.len() / dim;
297
298        // K-means++ initialization
299        let mut centroids = Self::kmeans_plus_plus_init(embeddings, dim, k);
300        let mut assignments = vec![0usize; n];
301
302        for _ in 0..iterations {
303            // Assign points to nearest centroid
304            for i in 0..n {
305                let point = &embeddings[i * dim..(i + 1) * dim];
306                let mut best_dist = f32::MAX;
307                let mut best_c = 0;
308
309                for c in 0..k {
310                    let centroid = &centroids[c * dim..(c + 1) * dim];
311                    let dist = Self::squared_distance(point, centroid);
312                    if dist < best_dist {
313                        best_dist = dist;
314                        best_c = c;
315                    }
316                }
317                assignments[i] = best_c;
318            }
319
320            // Update centroids as mean of assigned points
321            let mut new_centroids = vec![0.0f32; k * dim];
322            let mut counts = vec![0usize; k];
323
324            for i in 0..n {
325                let c = assignments[i];
326                counts[c] += 1;
327                let point = &embeddings[i * dim..(i + 1) * dim];
328                for d in 0..dim {
329                    new_centroids[c * dim + d] += point[d];
330                }
331            }
332
333            for c in 0..k {
334                if counts[c] > 0 {
335                    for d in 0..dim {
336                        new_centroids[c * dim + d] /= counts[c] as f32;
337                    }
338                } else {
339                    // Keep old centroid if no points assigned
340                    for d in 0..dim {
341                        new_centroids[c * dim + d] = centroids[c * dim + d];
342                    }
343                }
344            }
345
346            centroids = new_centroids;
347        }
348
349        centroids
350    }
351
352    /// K-means++ initialization.
353    fn kmeans_plus_plus_init(embeddings: &[f32], dim: usize, k: usize) -> Vec<f32> {
354        let n = embeddings.len() / dim;
355        let mut centroids = Vec::with_capacity(k * dim);
356        let mut rng_state = 42u64; // Simple deterministic RNG
357
358        // Choose first centroid uniformly at random
359        let first_idx = Self::simple_random(&mut rng_state, n);
360        centroids.extend_from_slice(&embeddings[first_idx * dim..(first_idx + 1) * dim]);
361
362        let mut distances = vec![f32::MAX; n];
363
364        for _ in 1..k {
365            let num_centroids = centroids.len() / dim;
366
367            // Update distances to nearest centroid
368            for i in 0..n {
369                let point = &embeddings[i * dim..(i + 1) * dim];
370                let centroid = &centroids[(num_centroids - 1) * dim..num_centroids * dim];
371                let dist = Self::squared_distance(point, centroid);
372                distances[i] = distances[i].min(dist);
373            }
374
375            // Choose next centroid with probability proportional to D²
376            let total: f32 = distances.iter().sum();
377            if total <= 0.0 {
378                // All points are centroids already, pick random
379                let idx = Self::simple_random(&mut rng_state, n);
380                centroids.extend_from_slice(&embeddings[idx * dim..(idx + 1) * dim]);
381                continue;
382            }
383
384            let threshold = Self::simple_random_f32(&mut rng_state) * total;
385            let mut cumsum = 0.0f32;
386            let mut chosen = 0;
387
388            for (i, &d) in distances.iter().enumerate() {
389                cumsum += d;
390                if cumsum >= threshold {
391                    chosen = i;
392                    break;
393                }
394            }
395
396            centroids.extend_from_slice(&embeddings[chosen * dim..(chosen + 1) * dim]);
397        }
398
399        centroids
400    }
401
402    /// Compute residuals for all embeddings.
403    fn compute_all_residuals(
404        embeddings: &[f32],
405        dim: usize,
406        centroids: &[f32],
407        num_centroids: usize,
408    ) -> Vec<f32> {
409        let n = embeddings.len() / dim;
410        let mut residuals = Vec::with_capacity(n * dim);
411
412        for i in 0..n {
413            let point = &embeddings[i * dim..(i + 1) * dim];
414
415            // Find nearest centroid
416            let mut best_c = 0;
417            let mut best_dist = f32::MAX;
418            for c in 0..num_centroids {
419                let centroid = &centroids[c * dim..(c + 1) * dim];
420                let dist = Self::squared_distance(point, centroid);
421                if dist < best_dist {
422                    best_dist = dist;
423                    best_c = c;
424                }
425            }
426
427            // Compute residual
428            let centroid = &centroids[best_c * dim..(best_c + 1) * dim];
429            for d in 0..dim {
430                residuals.push(point[d] - centroid[d]);
431            }
432        }
433
434        residuals
435    }
436
437    /// Learn quantization bucket boundaries and weights from residuals.
438    fn learn_quantization_params(residuals: &[f32], dim: usize, nbits: u8) -> (Vec<f32>, Vec<f32>) {
439        let num_buckets = 1usize << nbits;
440        let n = residuals.len() / dim;
441
442        let mut cutoffs = Vec::with_capacity(dim * (num_buckets - 1));
443        let mut weights = Vec::with_capacity(dim * num_buckets);
444
445        for d in 0..dim {
446            // Collect residual values for dimension d
447            let mut values: Vec<f32> = (0..n).map(|i| residuals[i * dim + d]).collect();
448            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
449
450            // Quantile-based boundaries for equal-frequency buckets
451            for b in 1..num_buckets {
452                let quantile_idx = (b * n) / num_buckets;
453                cutoffs.push(values[quantile_idx.min(n - 1)]);
454            }
455
456            // Bucket weights = mean value in each bucket
457            for b in 0..num_buckets {
458                let start = (b * n) / num_buckets;
459                let end = ((b + 1) * n) / num_buckets;
460                let end = end.max(start + 1).min(n);
461
462                let sum: f32 = values[start..end].iter().sum();
463                let mean = sum / (end - start) as f32;
464                weights.push(mean);
465            }
466        }
467
468        (cutoffs, weights)
469    }
470
471    // ============ Math Utilities ============
472
473    fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
474        a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
475    }
476
477    fn dot_product(a: &[f32], b: &[f32]) -> f32 {
478        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
479    }
480
481    fn simple_random(state: &mut u64, max: usize) -> usize {
482        *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
483        ((*state >> 33) as usize) % max
484    }
485
486    fn simple_random_f32(state: &mut u64) -> f32 {
487        *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
488        ((*state >> 33) as f32) / (u32::MAX as f32)
489    }
490}
491
492/// Builder for creating a `ResidualCodec` from a `WarpIndexConfig`.
493pub struct ResidualCodecBuilder {
494    config: WarpIndexConfig,
495}
496
497impl ResidualCodecBuilder {
498    /// Create a new builder from config.
499    #[must_use]
500    pub fn new(config: WarpIndexConfig) -> Self {
501        Self { config }
502    }
503
504    /// Train the codec from sample embeddings.
505    pub fn train(&self, embeddings: &[f32]) -> Result<ResidualCodec> {
506        ResidualCodec::train(
507            embeddings,
508            self.config.token_dim,
509            self.config.num_centroids,
510            self.config.nbits,
511            self.config.kmeans_iterations,
512        )
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    fn generate_test_embeddings(n: usize, dim: usize) -> Vec<f32> {
521        let mut embeddings = Vec::with_capacity(n * dim);
522        let mut rng_state = 12345u64;
523
524        for _ in 0..(n * dim) {
525            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
526            let val = ((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
527            embeddings.push(val);
528        }
529
530        embeddings
531    }
532
533    // ============ Basic Codec Tests ============
534
535    #[test]
536    fn test_codec_train_2bit() {
537        let embeddings = generate_test_embeddings(1000, 32);
538        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
539
540        assert_eq!(codec.num_centroids(), 16);
541        assert_eq!(codec.dim(), 32);
542        assert_eq!(codec.nbits(), 2);
543    }
544
545    #[test]
546    fn test_codec_train_4bit() {
547        let embeddings = generate_test_embeddings(1000, 32);
548        let codec = ResidualCodec::train(&embeddings, 32, 16, 4, 5).unwrap();
549
550        assert_eq!(codec.nbits(), 4);
551    }
552
553    #[test]
554    fn test_codec_train_insufficient_data() {
555        let embeddings = generate_test_embeddings(5, 32);
556        let result = ResidualCodec::train(&embeddings, 32, 16, 2, 5);
557
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn test_codec_train_invalid_nbits() {
563        let embeddings = generate_test_embeddings(100, 32);
564        let result = ResidualCodec::train(&embeddings, 32, 16, 3, 5);
565
566        assert!(result.is_err());
567    }
568
569    // ============ Compression Tests ============
570
571    #[test]
572    fn test_codec_compress() {
573        let embeddings = generate_test_embeddings(500, 32);
574        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
575
576        let test_vec = &embeddings[0..32];
577        let (centroid_id, packed) = codec.compress(test_vec);
578
579        assert!(centroid_id < 16);
580        assert_eq!(packed.len(), codec.packed_size());
581    }
582
583    #[test]
584    fn test_codec_packed_size_2bit() {
585        let embeddings = generate_test_embeddings(500, 128);
586        let codec = ResidualCodec::train(&embeddings, 128, 16, 2, 5).unwrap();
587
588        // 128 dims × 2 bits = 256 bits = 32 bytes
589        assert_eq!(codec.packed_size(), 32);
590    }
591
592    #[test]
593    fn test_codec_packed_size_4bit() {
594        let embeddings = generate_test_embeddings(500, 128);
595        let codec = ResidualCodec::train(&embeddings, 128, 16, 4, 5).unwrap();
596
597        // 128 dims × 4 bits = 512 bits = 64 bytes
598        assert_eq!(codec.packed_size(), 64);
599    }
600
601    // ============ Pack/Unpack Tests ============
602
603    #[test]
604    fn test_pack_unpack_2bit() {
605        let embeddings = generate_test_embeddings(500, 8);
606        let codec = ResidualCodec::train(&embeddings, 8, 16, 2, 5).unwrap();
607
608        let codes: Vec<u8> = vec![0, 1, 2, 3, 0, 1, 2, 3];
609        let packed = codec.pack_codes(&codes);
610        let unpacked = codec.unpack_codes(&packed);
611
612        assert_eq!(codes, unpacked);
613    }
614
615    #[test]
616    fn test_pack_unpack_4bit() {
617        let embeddings = generate_test_embeddings(500, 8);
618        let codec = ResidualCodec::train(&embeddings, 8, 16, 4, 5).unwrap();
619
620        let codes: Vec<u8> = vec![0, 5, 10, 15, 1, 6, 11, 14];
621        let packed = codec.pack_codes(&codes);
622        let unpacked = codec.unpack_codes(&packed);
623
624        assert_eq!(codes, unpacked);
625    }
626
627    // ============ Scoring Tests ============
628
629    #[test]
630    fn test_decompress_score() {
631        let embeddings = generate_test_embeddings(500, 32);
632        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
633
634        let query = &embeddings[0..32];
635        let doc = &embeddings[32..64];
636
637        // Compress document
638        let (centroid_id, packed) = codec.compress(doc);
639
640        // Compute centroid score
641        let centroid_score = codec.centroid_score(query, centroid_id);
642
643        // Compute approximate score
644        let approx_score = codec.decompress_score(query, centroid_id, centroid_score, &packed);
645
646        // Compute exact score
647        let exact_score: f32 = query.iter().zip(doc.iter()).map(|(q, d)| q * d).sum();
648
649        // Approximate score should be close to exact (within reasonable tolerance)
650        let error = (approx_score - exact_score).abs();
651        assert!(
652            error < exact_score.abs() * 0.5 + 1.0,
653            "Error too large: approx={}, exact={}, error={}",
654            approx_score,
655            exact_score,
656            error
657        );
658    }
659
660    #[test]
661    fn test_centroid_score() {
662        let embeddings = generate_test_embeddings(500, 32);
663        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
664
665        let query = &embeddings[0..32];
666        let centroid = codec.centroid(0);
667
668        let expected: f32 = query.iter().zip(centroid.iter()).map(|(q, c)| q * c).sum();
669        let actual = codec.centroid_score(query, 0);
670
671        assert!((expected - actual).abs() < 1e-6);
672    }
673
674    // ============ K-means Tests ============
675
676    #[test]
677    fn test_find_nearest_centroid() {
678        let embeddings = generate_test_embeddings(500, 32);
679        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
680
681        // A centroid should be nearest to itself
682        let centroid_0 = codec.centroid(0).to_vec();
683        let nearest = codec.find_nearest_centroid(&centroid_0);
684        assert_eq!(nearest, 0);
685    }
686
687    // ============ Builder Tests ============
688
689    #[test]
690    fn test_codec_builder() {
691        let config = WarpIndexConfig::new(2, 16, 32).with_kmeans_iterations(5);
692        let builder = ResidualCodecBuilder::new(config);
693
694        let embeddings = generate_test_embeddings(500, 32);
695        let codec = builder.train(&embeddings).unwrap();
696
697        assert_eq!(codec.num_centroids(), 16);
698        assert_eq!(codec.dim(), 32);
699    }
700
701    // ============ Serialization Tests ============
702
703    #[test]
704    fn test_codec_serialization() {
705        let embeddings = generate_test_embeddings(500, 16);
706        let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 5).unwrap();
707
708        let json = serde_json::to_string(&codec).unwrap();
709        let deserialized: ResidualCodec = serde_json::from_str(&json).unwrap();
710
711        assert_eq!(codec.num_centroids(), deserialized.num_centroids());
712        assert_eq!(codec.dim(), deserialized.dim());
713        assert_eq!(codec.nbits(), deserialized.nbits());
714    }
715
716    // ============ Property-Based Tests ============
717
718    use proptest::prelude::*;
719
720    proptest! {
721        #[test]
722        fn prop_compress_produces_valid_centroid_id(
723            seed in 0u64..1000
724        ) {
725            let mut embeddings = Vec::with_capacity(200 * 16);
726            let mut rng_state = seed;
727            for _ in 0..(200 * 16) {
728                rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
729                embeddings.push(((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0);
730            }
731
732            let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 3).unwrap();
733            let test_vec = &embeddings[0..16];
734            let (centroid_id, _) = codec.compress(test_vec);
735
736            prop_assert!(centroid_id < 8);
737        }
738
739        #[test]
740        fn prop_packed_size_matches_config(
741            nbits in prop::sample::select(vec![2u8, 4]),
742            dim in 8usize..64
743        ) {
744            let num_samples = 100 * dim;
745            let embeddings = generate_test_embeddings(num_samples / dim, dim);
746
747            if let Ok(codec) = ResidualCodec::train(&embeddings, dim, 8, nbits, 3) {
748                let expected_size = (dim * nbits as usize + 7) / 8;
749                prop_assert_eq!(codec.packed_size(), expected_size);
750            }
751        }
752    }
753}