1use crate::error::{Result, SQLRiteError};
71use crate::mvcc::{RowID, VersionPayload};
72use crate::sql::db::table::Value;
73use crate::sql::pager::page::PAGE_SIZE;
74
75pub const MVCC_FRAME_MARKER: u32 = u32::MAX;
85
86pub const MVCC_BODY_MAGIC: &[u8; 8] = b"MVCC0001";
91
92pub const MVCC_BODY_PAYLOAD_CAP: usize = PAGE_SIZE - 8 - 8 - 2;
96
97#[derive(Debug, Clone, PartialEq)]
102pub struct MvccLogRecord {
103 pub row: RowID,
104 pub payload: VersionPayload,
105}
106
107impl MvccLogRecord {
108 pub fn upsert(table: impl Into<String>, rowid: i64, columns: Vec<(String, Value)>) -> Self {
109 Self {
110 row: RowID::new(table, rowid),
111 payload: VersionPayload::Present(columns),
112 }
113 }
114
115 pub fn tombstone(table: impl Into<String>, rowid: i64) -> Self {
116 Self {
117 row: RowID::new(table, rowid),
118 payload: VersionPayload::Tombstone,
119 }
120 }
121}
122
123#[derive(Debug, Clone, PartialEq)]
127pub struct MvccCommitBatch {
128 pub commit_ts: u64,
129 pub records: Vec<MvccLogRecord>,
130}
131
132impl MvccCommitBatch {
133 pub fn encode(&self) -> Result<Box<[u8; PAGE_SIZE]>> {
143 let mut buf = Box::new([0u8; PAGE_SIZE]);
144 let mut cur = 0usize;
145 write_bytes(&mut buf, &mut cur, MVCC_BODY_MAGIC)?;
146 write_u64(&mut buf, &mut cur, self.commit_ts)?;
147 if self.records.len() > u16::MAX as usize {
148 return Err(SQLRiteError::General(format!(
149 "MVCC log: too many records in one commit ({}); cap is {}",
150 self.records.len(),
151 u16::MAX
152 )));
153 }
154 write_u16(&mut buf, &mut cur, self.records.len() as u16)?;
155 for rec in &self.records {
156 encode_record(&mut buf, &mut cur, rec)?;
157 }
158 Ok(buf)
159 }
160
161 pub fn decode(body: &[u8]) -> Result<Self> {
167 if body.len() < 8 + 8 + 2 {
168 return Err(SQLRiteError::General(
169 "MVCC log: body shorter than fixed header".to_string(),
170 ));
171 }
172 if &body[0..8] != MVCC_BODY_MAGIC {
173 return Err(SQLRiteError::General(format!(
174 "MVCC log: bad magic, expected {:?}, got {:?}",
175 MVCC_BODY_MAGIC,
176 &body[0..8],
177 )));
178 }
179 let commit_ts = read_u64(body, 8);
180 let record_count = read_u16(body, 16) as usize;
181 let mut cur = 18usize;
182 let mut records = Vec::with_capacity(record_count);
183 for _ in 0..record_count {
184 records.push(decode_record(body, &mut cur)?);
185 }
186 Ok(Self { commit_ts, records })
187 }
188}
189
190fn write_bytes(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, src: &[u8]) -> Result<()> {
193 if *cur + src.len() > PAGE_SIZE {
194 return Err(SQLRiteError::General(format!(
195 "MVCC log: encoded batch exceeds {PAGE_SIZE}-byte frame body cap"
196 )));
197 }
198 buf[*cur..*cur + src.len()].copy_from_slice(src);
199 *cur += src.len();
200 Ok(())
201}
202
203fn write_u16(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, v: u16) -> Result<()> {
204 write_bytes(buf, cur, &v.to_le_bytes())
205}
206
207fn write_u32(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, v: u32) -> Result<()> {
208 write_bytes(buf, cur, &v.to_le_bytes())
209}
210
211fn write_u64(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, v: u64) -> Result<()> {
212 write_bytes(buf, cur, &v.to_le_bytes())
213}
214
215fn write_i64(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, v: i64) -> Result<()> {
216 write_bytes(buf, cur, &v.to_le_bytes())
217}
218
219fn write_f64(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, v: f64) -> Result<()> {
220 write_bytes(buf, cur, &v.to_le_bytes())
221}
222
223fn write_str(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, s: &str) -> Result<()> {
224 if s.len() > u16::MAX as usize {
225 return Err(SQLRiteError::General(format!(
226 "MVCC log: string too long ({}); cap is {}",
227 s.len(),
228 u16::MAX,
229 )));
230 }
231 write_u16(buf, cur, s.len() as u16)?;
232 write_bytes(buf, cur, s.as_bytes())
233}
234
235fn encode_record(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, rec: &MvccLogRecord) -> Result<()> {
236 let op: u8 = match rec.payload {
237 VersionPayload::Tombstone => 0,
238 VersionPayload::Present(_) => 1,
239 };
240 write_bytes(buf, cur, &[op])?;
241 write_str(buf, cur, &rec.row.table)?;
242 write_i64(buf, cur, rec.row.rowid)?;
243 if let VersionPayload::Present(cols) = &rec.payload {
244 if cols.len() > u16::MAX as usize {
245 return Err(SQLRiteError::General(format!(
246 "MVCC log: column count {} exceeds cap {}",
247 cols.len(),
248 u16::MAX
249 )));
250 }
251 write_u16(buf, cur, cols.len() as u16)?;
252 for (name, value) in cols {
253 write_str(buf, cur, name)?;
254 encode_value(buf, cur, value)?;
255 }
256 }
257 Ok(())
258}
259
260fn encode_value(buf: &mut [u8; PAGE_SIZE], cur: &mut usize, v: &Value) -> Result<()> {
261 match v {
262 Value::Null => write_bytes(buf, cur, &[0u8]),
263 Value::Integer(n) => {
264 write_bytes(buf, cur, &[1u8])?;
265 write_i64(buf, cur, *n)
266 }
267 Value::Real(f) => {
268 write_bytes(buf, cur, &[2u8])?;
269 write_f64(buf, cur, *f)
270 }
271 Value::Text(s) => {
272 write_bytes(buf, cur, &[3u8])?;
273 if s.len() > u32::MAX as usize {
274 return Err(SQLRiteError::General(
275 "MVCC log: TEXT value exceeds u32 length cap".to_string(),
276 ));
277 }
278 write_u32(buf, cur, s.len() as u32)?;
279 write_bytes(buf, cur, s.as_bytes())
280 }
281 Value::Bool(b) => {
282 write_bytes(buf, cur, &[4u8])?;
283 write_bytes(buf, cur, &[*b as u8])
284 }
285 Value::Vector(elements) => {
286 write_bytes(buf, cur, &[5u8])?;
287 if elements.len() > u32::MAX as usize {
288 return Err(SQLRiteError::General(
289 "MVCC log: VECTOR value exceeds u32 length cap".to_string(),
290 ));
291 }
292 write_u32(buf, cur, elements.len() as u32)?;
293 for x in elements {
294 write_bytes(buf, cur, &x.to_le_bytes())?;
295 }
296 Ok(())
297 }
298 }
299}
300
301fn read_u16(buf: &[u8], at: usize) -> u16 {
304 u16::from_le_bytes(buf[at..at + 2].try_into().unwrap())
305}
306
307fn read_u32(buf: &[u8], at: usize) -> u32 {
308 u32::from_le_bytes(buf[at..at + 4].try_into().unwrap())
309}
310
311fn read_u64(buf: &[u8], at: usize) -> u64 {
312 u64::from_le_bytes(buf[at..at + 8].try_into().unwrap())
313}
314
315fn read_i64(buf: &[u8], at: usize) -> i64 {
316 i64::from_le_bytes(buf[at..at + 8].try_into().unwrap())
317}
318
319fn read_f64(buf: &[u8], at: usize) -> f64 {
320 f64::from_le_bytes(buf[at..at + 8].try_into().unwrap())
321}
322
323fn read_str(buf: &[u8], cur: &mut usize) -> Result<String> {
324 if *cur + 2 > buf.len() {
325 return Err(SQLRiteError::General(
326 "MVCC log: truncated string length".to_string(),
327 ));
328 }
329 let len = read_u16(buf, *cur) as usize;
330 *cur += 2;
331 if *cur + len > buf.len() {
332 return Err(SQLRiteError::General(format!(
333 "MVCC log: truncated string body (need {len} bytes)"
334 )));
335 }
336 let s = std::str::from_utf8(&buf[*cur..*cur + len])
337 .map_err(|e| SQLRiteError::General(format!("MVCC log: invalid UTF-8 in string: {e}")))?
338 .to_string();
339 *cur += len;
340 Ok(s)
341}
342
343fn decode_record(buf: &[u8], cur: &mut usize) -> Result<MvccLogRecord> {
344 if *cur + 1 > buf.len() {
345 return Err(SQLRiteError::General(
346 "MVCC log: truncated op tag".to_string(),
347 ));
348 }
349 let op = buf[*cur];
350 *cur += 1;
351 let table = read_str(buf, cur)?;
352 if *cur + 8 > buf.len() {
353 return Err(SQLRiteError::General(
354 "MVCC log: truncated rowid".to_string(),
355 ));
356 }
357 let rowid = read_i64(buf, *cur);
358 *cur += 8;
359 let payload = match op {
360 0 => VersionPayload::Tombstone,
361 1 => {
362 if *cur + 2 > buf.len() {
363 return Err(SQLRiteError::General(
364 "MVCC log: truncated column count".to_string(),
365 ));
366 }
367 let n = read_u16(buf, *cur) as usize;
368 *cur += 2;
369 let mut cols = Vec::with_capacity(n);
370 for _ in 0..n {
371 let name = read_str(buf, cur)?;
372 let value = decode_value(buf, cur)?;
373 cols.push((name, value));
374 }
375 VersionPayload::Present(cols)
376 }
377 other => {
378 return Err(SQLRiteError::General(format!(
379 "MVCC log: unknown op tag {other}"
380 )));
381 }
382 };
383 Ok(MvccLogRecord {
384 row: RowID::new(table, rowid),
385 payload,
386 })
387}
388
389fn decode_value(buf: &[u8], cur: &mut usize) -> Result<Value> {
390 if *cur + 1 > buf.len() {
391 return Err(SQLRiteError::General(
392 "MVCC log: truncated value tag".to_string(),
393 ));
394 }
395 let tag = buf[*cur];
396 *cur += 1;
397 let value = match tag {
398 0 => Value::Null,
399 1 => {
400 if *cur + 8 > buf.len() {
401 return Err(SQLRiteError::General(
402 "MVCC log: truncated Integer value".to_string(),
403 ));
404 }
405 let v = Value::Integer(read_i64(buf, *cur));
406 *cur += 8;
407 v
408 }
409 2 => {
410 if *cur + 8 > buf.len() {
411 return Err(SQLRiteError::General(
412 "MVCC log: truncated Real value".to_string(),
413 ));
414 }
415 let v = Value::Real(read_f64(buf, *cur));
416 *cur += 8;
417 v
418 }
419 3 => {
420 if *cur + 4 > buf.len() {
421 return Err(SQLRiteError::General(
422 "MVCC log: truncated Text length".to_string(),
423 ));
424 }
425 let len = read_u32(buf, *cur) as usize;
426 *cur += 4;
427 if *cur + len > buf.len() {
428 return Err(SQLRiteError::General(format!(
429 "MVCC log: truncated Text body (need {len} bytes)"
430 )));
431 }
432 let s = std::str::from_utf8(&buf[*cur..*cur + len])
433 .map_err(|e| {
434 SQLRiteError::General(format!("MVCC log: invalid UTF-8 in Text: {e}"))
435 })?
436 .to_string();
437 *cur += len;
438 Value::Text(s)
439 }
440 4 => {
441 if *cur + 1 > buf.len() {
442 return Err(SQLRiteError::General(
443 "MVCC log: truncated Bool".to_string(),
444 ));
445 }
446 let v = Value::Bool(buf[*cur] != 0);
447 *cur += 1;
448 v
449 }
450 5 => {
451 if *cur + 4 > buf.len() {
452 return Err(SQLRiteError::General(
453 "MVCC log: truncated Vector length".to_string(),
454 ));
455 }
456 let n = read_u32(buf, *cur) as usize;
457 *cur += 4;
458 if *cur + n * 4 > buf.len() {
459 return Err(SQLRiteError::General(format!(
460 "MVCC log: truncated Vector body (need {} bytes)",
461 n * 4
462 )));
463 }
464 let mut elements = Vec::with_capacity(n);
465 for _ in 0..n {
466 let f = f32::from_le_bytes(buf[*cur..*cur + 4].try_into().unwrap());
467 elements.push(f);
468 *cur += 4;
469 }
470 Value::Vector(elements)
471 }
472 other => {
473 return Err(SQLRiteError::General(format!(
474 "MVCC log: unknown value tag {other}"
475 )));
476 }
477 };
478 Ok(value)
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn empty_batch_round_trips() {
487 let batch = MvccCommitBatch {
488 commit_ts: 42,
489 records: Vec::new(),
490 };
491 let bytes = batch.encode().unwrap();
492 let back = MvccCommitBatch::decode(bytes.as_ref()).unwrap();
493 assert_eq!(batch, back);
494 }
495
496 #[test]
497 fn upsert_round_trips_with_every_value_kind() {
498 let cols = vec![
499 ("a_null".to_string(), Value::Null),
500 ("an_int".to_string(), Value::Integer(-42)),
501 ("a_real".to_string(), Value::Real(2.5)),
502 ("a_text".to_string(), Value::Text("héllo".to_string())),
503 ("a_bool".to_string(), Value::Bool(true)),
504 ("a_vec".to_string(), Value::Vector(vec![1.0, -2.5, 3.25])),
505 ];
506 let batch = MvccCommitBatch {
507 commit_ts: 99,
508 records: vec![MvccLogRecord::upsert("accounts", 7, cols)],
509 };
510 let bytes = batch.encode().unwrap();
511 let back = MvccCommitBatch::decode(bytes.as_ref()).unwrap();
512 assert_eq!(batch, back);
513 }
514
515 #[test]
516 fn multiple_records_in_one_batch_round_trip() {
517 let batch = MvccCommitBatch {
518 commit_ts: 100,
519 records: vec![
520 MvccLogRecord::upsert("t", 1, vec![("v".into(), Value::Integer(10))]),
521 MvccLogRecord::upsert("t", 2, vec![("v".into(), Value::Integer(20))]),
522 MvccLogRecord::tombstone("t", 3),
523 ],
524 };
525 let bytes = batch.encode().unwrap();
526 let back = MvccCommitBatch::decode(bytes.as_ref()).unwrap();
527 assert_eq!(batch, back);
528 }
529
530 #[test]
531 fn unicode_table_and_column_names_round_trip() {
532 let batch = MvccCommitBatch {
533 commit_ts: 1,
534 records: vec![MvccLogRecord::upsert(
535 "café_tablé",
536 1,
537 vec![("naïve_col".into(), Value::Text("日本語".into()))],
538 )],
539 };
540 let bytes = batch.encode().unwrap();
541 let back = MvccCommitBatch::decode(bytes.as_ref()).unwrap();
542 assert_eq!(batch, back);
543 }
544
545 #[test]
546 fn bad_magic_decode_errors() {
547 let mut bytes = [0u8; PAGE_SIZE];
548 bytes[0..8].copy_from_slice(b"NOTVALID");
549 let err = MvccCommitBatch::decode(&bytes).unwrap_err();
550 assert!(format!("{err}").contains("bad magic"));
551 }
552
553 #[test]
554 fn truncated_body_decode_errors() {
555 let mut bytes = vec![0u8; 8 + 8 + 2];
557 bytes[0..8].copy_from_slice(MVCC_BODY_MAGIC);
558 bytes[16..18].copy_from_slice(&1u16.to_le_bytes());
559 let err = MvccCommitBatch::decode(&bytes).unwrap_err();
560 assert!(format!("{err}").contains("truncated"));
561 }
562
563 #[test]
564 fn unknown_op_tag_decode_errors() {
565 let mut bytes = vec![0u8; 8 + 8 + 2 + 1 + 2 + 1 + 8];
567 bytes[0..8].copy_from_slice(MVCC_BODY_MAGIC);
568 bytes[16..18].copy_from_slice(&1u16.to_le_bytes());
569 bytes[18] = 42; bytes[19..21].copy_from_slice(&1u16.to_le_bytes()); bytes[21] = b't';
572 bytes[22..30].copy_from_slice(&0i64.to_le_bytes());
573 let err = MvccCommitBatch::decode(&bytes).unwrap_err();
574 assert!(format!("{err}").contains("unknown op tag"));
575 }
576
577 #[test]
581 fn oversized_batch_encode_errors() {
582 let big = "x".repeat(PAGE_SIZE);
585 let batch = MvccCommitBatch {
586 commit_ts: 1,
587 records: vec![MvccLogRecord::upsert(
588 "t",
589 1,
590 vec![("c".into(), Value::Text(big))],
591 )],
592 };
593 let err = batch.encode().unwrap_err();
594 assert!(format!("{err}").contains("exceeds"));
595 }
596
597 #[test]
600 fn column_order_is_preserved() {
601 let cols = vec![
602 ("z".to_string(), Value::Integer(1)),
603 ("a".to_string(), Value::Integer(2)),
604 ("m".to_string(), Value::Integer(3)),
605 ];
606 let batch = MvccCommitBatch {
607 commit_ts: 1,
608 records: vec![MvccLogRecord::upsert("t", 1, cols.clone())],
609 };
610 let bytes = batch.encode().unwrap();
611 let back = MvccCommitBatch::decode(bytes.as_ref()).unwrap();
612 if let VersionPayload::Present(decoded_cols) = &back.records[0].payload {
613 assert_eq!(
614 decoded_cols
615 .iter()
616 .map(|(n, _)| n.as_str())
617 .collect::<Vec<_>>(),
618 vec!["z", "a", "m"]
619 );
620 } else {
621 panic!("expected Present payload");
622 }
623 }
624}