1use crate::error::{Result, SQLRiteError};
46use crate::sql::db::table::Value;
47use crate::sql::pager::varint;
48
49pub const KIND_LOCAL: u8 = 0x01;
57pub const KIND_OVERFLOW: u8 = 0x02;
58pub const KIND_INTERIOR: u8 = 0x03;
59pub const KIND_INDEX: u8 = 0x04;
60
61pub mod tag {
63 pub const INTEGER: u8 = 0;
64 pub const REAL: u8 = 1;
65 pub const TEXT: u8 = 2;
66 pub const BOOL: u8 = 3;
67 pub const VECTOR: u8 = 4;
72}
73
74#[derive(Debug, Clone, PartialEq)]
79pub struct Cell {
80 pub rowid: i64,
81 pub values: Vec<Option<Value>>,
82}
83
84impl Cell {
85 pub fn new(rowid: i64, values: Vec<Option<Value>>) -> Self {
86 Self { rowid, values }
87 }
88
89 pub fn encode(&self) -> Result<Vec<u8>> {
94 let mut body = Vec::new();
97 body.push(KIND_LOCAL);
98 varint::write_i64(&mut body, self.rowid);
99 varint::write_u64(&mut body, self.values.len() as u64);
100 encode_null_bitmap(&mut body, &self.values);
101 for v in self.values.iter().flatten() {
102 encode_value(&mut body, v)?;
103 }
104
105 let mut out = Vec::with_capacity(body.len() + varint::MAX_VARINT_BYTES);
106 varint::write_u64(&mut out, body.len() as u64);
107 out.extend_from_slice(&body);
108 Ok(out)
109 }
110
111 pub fn encoded_len(&self) -> Result<usize> {
114 Ok(self.encode()?.len())
118 }
119
120 pub fn peek_rowid(buf: &[u8], pos: usize) -> Result<i64> {
125 let (_body_len, len_bytes) = varint::read_u64(buf, pos)?;
126 let body_start = pos + len_bytes;
127 if body_start >= buf.len() {
129 return Err(SQLRiteError::Internal(
130 "paged cell truncated before kind tag".to_string(),
131 ));
132 }
133 let (rowid, _) = varint::read_i64(buf, body_start + 1)?;
134 Ok(rowid)
135 }
136
137 pub fn encoded_size_at(buf: &[u8], pos: usize) -> Result<usize> {
141 let (body_len, len_bytes) = varint::read_u64(buf, pos)?;
142 Ok(len_bytes + body_len as usize)
143 }
144
145 pub fn peek_kind(buf: &[u8], pos: usize) -> Result<u8> {
148 let (_body_len, len_bytes) = varint::read_u64(buf, pos)?;
149 let kind_pos = pos + len_bytes;
150 buf.get(kind_pos).copied().ok_or_else(|| {
151 SQLRiteError::Internal("paged cell truncated before kind tag".to_string())
152 })
153 }
154
155 pub fn decode(buf: &[u8], pos: usize) -> Result<(Cell, usize)> {
160 let (body_len, len_bytes) = varint::read_u64(buf, pos)?;
161 let body_start = pos + len_bytes;
162 let body_end = body_start
163 .checked_add(body_len as usize)
164 .ok_or_else(|| SQLRiteError::Internal("cell length overflow".to_string()))?;
165 if body_end > buf.len() {
166 return Err(SQLRiteError::Internal(format!(
167 "cell extends past buffer: needs bytes {body_start}..{body_end}, have {}",
168 buf.len()
169 )));
170 }
171
172 let body = &buf[body_start..body_end];
173 if body.is_empty() {
174 return Err(SQLRiteError::Internal(
175 "paged cell body is empty (no kind tag)".to_string(),
176 ));
177 }
178 let kind_tag = body[0];
179 if kind_tag != KIND_LOCAL {
180 return Err(SQLRiteError::Internal(format!(
181 "Cell::decode called on non-local entry (kind_tag = {kind_tag:#x})"
182 )));
183 }
184 let mut cur = 1usize;
185
186 let (rowid, n) = varint::read_i64(body, cur)?;
187 cur += n;
188 let (col_count_u, n) = varint::read_u64(body, cur)?;
189 cur += n;
190 let col_count = col_count_u as usize;
191
192 let bitmap_bytes = col_count.div_ceil(8);
193 if cur + bitmap_bytes > body.len() {
194 return Err(SQLRiteError::Internal(
195 "cell body truncated before null bitmap ends".to_string(),
196 ));
197 }
198 let bitmap = &body[cur..cur + bitmap_bytes];
199 cur += bitmap_bytes;
200
201 let mut values = Vec::with_capacity(col_count);
202 for col in 0..col_count {
203 if is_null(bitmap, col) {
204 values.push(None);
205 } else {
206 let (v, n) = decode_value(body, cur)?;
207 cur += n;
208 values.push(Some(v));
209 }
210 }
211
212 if cur != body.len() {
213 return Err(SQLRiteError::Internal(format!(
214 "cell body had {} trailing bytes after last value",
215 body.len() - cur
216 )));
217 }
218
219 Ok((Cell { rowid, values }, body_end - pos))
220 }
221}
222
223fn encode_null_bitmap(out: &mut Vec<u8>, values: &[Option<Value>]) {
224 let n = values.len().div_ceil(8);
225 let start = out.len();
226 out.resize(start + n, 0);
227 for (i, v) in values.iter().enumerate() {
228 if v.is_none() {
229 let byte_idx = start + (i / 8);
230 let bit = i % 8;
231 out[byte_idx] |= 1 << bit;
232 }
233 }
234}
235
236fn is_null(bitmap: &[u8], col: usize) -> bool {
237 let byte = col / 8;
238 let bit = col % 8;
239 bitmap.get(byte).is_some_and(|b| (b >> bit) & 1 == 1)
240}
241
242pub(super) fn encode_value(out: &mut Vec<u8>, value: &Value) -> Result<()> {
243 match value {
244 Value::Integer(i) => {
245 out.push(tag::INTEGER);
246 varint::write_i64(out, *i);
247 }
248 Value::Real(f) => {
249 out.push(tag::REAL);
250 out.extend_from_slice(&f.to_le_bytes());
251 }
252 Value::Text(s) => {
253 out.push(tag::TEXT);
254 let bytes = s.as_bytes();
255 varint::write_u64(out, bytes.len() as u64);
256 out.extend_from_slice(bytes);
257 }
258 Value::Bool(b) => {
259 out.push(tag::BOOL);
260 out.push(if *b { 1 } else { 0 });
261 }
262 Value::Vector(v) => {
263 out.push(tag::VECTOR);
264 varint::write_u64(out, v.len() as u64);
266 for x in v {
268 out.extend_from_slice(&x.to_le_bytes());
269 }
270 }
271 Value::Null => {
272 return Err(SQLRiteError::Internal(
273 "Null values are encoded via the null bitmap, not a value block".to_string(),
274 ));
275 }
276 }
277 Ok(())
278}
279
280pub(super) fn decode_value(buf: &[u8], pos: usize) -> Result<(Value, usize)> {
281 let tag = *buf
282 .get(pos)
283 .ok_or_else(|| SQLRiteError::Internal(format!("value block truncated at offset {pos}")))?;
284 let body_start = pos + 1;
285 match tag {
286 tag::INTEGER => {
287 let (v, n) = varint::read_i64(buf, body_start)?;
288 Ok((Value::Integer(v), 1 + n))
289 }
290 tag::REAL => {
291 let end = body_start + 8;
292 if end > buf.len() {
293 return Err(SQLRiteError::Internal(
294 "Real value truncated: needs 8 bytes".to_string(),
295 ));
296 }
297 let arr: [u8; 8] = buf[body_start..end].try_into().unwrap();
298 Ok((Value::Real(f64::from_le_bytes(arr)), 1 + 8))
299 }
300 tag::TEXT => {
301 let (len, n) = varint::read_u64(buf, body_start)?;
302 let text_start = body_start + n;
303 let text_end = text_start + (len as usize);
304 if text_end > buf.len() {
305 return Err(SQLRiteError::Internal("Text value truncated".to_string()));
306 }
307 let s = std::str::from_utf8(&buf[text_start..text_end])
308 .map_err(|e| SQLRiteError::Internal(format!("Text value is not valid UTF-8: {e}")))?
309 .to_string();
310 Ok((Value::Text(s), 1 + n + (len as usize)))
311 }
312 tag::BOOL => {
313 let byte = *buf
314 .get(body_start)
315 .ok_or_else(|| SQLRiteError::Internal("Bool value truncated".to_string()))?;
316 Ok((Value::Bool(byte != 0), 1 + 1))
317 }
318 tag::VECTOR => {
319 let (dim, n) = varint::read_u64(buf, body_start)?;
322 let dim = dim as usize;
323 let elements_start = body_start + n;
324 let elements_end = elements_start + dim * 4;
325 if elements_end > buf.len() {
326 return Err(SQLRiteError::Internal(format!(
327 "Vector value truncated: needs {dim} × 4 = {} bytes",
328 dim * 4
329 )));
330 }
331 let mut out = Vec::with_capacity(dim);
332 for i in 0..dim {
333 let off = elements_start + i * 4;
334 let arr: [u8; 4] = buf[off..off + 4].try_into().unwrap();
335 out.push(f32::from_le_bytes(arr));
336 }
337 Ok((Value::Vector(out), 1 + n + dim * 4))
338 }
339 other => Err(SQLRiteError::Internal(format!(
340 "unknown value tag {other:#x} at offset {pos}"
341 ))),
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 fn round_trip(cell: &Cell) {
350 let bytes = cell.encode().unwrap();
351 let (back, consumed) = Cell::decode(&bytes, 0).unwrap();
352 assert_eq!(&back, cell);
353 assert_eq!(consumed, bytes.len());
354 }
355
356 #[test]
357 fn empty_cell_no_columns() {
358 round_trip(&Cell::new(1, vec![]));
359 }
360
361 #[test]
362 fn integer_only_cell() {
363 round_trip(&Cell::new(
364 42,
365 vec![Some(Value::Integer(1)), Some(Value::Integer(-1000))],
366 ));
367 }
368
369 #[test]
370 fn mixed_types_cell() {
371 round_trip(&Cell::new(
372 100,
373 vec![
374 Some(Value::Integer(7)),
375 Some(Value::Text("hello".to_string())),
376 Some(Value::Real(2.5)),
380 Some(Value::Bool(true)),
381 ],
382 ));
383 }
384
385 #[test]
386 fn nulls_interspersed() {
387 round_trip(&Cell::new(
388 5,
389 vec![
390 Some(Value::Integer(1)),
391 None,
392 Some(Value::Text("middle".to_string())),
393 None,
394 None,
395 Some(Value::Bool(false)),
396 ],
397 ));
398 }
399
400 #[test]
401 fn all_null_cell() {
402 round_trip(&Cell::new(
403 9,
404 vec![None, None, None, None, None, None, None, None, None],
405 ));
406 }
407
408 #[test]
409 fn large_text_cell() {
410 let big = "abc".repeat(10_000);
411 round_trip(&Cell::new(1, vec![Some(Value::Text(big))]));
412 }
413
414 #[test]
415 fn utf8_text_cell() {
416 round_trip(&Cell::new(
417 1,
418 vec![Some(Value::Text("héllo 🦀 世界".to_string()))],
419 ));
420 }
421
422 #[test]
423 fn negative_and_large_rowids() {
424 round_trip(&Cell::new(i64::MIN, vec![Some(Value::Integer(1))]));
425 round_trip(&Cell::new(i64::MAX, vec![Some(Value::Integer(1))]));
426 round_trip(&Cell::new(-1, vec![Some(Value::Integer(1))]));
427 }
428
429 #[test]
430 fn bool_edges() {
431 round_trip(&Cell::new(
432 1,
433 vec![Some(Value::Bool(true)), Some(Value::Bool(false))],
434 ));
435 }
436
437 #[test]
438 fn real_edges() {
439 for v in [
441 0.0f64,
442 1.0,
443 -1.0,
444 f64::MIN,
445 f64::MAX,
446 f64::INFINITY,
447 f64::NEG_INFINITY,
448 ] {
449 round_trip(&Cell::new(1, vec![Some(Value::Real(v))]));
450 }
451 }
452
453 #[test]
458 fn vector_round_trip_small() {
459 let v = vec![0.1f32, 0.2, 0.3];
463 round_trip(&Cell::new(1, vec![Some(Value::Vector(v))]));
464 }
465
466 #[test]
467 fn vector_round_trip_high_dim() {
468 let v: Vec<f32> = (0..384).map(|i| i as f32 * 0.01).collect();
471 round_trip(&Cell::new(7, vec![Some(Value::Vector(v))]));
472 }
473
474 #[test]
475 fn vector_round_trip_edge_values() {
476 let v = vec![
479 0.0f32,
480 -0.0,
481 1.0,
482 -1.0,
483 f32::MIN,
484 f32::MAX,
485 f32::INFINITY,
486 f32::NEG_INFINITY,
487 ];
488 let cell = Cell::new(2, vec![Some(Value::Vector(v.clone()))]);
491 let bytes = cell.encode().expect("encode");
492 let (decoded, _) = Cell::decode(&bytes, 0).expect("decode");
493 match &decoded.values[0] {
494 Some(Value::Vector(out)) => {
495 assert_eq!(out.len(), v.len());
496 for (i, (a, b)) in out.iter().zip(v.iter()).enumerate() {
497 assert_eq!(
498 a.to_bits(),
499 b.to_bits(),
500 "element {i} bits mismatch: out {a:?}, expected {b:?}"
501 );
502 }
503 }
504 other => panic!("decoded into wrong variant: {other:?}"),
505 }
506 }
507
508 #[test]
509 fn vector_round_trip_mixed_with_other_columns() {
510 let cell = Cell::new(
514 42,
515 vec![
516 Some(Value::Integer(7)),
517 Some(Value::Text("alpha".to_string())),
518 Some(Value::Vector(vec![1.0, 2.0, 3.0, 4.0])),
519 Some(Value::Bool(true)),
520 ],
521 );
522 round_trip(&cell);
523 }
524
525 #[test]
526 fn vector_decode_truncated_buffer_errors() {
527 let cell = Cell::new(1, vec![Some(Value::Vector(vec![1.0, 2.0, 3.0]))]);
530 let bytes = cell.encode().expect("encode");
531 for chop in 1..=4 {
532 let truncated = &bytes[..bytes.len() - chop];
533 assert!(
534 Cell::decode(truncated, 0).is_err(),
535 "expected error decoding {} bytes short of full {}",
536 chop,
537 bytes.len()
538 );
539 }
540 }
541
542 #[test]
543 fn encoding_null_directly_is_rejected() {
544 let bad = Cell::new(1, vec![Some(Value::Null)]);
545 let err = bad.encode().unwrap_err();
546 assert!(format!("{err}").contains("Null values are encoded"));
547 }
548
549 #[test]
550 fn decode_rejects_truncated_buffer() {
551 let cell = Cell::new(1, vec![Some(Value::Text("some text here".to_string()))]);
552 let bytes = cell.encode().unwrap();
553 let truncated = &bytes[..bytes.len() - 5];
554 assert!(Cell::decode(truncated, 0).is_err());
555 }
556
557 #[test]
558 fn decode_rejects_unknown_value_tag() {
559 let mut buf = Vec::new();
568 buf.push(5); buf.push(KIND_LOCAL); buf.push(0); buf.push(1); buf.push(0); buf.push(0xFE); let err = Cell::decode(&buf, 0).unwrap_err();
575 assert!(format!("{err}").contains("unknown value tag"));
576 }
577
578 #[test]
579 fn decode_rejects_wrong_kind_tag() {
580 let mut buf = Vec::new();
583 buf.push(1); buf.push(KIND_OVERFLOW);
585 let err = Cell::decode(&buf, 0).unwrap_err();
586 assert!(format!("{err}").contains("non-local"));
587 }
588
589 #[test]
590 fn concatenated_cells_read_sequentially() {
591 let c1 = Cell::new(1, vec![Some(Value::Integer(100))]);
592 let c2 = Cell::new(2, vec![Some(Value::Text("two".to_string()))]);
593 let c3 = Cell::new(3, vec![None]);
594
595 let mut buf = Vec::new();
596 buf.extend_from_slice(&c1.encode().unwrap());
597 buf.extend_from_slice(&c2.encode().unwrap());
598 buf.extend_from_slice(&c3.encode().unwrap());
599
600 let (d1, n1) = Cell::decode(&buf, 0).unwrap();
601 let (d2, n2) = Cell::decode(&buf, n1).unwrap();
602 let (d3, n3) = Cell::decode(&buf, n1 + n2).unwrap();
603 assert_eq!(d1, c1);
604 assert_eq!(d2, c2);
605 assert_eq!(d3, c3);
606 assert_eq!(n1 + n2 + n3, buf.len());
607 }
608
609 #[test]
610 fn null_bitmap_byte_boundary() {
611 let values: Vec<Option<Value>> = (0..8)
613 .map(|i| {
614 if i % 2 == 0 {
615 Some(Value::Integer(i))
616 } else {
617 None
618 }
619 })
620 .collect();
621 round_trip(&Cell::new(1, values));
622
623 let values: Vec<Option<Value>> = (0..9)
625 .map(|i| {
626 if i % 3 == 0 {
627 Some(Value::Integer(i))
628 } else {
629 None
630 }
631 })
632 .collect();
633 round_trip(&Cell::new(1, values));
634 }
635}