1use std::collections::HashMap;
52use std::fs::{File, OpenOptions};
53use std::io::{Read, Seek, SeekFrom, Write};
54use std::path::{Path, PathBuf};
55use std::time::{SystemTime, UNIX_EPOCH};
56
57use crate::error::{Result, SQLRiteError};
58use crate::sql::pager::page::PAGE_SIZE;
59use crate::sql::pager::pager::{AccessMode, acquire_lock};
60
61pub const WAL_HEADER_SIZE: usize = 32;
62pub const WAL_MAGIC: &[u8; 8] = b"SQLRWAL\0";
63pub const WAL_FORMAT_VERSION: u32 = 1;
64pub const FRAME_HEADER_SIZE: usize = 16;
65pub const FRAME_SIZE: usize = FRAME_HEADER_SIZE + PAGE_SIZE;
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub struct WalHeader {
72 pub salt: u32,
73 pub checkpoint_seq: u32,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct FrameHeader {
79 pub page_num: u32,
80 pub commit_page_count: u32,
81 pub salt: u32,
82 pub checksum: u32,
83}
84
85impl FrameHeader {
86 pub fn is_commit(&self) -> bool {
90 self.commit_page_count != 0
91 }
92}
93
94pub struct Wal {
95 file: File,
98 path: PathBuf,
99 header: WalHeader,
100 latest_frame: HashMap<u32, u64>,
105 last_commit_offset: u64,
109 last_commit_page_count: Option<u32>,
111 frame_count: usize,
114}
115
116impl Wal {
117 pub fn create(path: &Path) -> Result<Self> {
122 let file = OpenOptions::new()
123 .read(true)
124 .write(true)
125 .create(true)
126 .truncate(true)
127 .open(path)?;
128 acquire_lock(&file, path, AccessMode::ReadWrite)?;
129
130 let salt = random_salt();
131 let header = WalHeader {
132 salt,
133 checkpoint_seq: 0,
134 };
135 let mut wal = Self {
136 file,
137 path: path.to_path_buf(),
138 header,
139 latest_frame: HashMap::new(),
140 last_commit_offset: WAL_HEADER_SIZE as u64,
141 last_commit_page_count: None,
142 frame_count: 0,
143 };
144 wal.write_header()?;
145 wal.file.flush()?;
146 wal.file.sync_all()?;
147 Ok(wal)
148 }
149
150 pub fn open(path: &Path) -> Result<Self> {
153 Self::open_with_mode(path, AccessMode::ReadWrite)
154 }
155
156 pub fn open_with_mode(path: &Path, mode: AccessMode) -> Result<Self> {
164 let mut file = match mode {
165 AccessMode::ReadWrite => OpenOptions::new().read(true).write(true).open(path)?,
166 AccessMode::ReadOnly => OpenOptions::new().read(true).open(path)?,
167 };
168 acquire_lock(&file, path, mode)?;
169
170 let header = read_header(&mut file)?;
171 let mut wal = Self {
172 file,
173 path: path.to_path_buf(),
174 header,
175 latest_frame: HashMap::new(),
176 last_commit_offset: WAL_HEADER_SIZE as u64,
177 last_commit_page_count: None,
178 frame_count: 0,
179 };
180 wal.replay_frames()?;
181 Ok(wal)
182 }
183
184 pub fn header(&self) -> WalHeader {
185 self.header
186 }
187
188 pub fn frame_count(&self) -> usize {
189 self.frame_count
190 }
191
192 pub fn last_commit_page_count(&self) -> Option<u32> {
193 self.last_commit_page_count
194 }
195
196 pub fn load_committed_into(
201 &mut self,
202 dest: &mut HashMap<u32, Box<[u8; PAGE_SIZE]>>,
203 ) -> Result<()> {
204 let pages: Vec<u32> = self.latest_frame.keys().copied().collect();
207 for page_num in pages {
208 if let Some(body) = self.read_page(page_num)? {
209 dest.insert(page_num, body);
210 }
211 }
212 Ok(())
213 }
214
215 pub fn append_frame(
221 &mut self,
222 page_num: u32,
223 content: &[u8; PAGE_SIZE],
224 commit_page_count: Option<u32>,
225 ) -> Result<()> {
226 let mut header_buf = [0u8; FRAME_HEADER_SIZE];
229 header_buf[0..4].copy_from_slice(&page_num.to_le_bytes());
230 header_buf[4..8].copy_from_slice(&commit_page_count.unwrap_or(0).to_le_bytes());
231 header_buf[8..12].copy_from_slice(&self.header.salt.to_le_bytes());
232 let sum = compute_checksum(&header_buf[0..12], content);
233 header_buf[12..16].copy_from_slice(&sum.to_le_bytes());
234
235 let offset = self.file.seek(SeekFrom::End(0))?;
237 self.file.write_all(&header_buf)?;
238 self.file.write_all(content)?;
239
240 if commit_page_count.is_some() {
242 self.file.flush()?;
243 self.file.sync_all()?;
244 }
245
246 self.latest_frame.insert(page_num, offset);
250 if let Some(pc) = commit_page_count {
251 self.last_commit_offset = offset + FRAME_SIZE as u64;
252 self.last_commit_page_count = Some(pc);
253 }
254 self.frame_count += 1;
255 Ok(())
256 }
257
258 pub fn read_page(&mut self, page_num: u32) -> Result<Option<Box<[u8; PAGE_SIZE]>>> {
263 let Some(&offset) = self.latest_frame.get(&page_num) else {
264 return Ok(None);
265 };
266 if offset + FRAME_SIZE as u64 > self.last_commit_offset {
269 return Ok(None);
270 }
271 let (_hdr, body) = self.read_frame_at(offset)?;
272 Ok(Some(body))
273 }
274
275 pub fn truncate(&mut self) -> Result<()> {
279 self.header.salt = random_salt();
280 self.header.checkpoint_seq = self.header.checkpoint_seq.wrapping_add(1);
281 self.file.set_len(WAL_HEADER_SIZE as u64)?;
282 self.write_header()?;
283 self.file.flush()?;
284 self.file.sync_all()?;
285 self.latest_frame.clear();
286 self.last_commit_offset = WAL_HEADER_SIZE as u64;
287 self.last_commit_page_count = None;
288 self.frame_count = 0;
289 Ok(())
290 }
291
292 fn write_header(&mut self) -> Result<()> {
295 let mut buf = [0u8; WAL_HEADER_SIZE];
296 buf[0..8].copy_from_slice(WAL_MAGIC);
297 buf[8..12].copy_from_slice(&WAL_FORMAT_VERSION.to_le_bytes());
298 buf[12..16].copy_from_slice(&(PAGE_SIZE as u32).to_le_bytes());
299 buf[16..20].copy_from_slice(&self.header.salt.to_le_bytes());
300 buf[20..24].copy_from_slice(&self.header.checkpoint_seq.to_le_bytes());
301 self.file.seek(SeekFrom::Start(0))?;
303 self.file.write_all(&buf)?;
304 Ok(())
305 }
306
307 fn read_frame_at(&mut self, offset: u64) -> Result<(FrameHeader, Box<[u8; PAGE_SIZE]>)> {
310 self.file.seek(SeekFrom::Start(offset))?;
311 let mut header_buf = [0u8; FRAME_HEADER_SIZE];
312 self.file.read_exact(&mut header_buf)?;
313 let mut body = Box::new([0u8; PAGE_SIZE]);
314 self.file.read_exact(body.as_mut())?;
315
316 let page_num = u32::from_le_bytes(header_buf[0..4].try_into().unwrap());
317 let commit_page_count = u32::from_le_bytes(header_buf[4..8].try_into().unwrap());
318 let salt = u32::from_le_bytes(header_buf[8..12].try_into().unwrap());
319 let stored_checksum = u32::from_le_bytes(header_buf[12..16].try_into().unwrap());
320
321 if salt != self.header.salt {
322 return Err(SQLRiteError::General(format!(
323 "WAL frame at offset {offset}: salt mismatch (expected {:x}, got {:x})",
324 self.header.salt, salt
325 )));
326 }
327 let computed = compute_checksum(&header_buf[0..12], &body);
328 if computed != stored_checksum {
329 return Err(SQLRiteError::General(format!(
330 "WAL frame at offset {offset}: bad checksum (expected {stored_checksum:x}, got {computed:x})"
331 )));
332 }
333
334 Ok((
335 FrameHeader {
336 page_num,
337 commit_page_count,
338 salt,
339 checksum: stored_checksum,
340 },
341 body,
342 ))
343 }
344
345 fn replay_frames(&mut self) -> Result<()> {
359 let file_len = self.file.seek(SeekFrom::End(0))?;
360 let mut offset = WAL_HEADER_SIZE as u64;
361 let mut pending: HashMap<u32, u64> = HashMap::new();
362 while offset + FRAME_SIZE as u64 <= file_len {
363 match self.read_frame_at(offset) {
364 Ok((header, _body)) => {
365 self.frame_count += 1;
366 pending.insert(header.page_num, offset);
367 if header.is_commit() {
368 for (p, o) in pending.drain() {
371 self.latest_frame.insert(p, o);
372 }
373 self.last_commit_offset = offset + FRAME_SIZE as u64;
374 self.last_commit_page_count = Some(header.commit_page_count);
375 }
376 offset += FRAME_SIZE as u64;
377 }
378 Err(_) => break,
381 }
382 }
383 Ok(())
386 }
387}
388
389impl std::fmt::Debug for Wal {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 f.debug_struct("Wal")
392 .field("path", &self.path)
393 .field("salt", &format_args!("{:#x}", self.header.salt))
394 .field("checkpoint_seq", &self.header.checkpoint_seq)
395 .field("frame_count", &self.frame_count)
396 .field("last_commit_page_count", &self.last_commit_page_count)
397 .finish()
398 }
399}
400
401fn read_header(file: &mut File) -> Result<WalHeader> {
402 let mut buf = [0u8; WAL_HEADER_SIZE];
403 file.seek(SeekFrom::Start(0))?;
404 if file.read_exact(&mut buf).is_err() {
409 return Err(SQLRiteError::General(
410 "file is not a SQLRite WAL (too short / bad magic)".to_string(),
411 ));
412 }
413 if &buf[0..8] != WAL_MAGIC {
414 return Err(SQLRiteError::General(
415 "file is not a SQLRite WAL (bad magic)".to_string(),
416 ));
417 }
418 let version = u32::from_le_bytes(buf[8..12].try_into().unwrap());
419 if version != WAL_FORMAT_VERSION {
420 return Err(SQLRiteError::General(format!(
421 "unsupported WAL format version {version}; this build understands {WAL_FORMAT_VERSION}"
422 )));
423 }
424 let page_size = u32::from_le_bytes(buf[12..16].try_into().unwrap()) as usize;
425 if page_size != PAGE_SIZE {
426 return Err(SQLRiteError::General(format!(
427 "WAL page size {page_size} doesn't match engine's {PAGE_SIZE}"
428 )));
429 }
430 let salt = u32::from_le_bytes(buf[16..20].try_into().unwrap());
431 let checkpoint_seq = u32::from_le_bytes(buf[20..24].try_into().unwrap());
432 Ok(WalHeader {
433 salt,
434 checkpoint_seq,
435 })
436}
437
438fn random_salt() -> u32 {
439 SystemTime::now()
444 .duration_since(UNIX_EPOCH)
445 .map(|d| (d.as_nanos() as u32) ^ (d.as_secs() as u32).rotate_left(13))
446 .unwrap_or(0xdeadbeef)
447}
448
449fn compute_checksum(header_bytes: &[u8], body: &[u8; PAGE_SIZE]) -> u32 {
453 let mut sum: u32 = 0;
454 for &b in header_bytes {
455 sum = sum.rotate_left(1).wrapping_add(b as u32);
456 }
457 for &b in body.iter() {
458 sum = sum.rotate_left(1).wrapping_add(b as u32);
459 }
460 sum
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 fn tmp_wal(name: &str) -> PathBuf {
468 let mut p = std::env::temp_dir();
469 let pid = std::process::id();
470 let nanos = std::time::SystemTime::now()
471 .duration_since(UNIX_EPOCH)
472 .map(|d| d.as_nanos())
473 .unwrap_or(0);
474 p.push(format!("sqlrite-wal-{pid}-{nanos}-{name}.wal"));
475 p
476 }
477
478 fn page(byte: u8) -> Box<[u8; PAGE_SIZE]> {
479 let mut b = Box::new([0u8; PAGE_SIZE]);
480 for (i, slot) in b.iter_mut().enumerate() {
481 *slot = byte.wrapping_add(i as u8);
482 }
483 b
484 }
485
486 #[test]
487 fn create_then_open_round_trips_an_empty_wal() {
488 let p = tmp_wal("empty");
489 let w = Wal::create(&p).unwrap();
490 assert_eq!(w.frame_count(), 0);
491 assert_eq!(w.last_commit_page_count(), None);
492 let salt = w.header().salt;
493 drop(w);
494
495 let w2 = Wal::open(&p).unwrap();
496 assert_eq!(w2.header().salt, salt);
497 assert_eq!(w2.frame_count(), 0);
498 assert_eq!(w2.last_commit_page_count(), None);
499
500 let _ = std::fs::remove_file(&p);
501 }
502
503 #[test]
504 fn single_commit_frame_round_trips() {
505 let p = tmp_wal("one_frame");
506 let mut w = Wal::create(&p).unwrap();
507 let content = page(0xab);
508 w.append_frame(7, &content, Some(42)).unwrap();
509 assert_eq!(w.frame_count(), 1);
510 assert_eq!(w.last_commit_page_count(), Some(42));
511 drop(w);
512
513 let mut w2 = Wal::open(&p).unwrap();
514 assert_eq!(w2.frame_count(), 1);
515 assert_eq!(w2.last_commit_page_count(), Some(42));
516 let read = w2.read_page(7).unwrap().expect("frame should be visible");
517 assert_eq!(read.as_ref(), content.as_ref());
518 assert!(
519 w2.read_page(99).unwrap().is_none(),
520 "untouched page is None"
521 );
522
523 let _ = std::fs::remove_file(&p);
524 }
525
526 #[test]
527 fn multi_frame_commits_and_latest_wins() {
528 let p = tmp_wal("latest_wins");
531 let mut w = Wal::create(&p).unwrap();
532 w.append_frame(1, &page(1), Some(10)).unwrap();
533 w.append_frame(1, &page(2), Some(10)).unwrap();
534 w.append_frame(1, &page(3), Some(10)).unwrap();
535 w.append_frame(2, &page(9), Some(10)).unwrap();
536 assert_eq!(w.frame_count(), 4);
537 drop(w);
538
539 let mut w2 = Wal::open(&p).unwrap();
540 assert_eq!(w2.read_page(1).unwrap().unwrap().as_ref(), page(3).as_ref());
541 assert_eq!(w2.read_page(2).unwrap().unwrap().as_ref(), page(9).as_ref());
542 let _ = std::fs::remove_file(&p);
543 }
544
545 #[test]
546 fn orphan_dirty_tail_preserves_previous_commit() {
547 let p = tmp_wal("dirty_tail");
553 let mut w = Wal::create(&p).unwrap();
554 w.append_frame(5, &page(50), Some(10)).unwrap(); w.append_frame(5, &page(51), None).unwrap(); drop(w);
557
558 let mut w2 = Wal::open(&p).unwrap();
559 let got = w2
562 .read_page(5)
563 .unwrap()
564 .expect("committed V1 should still be visible");
565 assert_eq!(got.as_ref(), page(50).as_ref());
566 assert_eq!(w2.frame_count(), 2);
568 let _ = std::fs::remove_file(&p);
569 }
570
571 #[test]
572 fn uncommitted_frame_for_untouched_page_returns_none() {
573 let p = tmp_wal("dirty_only");
576 let mut w = Wal::create(&p).unwrap();
577 w.append_frame(7, &page(70), None).unwrap(); drop(w);
579
580 let mut w2 = Wal::open(&p).unwrap();
581 assert_eq!(w2.read_page(7).unwrap(), None);
582 let _ = std::fs::remove_file(&p);
583 }
584
585 #[test]
586 fn truncate_resets_to_empty_and_rolls_salt() {
587 let p = tmp_wal("truncate");
588 let mut w = Wal::create(&p).unwrap();
589 w.append_frame(1, &page(11), Some(5)).unwrap();
590 w.append_frame(2, &page(22), Some(5)).unwrap();
591 let seq_before = w.header().checkpoint_seq;
592 let salt_before = w.header().salt;
593 w.truncate().unwrap();
594 assert_eq!(w.frame_count(), 0);
595 assert_eq!(w.last_commit_page_count(), None);
596 assert_eq!(w.header().checkpoint_seq, seq_before + 1);
597 let _ = salt_before; drop(w);
605
606 let mut w2 = Wal::open(&p).unwrap();
609 assert_eq!(w2.frame_count(), 0);
610 assert_eq!(w2.read_page(1).unwrap(), None);
611 assert_eq!(w2.read_page(2).unwrap(), None);
612
613 let _ = std::fs::remove_file(&p);
614 }
615
616 #[test]
617 fn bad_magic_file_is_rejected() {
618 let p = tmp_wal("bad_magic");
619 std::fs::write(&p, b"not a WAL file").unwrap();
620 let err = Wal::open(&p).unwrap_err();
621 assert!(format!("{err}").contains("bad magic"));
622 let _ = std::fs::remove_file(&p);
623 }
624
625 #[test]
626 fn corrupt_frame_body_marks_end_of_log() {
627 let p = tmp_wal("bit_flip");
631 let mut w = Wal::create(&p).unwrap();
632 w.append_frame(1, &page(0x11), Some(5)).unwrap();
633 w.append_frame(2, &page(0x22), Some(5)).unwrap();
634 drop(w);
635
636 let body_offset = WAL_HEADER_SIZE + FRAME_SIZE + FRAME_HEADER_SIZE;
639 let mut buf = std::fs::read(&p).unwrap();
640 buf[body_offset] ^= 0xff;
641 std::fs::write(&p, &buf).unwrap();
642
643 let mut w2 = Wal::open(&p).unwrap();
644 assert_eq!(
646 w2.read_page(1).unwrap().unwrap().as_ref(),
647 page(0x11).as_ref()
648 );
649 assert_eq!(w2.read_page(2).unwrap(), None);
651 assert_eq!(w2.frame_count(), 1);
652
653 let _ = std::fs::remove_file(&p);
654 }
655
656 #[test]
657 fn partial_trailing_frame_is_ignored() {
658 let p = tmp_wal("partial");
662 let mut w = Wal::create(&p).unwrap();
663 w.append_frame(42, &page(42), Some(1)).unwrap();
664 drop(w);
665 {
666 let mut f = OpenOptions::new().write(true).open(&p).unwrap();
667 f.seek(SeekFrom::End(0)).unwrap();
668 f.write_all(&[0xaa; 2000]).unwrap();
669 }
670 let mut w2 = Wal::open(&p).unwrap();
671 assert_eq!(
672 w2.read_page(42).unwrap().unwrap().as_ref(),
673 page(42).as_ref()
674 );
675 assert_eq!(w2.frame_count(), 1);
676 let _ = std::fs::remove_file(&p);
677 }
678}