Skip to main content

selene_core/vector/turbo_quant/
blocked.rs

1use super::{
2    TurboQuantBitWidth, TurboQuantCodecError, TurboQuantCodecResult, TurboQuantPackedCodes,
3    bytes_per_row, validate_dimension,
4};
5
6/// Number of vector rows stored together in blocked TurboQuant code storage.
7pub const TURBO_QUANT_BLOCK_ROWS: usize = 32;
8
9/// Block-major packed TurboQuant coordinate codes.
10///
11/// Rows are grouped in fixed 32-row blocks. For each block, all rows for packed
12/// byte 0 are stored contiguously, then all rows for packed byte 1, and so on.
13/// This preserves the same packed row representation as [`TurboQuantPackedCodes`]
14/// while making search-time scans load one byte position across a whole block.
15#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct TurboQuantBlockedCodes {
17    bit_width: TurboQuantBitWidth,
18    dimensions: usize,
19    rows: usize,
20    bytes_per_row: usize,
21    bytes: Vec<u8>,
22}
23
24impl TurboQuantBlockedCodes {
25    /// Allocate zero-filled blocked-code storage.
26    ///
27    /// # Errors
28    ///
29    /// Returns an error when dimensions are invalid or the computed byte size
30    /// overflows `usize`.
31    pub fn new(
32        bit_width: TurboQuantBitWidth,
33        dimensions: usize,
34        rows: usize,
35    ) -> TurboQuantCodecResult<Self> {
36        let bytes_per_row = bytes_per_row(bit_width, dimensions)?;
37        let byte_len = byte_len(bytes_per_row, rows)?;
38        Ok(Self {
39            bit_width,
40            dimensions,
41            rows,
42            bytes_per_row,
43            bytes: vec![0; byte_len],
44        })
45    }
46
47    /// Repack row-major packed codes into block-major storage.
48    ///
49    /// # Errors
50    ///
51    /// Returns an error if the blocked storage size overflows `usize`.
52    pub fn from_row_major(codes: &TurboQuantPackedCodes) -> TurboQuantCodecResult<Self> {
53        let mut blocked = Self::new(codes.bit_width(), codes.dimensions(), codes.rows())?;
54        for row in 0..codes.rows() {
55            let source = row * codes.bytes_per_row();
56            for byte in 0..codes.bytes_per_row() {
57                blocked.set_row_byte(row, byte, codes.as_bytes()[source + byte]);
58            }
59        }
60        Ok(blocked)
61    }
62
63    /// Return the bit width used by the blocked codes.
64    #[must_use]
65    pub const fn bit_width(&self) -> TurboQuantBitWidth {
66        self.bit_width
67    }
68
69    /// Return the number of dimensions encoded in each row.
70    #[must_use]
71    pub const fn dimensions(&self) -> usize {
72        self.dimensions
73    }
74
75    /// Return the number of encoded rows.
76    #[must_use]
77    pub const fn rows(&self) -> usize {
78        self.rows
79    }
80
81    /// Return the byte stride for one packed row.
82    #[must_use]
83    pub const fn bytes_per_row(&self) -> usize {
84        self.bytes_per_row
85    }
86
87    /// Return the number of row blocks in the backing storage.
88    #[must_use]
89    pub fn block_count(&self) -> usize {
90        block_count(self.rows)
91    }
92
93    /// Return the number of real rows in `block`.
94    #[must_use]
95    pub fn block_len(&self, block: usize) -> usize {
96        debug_assert!(block < self.block_count());
97        let remaining = self.rows - block * TURBO_QUANT_BLOCK_ROWS;
98        remaining.min(TURBO_QUANT_BLOCK_ROWS)
99    }
100
101    /// Return the packed bytes for one byte position across a 32-row block.
102    ///
103    /// The returned slice always has [`TURBO_QUANT_BLOCK_ROWS`] bytes. Tail
104    /// lanes beyond [`Self::block_len`] are zero padding.
105    #[must_use]
106    pub fn block_byte(&self, block: usize, byte: usize) -> &[u8] {
107        debug_assert!(block < self.block_count());
108        debug_assert!(byte < self.bytes_per_row);
109        let offset = (block * self.bytes_per_row + byte) * TURBO_QUANT_BLOCK_ROWS;
110        &self.bytes[offset..offset + TURBO_QUANT_BLOCK_ROWS]
111    }
112
113    /// Return one packed row byte.
114    ///
115    /// # Errors
116    ///
117    /// Returns a bounds error when `row` or `byte` is outside the matrix.
118    pub fn row_byte(&self, row: usize, byte: usize) -> TurboQuantCodecResult<u8> {
119        let offset = self.byte_offset(row, byte)?;
120        Ok(self.bytes[offset])
121    }
122
123    /// Overwrite one packed row from caller-provided row-major bytes.
124    ///
125    /// # Errors
126    ///
127    /// Returns a bounds error when `row` is outside the matrix, or a byte-length
128    /// mismatch when `bytes` does not contain exactly one packed row.
129    pub fn write_row_bytes(&mut self, row: usize, bytes: &[u8]) -> TurboQuantCodecResult<()> {
130        self.validate_row(row)?;
131        if bytes.len() != self.bytes_per_row {
132            return Err(TurboQuantCodecError::ByteLengthMismatch {
133                expected: self.bytes_per_row,
134                actual: bytes.len(),
135            });
136        }
137        for (byte, value) in bytes.iter().copied().enumerate() {
138            self.set_row_byte(row, byte, value);
139        }
140        Ok(())
141    }
142
143    /// Return the blocked backing bytes.
144    #[must_use]
145    pub fn as_bytes(&self) -> &[u8] {
146        &self.bytes
147    }
148
149    /// Return the blocked-code byte footprint.
150    #[must_use]
151    pub fn estimated_bytes(&self) -> usize {
152        self.bytes.len()
153    }
154
155    /// Resize the row count while preserving existing packed rows.
156    ///
157    /// Newly added rows are zero-filled. Shrinking clears rows that remain in a
158    /// retained tail block, so later growth cannot expose stale packed bytes.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error when the computed byte size overflows `usize`.
163    pub fn resize_rows(&mut self, rows: usize) -> TurboQuantCodecResult<()> {
164        validate_dimension(self.dimensions)?;
165        let old_rows = self.rows;
166        let byte_len = byte_len(self.bytes_per_row, rows)?;
167        self.bytes.resize(byte_len, 0);
168        for row in old_rows.min(rows)..old_rows.max(rows) {
169            for byte in 0..self.bytes_per_row {
170                if let Some(offset) = self.byte_offset_if_allocated(row, byte) {
171                    self.bytes[offset] = 0;
172                }
173            }
174        }
175        self.rows = rows;
176        Ok(())
177    }
178
179    /// Remove one row by moving the current last row into its slot.
180    ///
181    /// This preserves the packed row bytes without decoding individual
182    /// coordinate codes. Tail bytes are cleared through [`Self::resize_rows`].
183    ///
184    /// # Errors
185    ///
186    /// Returns a bounds error when `row` is outside the matrix, or an overflow
187    /// error if shrinking the storage would overflow internal size accounting.
188    pub fn swap_remove_row(&mut self, row: usize) -> TurboQuantCodecResult<()> {
189        self.validate_row(row)?;
190        let last = self.rows - 1;
191        if row != last {
192            for byte in 0..self.bytes_per_row {
193                let source = self.byte_offset_unchecked(last, byte);
194                let destination = self.byte_offset_unchecked(row, byte);
195                self.bytes[destination] = self.bytes[source];
196            }
197        }
198        self.resize_rows(last)
199    }
200
201    /// Read one packed coordinate code.
202    ///
203    /// # Errors
204    ///
205    /// Returns bounds errors when `row` or `dimension` is outside the packed
206    /// matrix.
207    pub fn read(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<u8> {
208        let bit_offset = self.bit_offset(row, dimension)?;
209        let byte = bit_offset / u8::BITS as usize;
210        let shift = bit_offset % u8::BITS as usize;
211        let mut word = u16::from(self.bytes[self.byte_offset(row, byte)?]);
212        if byte + 1 < self.bytes_per_row {
213            word |= u16::from(self.bytes[self.byte_offset(row, byte + 1)?]) << u8::BITS;
214        }
215        let mask = (1_u16 << self.bit_width.bits()) - 1;
216        Ok(((word >> shift) & mask) as u8)
217    }
218
219    /// Write one packed coordinate code.
220    ///
221    /// # Errors
222    ///
223    /// Returns bounds errors when `row` or `dimension` is outside the packed
224    /// matrix, and [`TurboQuantCodecError::InvalidCode`] when `code` exceeds
225    /// this storage's bit width.
226    pub fn write(&mut self, row: usize, dimension: usize, code: u8) -> TurboQuantCodecResult<()> {
227        self.validate_code(code)?;
228        let bit_offset = self.bit_offset(row, dimension)?;
229        let byte = bit_offset / u8::BITS as usize;
230        let shift = bit_offset % u8::BITS as usize;
231        let mask = ((1_u16 << self.bit_width.bits()) - 1) << shift;
232        let first = self.byte_offset(row, byte)?;
233        let mut word = u16::from(self.bytes[first]);
234        let second = (byte + 1 < self.bytes_per_row)
235            .then(|| self.byte_offset(row, byte + 1))
236            .transpose()?;
237        if let Some(second) = second {
238            word |= u16::from(self.bytes[second]) << u8::BITS;
239        }
240        word = (word & !mask) | (u16::from(code) << shift);
241        self.bytes[first] = (word & 0xff) as u8;
242        if shift + usize::from(self.bit_width.bits()) > u8::BITS as usize
243            && let Some(second) = second
244        {
245            self.bytes[second] = (word >> u8::BITS) as u8;
246        }
247        Ok(())
248    }
249
250    fn validate_code(&self, code: u8) -> TurboQuantCodecResult<()> {
251        let max = self.bit_width.max_code();
252        if code <= max {
253            Ok(())
254        } else {
255            Err(TurboQuantCodecError::InvalidCode { code, max })
256        }
257    }
258
259    fn bit_offset(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<usize> {
260        self.validate_row(row)?;
261        if dimension >= self.dimensions {
262            return Err(TurboQuantCodecError::DimensionOutOfBounds {
263                dimension,
264                dimensions: self.dimensions,
265            });
266        }
267        dimension
268            .checked_mul(usize::from(self.bit_width.bits()))
269            .ok_or(TurboQuantCodecError::SizeOverflow)
270    }
271
272    fn byte_offset(&self, row: usize, byte: usize) -> TurboQuantCodecResult<usize> {
273        self.validate_row(row)?;
274        if byte >= self.bytes_per_row {
275            return Err(TurboQuantCodecError::DimensionOutOfBounds {
276                dimension: byte.saturating_mul(u8::BITS as usize),
277                dimensions: self.dimensions,
278            });
279        }
280        Ok(self.byte_offset_unchecked(row, byte))
281    }
282
283    fn validate_row(&self, row: usize) -> TurboQuantCodecResult<()> {
284        if row >= self.rows {
285            Err(TurboQuantCodecError::RowOutOfBounds {
286                row,
287                rows: self.rows,
288            })
289        } else {
290            Ok(())
291        }
292    }
293
294    fn set_row_byte(&mut self, row: usize, byte: usize, value: u8) {
295        let offset = self.byte_offset_unchecked(row, byte);
296        self.bytes[offset] = value;
297    }
298
299    fn byte_offset_if_allocated(&self, row: usize, byte: usize) -> Option<usize> {
300        let offset = self.byte_offset_unchecked(row, byte);
301        (offset < self.bytes.len()).then_some(offset)
302    }
303
304    fn byte_offset_unchecked(&self, row: usize, byte: usize) -> usize {
305        let block = row / TURBO_QUANT_BLOCK_ROWS;
306        let lane = row % TURBO_QUANT_BLOCK_ROWS;
307        (block * self.bytes_per_row + byte) * TURBO_QUANT_BLOCK_ROWS + lane
308    }
309}
310
311fn block_count(rows: usize) -> usize {
312    rows.div_ceil(TURBO_QUANT_BLOCK_ROWS)
313}
314
315fn byte_len(bytes_per_row: usize, rows: usize) -> TurboQuantCodecResult<usize> {
316    block_count(rows)
317        .checked_mul(bytes_per_row)
318        .and_then(|bytes| bytes.checked_mul(TURBO_QUANT_BLOCK_ROWS))
319        .ok_or(TurboQuantCodecError::SizeOverflow)
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn blocked_codes_match_row_major_reads() {
328        for bits in 2..=4 {
329            let bit_width = TurboQuantBitWidth::new(bits).unwrap();
330            let mut row_major = TurboQuantPackedCodes::new(bit_width, 11, 35).unwrap();
331            let mut blocked = TurboQuantBlockedCodes::new(bit_width, 11, 35).unwrap();
332            for row in 0..row_major.rows() {
333                for dimension in 0..row_major.dimensions() {
334                    let code = ((row * 3 + dimension) % bit_width.levels()) as u8;
335                    row_major.write(row, dimension, code).unwrap();
336                    blocked.write(row, dimension, code).unwrap();
337                }
338            }
339
340            for row in 0..row_major.rows() {
341                for dimension in 0..row_major.dimensions() {
342                    assert_eq!(
343                        blocked.read(row, dimension).unwrap(),
344                        row_major.read(row, dimension).unwrap()
345                    );
346                }
347            }
348        }
349    }
350
351    #[test]
352    fn row_major_repack_uses_block_byte_layout() {
353        let bit_width = TurboQuantBitWidth::new(4).unwrap();
354        let mut row_major = TurboQuantPackedCodes::new(bit_width, 4, 35).unwrap();
355        for row in 0..row_major.rows() {
356            for dimension in 0..row_major.dimensions() {
357                row_major
358                    .write(row, dimension, ((row + dimension) % 16) as u8)
359                    .unwrap();
360            }
361        }
362
363        let blocked = TurboQuantBlockedCodes::from_row_major(&row_major).unwrap();
364
365        assert_eq!(blocked.block_count(), 2);
366        assert_eq!(blocked.block_len(0), TURBO_QUANT_BLOCK_ROWS);
367        assert_eq!(blocked.block_len(1), 3);
368        for byte in 0..row_major.bytes_per_row() {
369            let block_byte = blocked.block_byte(0, byte);
370            for (row, packed) in block_byte.iter().enumerate() {
371                assert_eq!(
372                    *packed,
373                    row_major.as_bytes()[row * row_major.bytes_per_row() + byte]
374                );
375            }
376        }
377    }
378
379    #[test]
380    fn write_row_bytes_overwrites_one_blocked_row() {
381        let bit_width = TurboQuantBitWidth::new(4).unwrap();
382        let mut blocked = TurboQuantBlockedCodes::new(bit_width, 4, 35).unwrap();
383
384        blocked.write_row_bytes(33, &[0x21, 0x43]).unwrap();
385
386        assert_eq!(blocked.read(33, 0).unwrap(), 1);
387        assert_eq!(blocked.read(33, 1).unwrap(), 2);
388        assert_eq!(blocked.read(33, 2).unwrap(), 3);
389        assert_eq!(blocked.read(33, 3).unwrap(), 4);
390        assert_eq!(blocked.block_byte(1, 0)[1], 0x21);
391        assert_eq!(blocked.block_byte(1, 1)[1], 0x43);
392    }
393
394    #[test]
395    fn write_row_bytes_rejects_wrong_length() {
396        let bit_width = TurboQuantBitWidth::new(4).unwrap();
397        let mut blocked = TurboQuantBlockedCodes::new(bit_width, 4, 1).unwrap();
398
399        assert_eq!(
400            blocked.write_row_bytes(0, &[0x21]).unwrap_err(),
401            TurboQuantCodecError::ByteLengthMismatch {
402                expected: 2,
403                actual: 1
404            }
405        );
406    }
407
408    #[test]
409    fn resize_rows_clears_retained_tail_slots() {
410        let bit_width = TurboQuantBitWidth::new(4).unwrap();
411        let mut blocked = TurboQuantBlockedCodes::new(bit_width, 2, 4).unwrap();
412        blocked.write(3, 0, 15).unwrap();
413        blocked.resize_rows(2).unwrap();
414        blocked.resize_rows(4).unwrap();
415
416        assert_eq!(blocked.read(3, 0).unwrap(), 0);
417    }
418
419    #[test]
420    fn swap_remove_row_moves_last_row_and_clears_tail() {
421        for bits in 2..=4 {
422            let bit_width = TurboQuantBitWidth::new(bits).unwrap();
423            let mut blocked = TurboQuantBlockedCodes::new(bit_width, 11, 35).unwrap();
424            let last = blocked.rows() - 1;
425            let removed = 7;
426            let max_code = usize::from(bit_width.max_code());
427            let moved_codes = (0..blocked.dimensions())
428                .map(|dim| ((last * 5 + dim * 3) % (max_code + 1)) as u8)
429                .collect::<Vec<_>>();
430            for row in 0..blocked.rows() {
431                for dim in 0..blocked.dimensions() {
432                    let code = ((row * 5 + dim * 3) % (max_code + 1)) as u8;
433                    blocked.write(row, dim, code).unwrap();
434                }
435            }
436
437            blocked.swap_remove_row(removed).unwrap();
438
439            assert_eq!(blocked.rows(), last);
440            for (dim, expected) in moved_codes.into_iter().enumerate() {
441                assert_eq!(blocked.read(removed, dim).unwrap(), expected);
442            }
443            blocked.resize_rows(last + 1).unwrap();
444            for dim in 0..blocked.dimensions() {
445                assert_eq!(blocked.read(last, dim).unwrap(), 0);
446            }
447        }
448    }
449}