selene_core/vector/turbo_quant/
blocked.rs1use super::{
2 TurboQuantBitWidth, TurboQuantCodecError, TurboQuantCodecResult, TurboQuantPackedCodes,
3 bytes_per_row, validate_dimension,
4};
5
6pub const TURBO_QUANT_BLOCK_ROWS: usize = 32;
8
9#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct TurboQuantBlockedCodes {
17 bit_width: TurboQuantBitWidth,
18 dimensions: usize,
19 rows: usize,
20 bytes_per_row: usize,
21 bytes: Vec<u8>,
22}
23
24impl TurboQuantBlockedCodes {
25 pub fn new(
32 bit_width: TurboQuantBitWidth,
33 dimensions: usize,
34 rows: usize,
35 ) -> TurboQuantCodecResult<Self> {
36 let bytes_per_row = bytes_per_row(bit_width, dimensions)?;
37 let byte_len = byte_len(bytes_per_row, rows)?;
38 Ok(Self {
39 bit_width,
40 dimensions,
41 rows,
42 bytes_per_row,
43 bytes: vec![0; byte_len],
44 })
45 }
46
47 pub fn from_row_major(codes: &TurboQuantPackedCodes) -> TurboQuantCodecResult<Self> {
53 let mut blocked = Self::new(codes.bit_width(), codes.dimensions(), codes.rows())?;
54 for row in 0..codes.rows() {
55 let source = row * codes.bytes_per_row();
56 for byte in 0..codes.bytes_per_row() {
57 blocked.set_row_byte(row, byte, codes.as_bytes()[source + byte]);
58 }
59 }
60 Ok(blocked)
61 }
62
63 #[must_use]
65 pub const fn bit_width(&self) -> TurboQuantBitWidth {
66 self.bit_width
67 }
68
69 #[must_use]
71 pub const fn dimensions(&self) -> usize {
72 self.dimensions
73 }
74
75 #[must_use]
77 pub const fn rows(&self) -> usize {
78 self.rows
79 }
80
81 #[must_use]
83 pub const fn bytes_per_row(&self) -> usize {
84 self.bytes_per_row
85 }
86
87 #[must_use]
89 pub fn block_count(&self) -> usize {
90 block_count(self.rows)
91 }
92
93 #[must_use]
95 pub fn block_len(&self, block: usize) -> usize {
96 debug_assert!(block < self.block_count());
97 let remaining = self.rows - block * TURBO_QUANT_BLOCK_ROWS;
98 remaining.min(TURBO_QUANT_BLOCK_ROWS)
99 }
100
101 #[must_use]
106 pub fn block_byte(&self, block: usize, byte: usize) -> &[u8] {
107 debug_assert!(block < self.block_count());
108 debug_assert!(byte < self.bytes_per_row);
109 let offset = (block * self.bytes_per_row + byte) * TURBO_QUANT_BLOCK_ROWS;
110 &self.bytes[offset..offset + TURBO_QUANT_BLOCK_ROWS]
111 }
112
113 pub fn row_byte(&self, row: usize, byte: usize) -> TurboQuantCodecResult<u8> {
119 let offset = self.byte_offset(row, byte)?;
120 Ok(self.bytes[offset])
121 }
122
123 pub fn write_row_bytes(&mut self, row: usize, bytes: &[u8]) -> TurboQuantCodecResult<()> {
130 self.validate_row(row)?;
131 if bytes.len() != self.bytes_per_row {
132 return Err(TurboQuantCodecError::ByteLengthMismatch {
133 expected: self.bytes_per_row,
134 actual: bytes.len(),
135 });
136 }
137 for (byte, value) in bytes.iter().copied().enumerate() {
138 self.set_row_byte(row, byte, value);
139 }
140 Ok(())
141 }
142
143 #[must_use]
145 pub fn as_bytes(&self) -> &[u8] {
146 &self.bytes
147 }
148
149 #[must_use]
151 pub fn estimated_bytes(&self) -> usize {
152 self.bytes.len()
153 }
154
155 pub fn resize_rows(&mut self, rows: usize) -> TurboQuantCodecResult<()> {
164 validate_dimension(self.dimensions)?;
165 let old_rows = self.rows;
166 let byte_len = byte_len(self.bytes_per_row, rows)?;
167 self.bytes.resize(byte_len, 0);
168 for row in old_rows.min(rows)..old_rows.max(rows) {
169 for byte in 0..self.bytes_per_row {
170 if let Some(offset) = self.byte_offset_if_allocated(row, byte) {
171 self.bytes[offset] = 0;
172 }
173 }
174 }
175 self.rows = rows;
176 Ok(())
177 }
178
179 pub fn swap_remove_row(&mut self, row: usize) -> TurboQuantCodecResult<()> {
189 self.validate_row(row)?;
190 let last = self.rows - 1;
191 if row != last {
192 for byte in 0..self.bytes_per_row {
193 let source = self.byte_offset_unchecked(last, byte);
194 let destination = self.byte_offset_unchecked(row, byte);
195 self.bytes[destination] = self.bytes[source];
196 }
197 }
198 self.resize_rows(last)
199 }
200
201 pub fn read(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<u8> {
208 let bit_offset = self.bit_offset(row, dimension)?;
209 let byte = bit_offset / u8::BITS as usize;
210 let shift = bit_offset % u8::BITS as usize;
211 let mut word = u16::from(self.bytes[self.byte_offset(row, byte)?]);
212 if byte + 1 < self.bytes_per_row {
213 word |= u16::from(self.bytes[self.byte_offset(row, byte + 1)?]) << u8::BITS;
214 }
215 let mask = (1_u16 << self.bit_width.bits()) - 1;
216 Ok(((word >> shift) & mask) as u8)
217 }
218
219 pub fn write(&mut self, row: usize, dimension: usize, code: u8) -> TurboQuantCodecResult<()> {
227 self.validate_code(code)?;
228 let bit_offset = self.bit_offset(row, dimension)?;
229 let byte = bit_offset / u8::BITS as usize;
230 let shift = bit_offset % u8::BITS as usize;
231 let mask = ((1_u16 << self.bit_width.bits()) - 1) << shift;
232 let first = self.byte_offset(row, byte)?;
233 let mut word = u16::from(self.bytes[first]);
234 let second = (byte + 1 < self.bytes_per_row)
235 .then(|| self.byte_offset(row, byte + 1))
236 .transpose()?;
237 if let Some(second) = second {
238 word |= u16::from(self.bytes[second]) << u8::BITS;
239 }
240 word = (word & !mask) | (u16::from(code) << shift);
241 self.bytes[first] = (word & 0xff) as u8;
242 if shift + usize::from(self.bit_width.bits()) > u8::BITS as usize
243 && let Some(second) = second
244 {
245 self.bytes[second] = (word >> u8::BITS) as u8;
246 }
247 Ok(())
248 }
249
250 fn validate_code(&self, code: u8) -> TurboQuantCodecResult<()> {
251 let max = self.bit_width.max_code();
252 if code <= max {
253 Ok(())
254 } else {
255 Err(TurboQuantCodecError::InvalidCode { code, max })
256 }
257 }
258
259 fn bit_offset(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<usize> {
260 self.validate_row(row)?;
261 if dimension >= self.dimensions {
262 return Err(TurboQuantCodecError::DimensionOutOfBounds {
263 dimension,
264 dimensions: self.dimensions,
265 });
266 }
267 dimension
268 .checked_mul(usize::from(self.bit_width.bits()))
269 .ok_or(TurboQuantCodecError::SizeOverflow)
270 }
271
272 fn byte_offset(&self, row: usize, byte: usize) -> TurboQuantCodecResult<usize> {
273 self.validate_row(row)?;
274 if byte >= self.bytes_per_row {
275 return Err(TurboQuantCodecError::DimensionOutOfBounds {
276 dimension: byte.saturating_mul(u8::BITS as usize),
277 dimensions: self.dimensions,
278 });
279 }
280 Ok(self.byte_offset_unchecked(row, byte))
281 }
282
283 fn validate_row(&self, row: usize) -> TurboQuantCodecResult<()> {
284 if row >= self.rows {
285 Err(TurboQuantCodecError::RowOutOfBounds {
286 row,
287 rows: self.rows,
288 })
289 } else {
290 Ok(())
291 }
292 }
293
294 fn set_row_byte(&mut self, row: usize, byte: usize, value: u8) {
295 let offset = self.byte_offset_unchecked(row, byte);
296 self.bytes[offset] = value;
297 }
298
299 fn byte_offset_if_allocated(&self, row: usize, byte: usize) -> Option<usize> {
300 let offset = self.byte_offset_unchecked(row, byte);
301 (offset < self.bytes.len()).then_some(offset)
302 }
303
304 fn byte_offset_unchecked(&self, row: usize, byte: usize) -> usize {
305 let block = row / TURBO_QUANT_BLOCK_ROWS;
306 let lane = row % TURBO_QUANT_BLOCK_ROWS;
307 (block * self.bytes_per_row + byte) * TURBO_QUANT_BLOCK_ROWS + lane
308 }
309}
310
311fn block_count(rows: usize) -> usize {
312 rows.div_ceil(TURBO_QUANT_BLOCK_ROWS)
313}
314
315fn byte_len(bytes_per_row: usize, rows: usize) -> TurboQuantCodecResult<usize> {
316 block_count(rows)
317 .checked_mul(bytes_per_row)
318 .and_then(|bytes| bytes.checked_mul(TURBO_QUANT_BLOCK_ROWS))
319 .ok_or(TurboQuantCodecError::SizeOverflow)
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn blocked_codes_match_row_major_reads() {
328 for bits in 2..=4 {
329 let bit_width = TurboQuantBitWidth::new(bits).unwrap();
330 let mut row_major = TurboQuantPackedCodes::new(bit_width, 11, 35).unwrap();
331 let mut blocked = TurboQuantBlockedCodes::new(bit_width, 11, 35).unwrap();
332 for row in 0..row_major.rows() {
333 for dimension in 0..row_major.dimensions() {
334 let code = ((row * 3 + dimension) % bit_width.levels()) as u8;
335 row_major.write(row, dimension, code).unwrap();
336 blocked.write(row, dimension, code).unwrap();
337 }
338 }
339
340 for row in 0..row_major.rows() {
341 for dimension in 0..row_major.dimensions() {
342 assert_eq!(
343 blocked.read(row, dimension).unwrap(),
344 row_major.read(row, dimension).unwrap()
345 );
346 }
347 }
348 }
349 }
350
351 #[test]
352 fn row_major_repack_uses_block_byte_layout() {
353 let bit_width = TurboQuantBitWidth::new(4).unwrap();
354 let mut row_major = TurboQuantPackedCodes::new(bit_width, 4, 35).unwrap();
355 for row in 0..row_major.rows() {
356 for dimension in 0..row_major.dimensions() {
357 row_major
358 .write(row, dimension, ((row + dimension) % 16) as u8)
359 .unwrap();
360 }
361 }
362
363 let blocked = TurboQuantBlockedCodes::from_row_major(&row_major).unwrap();
364
365 assert_eq!(blocked.block_count(), 2);
366 assert_eq!(blocked.block_len(0), TURBO_QUANT_BLOCK_ROWS);
367 assert_eq!(blocked.block_len(1), 3);
368 for byte in 0..row_major.bytes_per_row() {
369 let block_byte = blocked.block_byte(0, byte);
370 for (row, packed) in block_byte.iter().enumerate() {
371 assert_eq!(
372 *packed,
373 row_major.as_bytes()[row * row_major.bytes_per_row() + byte]
374 );
375 }
376 }
377 }
378
379 #[test]
380 fn write_row_bytes_overwrites_one_blocked_row() {
381 let bit_width = TurboQuantBitWidth::new(4).unwrap();
382 let mut blocked = TurboQuantBlockedCodes::new(bit_width, 4, 35).unwrap();
383
384 blocked.write_row_bytes(33, &[0x21, 0x43]).unwrap();
385
386 assert_eq!(blocked.read(33, 0).unwrap(), 1);
387 assert_eq!(blocked.read(33, 1).unwrap(), 2);
388 assert_eq!(blocked.read(33, 2).unwrap(), 3);
389 assert_eq!(blocked.read(33, 3).unwrap(), 4);
390 assert_eq!(blocked.block_byte(1, 0)[1], 0x21);
391 assert_eq!(blocked.block_byte(1, 1)[1], 0x43);
392 }
393
394 #[test]
395 fn write_row_bytes_rejects_wrong_length() {
396 let bit_width = TurboQuantBitWidth::new(4).unwrap();
397 let mut blocked = TurboQuantBlockedCodes::new(bit_width, 4, 1).unwrap();
398
399 assert_eq!(
400 blocked.write_row_bytes(0, &[0x21]).unwrap_err(),
401 TurboQuantCodecError::ByteLengthMismatch {
402 expected: 2,
403 actual: 1
404 }
405 );
406 }
407
408 #[test]
409 fn resize_rows_clears_retained_tail_slots() {
410 let bit_width = TurboQuantBitWidth::new(4).unwrap();
411 let mut blocked = TurboQuantBlockedCodes::new(bit_width, 2, 4).unwrap();
412 blocked.write(3, 0, 15).unwrap();
413 blocked.resize_rows(2).unwrap();
414 blocked.resize_rows(4).unwrap();
415
416 assert_eq!(blocked.read(3, 0).unwrap(), 0);
417 }
418
419 #[test]
420 fn swap_remove_row_moves_last_row_and_clears_tail() {
421 for bits in 2..=4 {
422 let bit_width = TurboQuantBitWidth::new(bits).unwrap();
423 let mut blocked = TurboQuantBlockedCodes::new(bit_width, 11, 35).unwrap();
424 let last = blocked.rows() - 1;
425 let removed = 7;
426 let max_code = usize::from(bit_width.max_code());
427 let moved_codes = (0..blocked.dimensions())
428 .map(|dim| ((last * 5 + dim * 3) % (max_code + 1)) as u8)
429 .collect::<Vec<_>>();
430 for row in 0..blocked.rows() {
431 for dim in 0..blocked.dimensions() {
432 let code = ((row * 5 + dim * 3) % (max_code + 1)) as u8;
433 blocked.write(row, dim, code).unwrap();
434 }
435 }
436
437 blocked.swap_remove_row(removed).unwrap();
438
439 assert_eq!(blocked.rows(), last);
440 for (dim, expected) in moved_codes.into_iter().enumerate() {
441 assert_eq!(blocked.read(removed, dim).unwrap(), expected);
442 }
443 blocked.resize_rows(last + 1).unwrap();
444 for dim in 0..blocked.dimensions() {
445 assert_eq!(blocked.read(last, dim).unwrap(), 0);
446 }
447 }
448 }
449}