1use std::fs::{File, OpenOptions};
50use std::io::{self, IoSlice, Write};
51use std::path::Path;
52use std::sync::atomic::{AtomicU64, Ordering};
53
54use crate::txn_wal::TxnWalEntry;
55use parking_lot::Mutex;
56use sochdb_core::{Result, SochDBError};
57
58const BATCH_HEADER_SIZE: usize = 16;
61const BATCH_MAGIC: u32 = 0x42415443; const BATCH_VERSION: u16 = 1;
63
64pub const DEFAULT_MAX_BATCH_SIZE: usize = 1000;
66
67pub const DEFAULT_MAX_BATCH_BYTES: usize = 64 * 1024;
69
70#[derive(Debug, Default, Clone)]
72pub struct BatchedWalStats {
73 pub entries_written: u64,
75 pub batches_written: u64,
77 pub bytes_written: u64,
79 pub syncs_performed: u64,
81 pub avg_batch_size: f64,
83}
84
85pub struct BatchedWalWriter {
90 file: File,
92 pending: Vec<Vec<u8>>,
94 pending_bytes: usize,
96 max_batch_size: usize,
98 max_batch_bytes: usize,
100 header_buf: Vec<u8>,
102 stats: BatchedWalStats,
104}
105
106impl BatchedWalWriter {
107 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
112 Self::with_config(path, DEFAULT_MAX_BATCH_SIZE, DEFAULT_MAX_BATCH_BYTES)
113 }
114
115 pub fn with_config<P: AsRef<Path>>(
122 path: P,
123 max_batch_size: usize,
124 max_batch_bytes: usize,
125 ) -> Result<Self> {
126 let file = OpenOptions::new()
127 .create(true)
128 .append(true)
129 .open(path.as_ref())
130 .map_err(SochDBError::Io)?;
131
132 Ok(Self {
133 file,
134 pending: Vec::with_capacity(max_batch_size),
135 pending_bytes: 0,
136 max_batch_size,
137 max_batch_bytes,
138 header_buf: vec![0u8; BATCH_HEADER_SIZE],
139 stats: BatchedWalStats::default(),
140 })
141 }
142
143 pub fn from_file(file: File) -> Self {
145 Self {
146 file,
147 pending: Vec::with_capacity(DEFAULT_MAX_BATCH_SIZE),
148 pending_bytes: 0,
149 max_batch_size: DEFAULT_MAX_BATCH_SIZE,
150 max_batch_bytes: DEFAULT_MAX_BATCH_BYTES,
151 header_buf: vec![0u8; BATCH_HEADER_SIZE],
152 stats: BatchedWalStats::default(),
153 }
154 }
155
156 pub fn append(&mut self, entry: &TxnWalEntry) -> Result<()> {
161 let serialized = entry.to_bytes();
162 self.pending_bytes += serialized.len();
163 self.pending.push(serialized);
164
165 if self.pending.len() >= self.max_batch_size || self.pending_bytes >= self.max_batch_bytes {
167 self.flush()?;
168 }
169
170 Ok(())
171 }
172
173 #[inline]
175 pub fn append_bytes(&mut self, bytes: Vec<u8>) -> Result<()> {
176 self.pending_bytes += bytes.len();
177 self.pending.push(bytes);
178
179 if self.pending.len() >= self.max_batch_size || self.pending_bytes >= self.max_batch_bytes {
180 self.flush()?;
181 }
182
183 Ok(())
184 }
185
186 pub fn flush(&mut self) -> Result<usize> {
190 if self.pending.is_empty() {
191 return Ok(0);
192 }
193
194 let count = self.pending.len();
195
196 self.header_buf[0..4].copy_from_slice(&BATCH_MAGIC.to_le_bytes());
198 self.header_buf[4..6].copy_from_slice(&BATCH_VERSION.to_le_bytes());
199 self.header_buf[6..8].copy_from_slice(&(count as u16).to_le_bytes());
200 self.header_buf[8..12].copy_from_slice(&(self.pending_bytes as u32).to_le_bytes());
201
202 let checksum = crc32fast::hash(&self.header_buf[..12]);
204 self.header_buf[12..16].copy_from_slice(&checksum.to_le_bytes());
205
206 let mut iovecs: Vec<IoSlice> = Vec::with_capacity(1 + self.pending.len());
208 iovecs.push(IoSlice::new(&self.header_buf));
209 for entry in &self.pending {
210 iovecs.push(IoSlice::new(entry));
211 }
212
213 let expected = BATCH_HEADER_SIZE + self.pending_bytes;
215 let written = self.file.write_vectored(&iovecs).map_err(SochDBError::Io)?;
216
217 if written != expected {
218 return Err(SochDBError::Io(io::Error::new(
219 io::ErrorKind::WriteZero,
220 format!("Incomplete batch write: {} < {}", written, expected),
221 )));
222 }
223
224 self.stats.entries_written += count as u64;
226 self.stats.batches_written += 1;
227 self.stats.bytes_written += written as u64;
228 self.stats.avg_batch_size =
229 self.stats.entries_written as f64 / self.stats.batches_written as f64;
230
231 self.pending.clear();
233 self.pending_bytes = 0;
234
235 Ok(count)
236 }
237
238 pub fn sync(&mut self) -> Result<()> {
240 if !self.pending.is_empty() {
242 self.flush()?;
243 }
244
245 self.file.sync_data().map_err(SochDBError::Io)?;
246
247 self.stats.syncs_performed += 1;
248 Ok(())
249 }
250
251 pub fn stats(&self) -> BatchedWalStats {
253 self.stats.clone()
254 }
255
256 #[inline]
258 pub fn pending_count(&self) -> usize {
259 self.pending.len()
260 }
261
262 #[inline]
264 pub fn pending_bytes(&self) -> usize {
265 self.pending_bytes
266 }
267}
268
269impl Drop for BatchedWalWriter {
270 fn drop(&mut self) {
271 let _ = self.flush();
273 }
274}
275
276pub struct BatchAccumulator {
280 txn_id: u64,
282 entries: Vec<TxnWalEntry>,
284}
285
286impl BatchAccumulator {
287 pub fn new(txn_id: u64) -> Self {
289 Self {
290 txn_id,
291 entries: Vec::with_capacity(16),
292 }
293 }
294
295 pub fn write(&mut self, key: Vec<u8>, value: Vec<u8>) {
297 self.entries
298 .push(TxnWalEntry::data(self.txn_id, key, value));
299 }
300
301 pub fn delete(&mut self, key: Vec<u8>) {
303 self.entries
306 .push(TxnWalEntry::data(self.txn_id, key, Vec::new()));
307 }
308
309 #[inline]
311 pub fn len(&self) -> usize {
312 self.entries.len()
313 }
314
315 #[inline]
317 pub fn is_empty(&self) -> bool {
318 self.entries.is_empty()
319 }
320
321 pub fn commit(mut self, writer: &mut BatchedWalWriter) -> Result<usize> {
329 self.entries.push(TxnWalEntry::txn_commit(self.txn_id));
331
332 let count = self.entries.len();
333
334 for entry in &self.entries {
336 writer.append(entry)?;
337 }
338
339 writer.flush()?;
341 writer.sync()?;
342
343 Ok(count)
344 }
345
346 pub fn abort(self) {
348 }
350
351 #[inline]
353 pub fn txn_id(&self) -> u64 {
354 self.txn_id
355 }
356}
357
358pub struct ConcurrentBatchedWal {
362 inner: Mutex<BatchedWalWriter>,
363 next_txn_id: AtomicU64,
365}
366
367impl ConcurrentBatchedWal {
368 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
370 Ok(Self {
371 inner: Mutex::new(BatchedWalWriter::new(path)?),
372 next_txn_id: AtomicU64::new(1),
373 })
374 }
375
376 pub fn begin(&self) -> BatchAccumulator {
378 let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
379 BatchAccumulator::new(txn_id)
380 }
381
382 pub fn commit(&self, batch: BatchAccumulator) -> Result<usize> {
384 let mut writer = self.inner.lock();
385 batch.commit(&mut writer)
386 }
387
388 pub fn append(&self, entry: &TxnWalEntry) -> Result<()> {
390 self.inner.lock().append(entry)
391 }
392
393 pub fn flush(&self) -> Result<usize> {
395 self.inner.lock().flush()
396 }
397
398 pub fn sync(&self) -> Result<()> {
400 self.inner.lock().sync()
401 }
402
403 pub fn stats(&self) -> BatchedWalStats {
405 self.inner.lock().stats()
406 }
407}
408
409pub struct BatchedWalReader {
413 file: File,
414 position: u64,
415}
416
417impl BatchedWalReader {
418 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
420 let file = File::open(path.as_ref()).map_err(SochDBError::Io)?;
421
422 Ok(Self { file, position: 0 })
423 }
424
425 pub fn read_batch(&mut self) -> Result<Option<Vec<TxnWalEntry>>> {
429 use std::io::Read;
430
431 let mut header = [0u8; BATCH_HEADER_SIZE];
433 match self.file.read_exact(&mut header) {
434 Ok(_) => {}
435 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
436 Err(e) => return Err(SochDBError::Io(e)),
437 }
438
439 let magic = u32::from_le_bytes(header[0..4].try_into().unwrap());
441 if magic != BATCH_MAGIC {
442 return Err(SochDBError::Internal("Invalid batch magic".into()));
443 }
444
445 let _version = u16::from_le_bytes(header[4..6].try_into().unwrap());
447 let entry_count = u16::from_le_bytes(header[6..8].try_into().unwrap()) as usize;
448 let total_bytes = u32::from_le_bytes(header[8..12].try_into().unwrap()) as usize;
449 let stored_checksum = u32::from_le_bytes(header[12..16].try_into().unwrap());
450
451 let computed_checksum = crc32fast::hash(&header[..12]);
453 if stored_checksum != computed_checksum {
454 return Err(SochDBError::Internal(
455 "Batch header checksum mismatch".into(),
456 ));
457 }
458
459 let mut data = vec![0u8; total_bytes];
461 self.file.read_exact(&mut data).map_err(SochDBError::Io)?;
462
463 let mut entries = Vec::with_capacity(entry_count);
465 let mut cursor = std::io::Cursor::new(&data);
466
467 for _ in 0..entry_count {
468 let entry = TxnWalEntry::from_reader(&mut cursor)?;
469 entries.push(entry);
470 }
471
472 self.position += BATCH_HEADER_SIZE as u64 + total_bytes as u64;
473
474 Ok(Some(entries))
475 }
476
477 pub fn position(&self) -> u64 {
479 self.position
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use tempfile::tempdir;
487
488 #[test]
489 fn test_batch_write_and_read() {
490 let dir = tempdir().unwrap();
491 let path = dir.path().join("test.wal");
492
493 {
495 let mut writer = BatchedWalWriter::new(&path).unwrap();
496
497 for i in 0..10 {
498 let entry = TxnWalEntry::data(
499 1,
500 format!("key{}", i).into_bytes(),
501 format!("value{}", i).into_bytes(),
502 );
503 writer.append(&entry).unwrap();
504 }
505
506 writer.flush().unwrap();
507 }
508
509 {
511 let mut reader = BatchedWalReader::open(&path).unwrap();
512 let batch = reader.read_batch().unwrap().unwrap();
513
514 assert_eq!(batch.len(), 10);
515 for (i, entry) in batch.iter().enumerate() {
516 assert_eq!(entry.key, format!("key{}", i).into_bytes());
517 assert_eq!(entry.value, format!("value{}", i).into_bytes());
518 }
519 }
520 }
521
522 #[test]
523 fn test_auto_flush_on_limit() {
524 let dir = tempdir().unwrap();
525 let path = dir.path().join("test.wal");
526
527 let mut writer = BatchedWalWriter::with_config(&path, 5, 1024 * 1024).unwrap();
528
529 for i in 0..4 {
531 let entry = TxnWalEntry::data(1, vec![i], vec![i]);
532 writer.append(&entry).unwrap();
533 }
534 assert_eq!(writer.pending_count(), 4);
535
536 let entry = TxnWalEntry::data(1, vec![4], vec![4]);
538 writer.append(&entry).unwrap();
539 assert_eq!(writer.pending_count(), 0); let stats = writer.stats();
542 assert_eq!(stats.batches_written, 1);
543 assert_eq!(stats.entries_written, 5);
544 }
545
546 #[test]
547 fn test_batch_accumulator() {
548 let dir = tempdir().unwrap();
549 let path = dir.path().join("test.wal");
550
551 let wal = ConcurrentBatchedWal::new(&path).unwrap();
552
553 let mut batch = wal.begin();
555 batch.write(b"key1".to_vec(), b"value1".to_vec());
556 batch.write(b"key2".to_vec(), b"value2".to_vec());
557 batch.write(b"key3".to_vec(), b"value3".to_vec());
558
559 assert_eq!(batch.len(), 3);
560
561 let count = wal.commit(batch).unwrap();
563 assert_eq!(count, 4); let stats = wal.stats();
567 assert_eq!(stats.entries_written, 4);
568 }
569
570 #[test]
571 fn test_batch_abort() {
572 let dir = tempdir().unwrap();
573 let path = dir.path().join("test.wal");
574
575 let wal = ConcurrentBatchedWal::new(&path).unwrap();
576 let wal_stats_before = wal.stats();
577
578 let mut batch = wal.begin();
580 batch.write(b"key1".to_vec(), b"value1".to_vec());
581 batch.write(b"key2".to_vec(), b"value2".to_vec());
582
583 batch.abort();
585
586 let stats = wal.stats();
588 assert_eq!(stats.entries_written, wal_stats_before.entries_written);
589 }
590}