1use std::fs::{File, OpenOptions};
47use std::io::{self, IoSlice, Write};
48use std::path::Path;
49use std::sync::atomic::{AtomicU64, Ordering};
50
51use crate::txn_wal::TxnWalEntry;
52use parking_lot::Mutex;
53use sochdb_core::{Result, SochDBError};
54
55const BATCH_HEADER_SIZE: usize = 16;
58const BATCH_MAGIC: u32 = 0x42415443; const BATCH_VERSION: u16 = 1;
60
61pub const DEFAULT_MAX_BATCH_SIZE: usize = 1000;
63
64pub const DEFAULT_MAX_BATCH_BYTES: usize = 64 * 1024;
66
67#[derive(Debug, Default, Clone)]
69pub struct BatchedWalStats {
70 pub entries_written: u64,
72 pub batches_written: u64,
74 pub bytes_written: u64,
76 pub syncs_performed: u64,
78 pub avg_batch_size: f64,
80}
81
82pub struct BatchedWalWriter {
87 file: File,
89 pending: Vec<Vec<u8>>,
91 pending_bytes: usize,
93 max_batch_size: usize,
95 max_batch_bytes: usize,
97 header_buf: Vec<u8>,
99 stats: BatchedWalStats,
101}
102
103impl BatchedWalWriter {
104 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
109 Self::with_config(path, DEFAULT_MAX_BATCH_SIZE, DEFAULT_MAX_BATCH_BYTES)
110 }
111
112 pub fn with_config<P: AsRef<Path>>(
119 path: P,
120 max_batch_size: usize,
121 max_batch_bytes: usize,
122 ) -> Result<Self> {
123 let file = OpenOptions::new()
124 .create(true)
125 .append(true)
126 .open(path.as_ref())
127 .map_err(SochDBError::Io)?;
128
129 Ok(Self {
130 file,
131 pending: Vec::with_capacity(max_batch_size),
132 pending_bytes: 0,
133 max_batch_size,
134 max_batch_bytes,
135 header_buf: vec![0u8; BATCH_HEADER_SIZE],
136 stats: BatchedWalStats::default(),
137 })
138 }
139
140 pub fn from_file(file: File) -> Self {
142 Self {
143 file,
144 pending: Vec::with_capacity(DEFAULT_MAX_BATCH_SIZE),
145 pending_bytes: 0,
146 max_batch_size: DEFAULT_MAX_BATCH_SIZE,
147 max_batch_bytes: DEFAULT_MAX_BATCH_BYTES,
148 header_buf: vec![0u8; BATCH_HEADER_SIZE],
149 stats: BatchedWalStats::default(),
150 }
151 }
152
153 pub fn append(&mut self, entry: &TxnWalEntry) -> Result<()> {
158 let serialized = entry.to_bytes();
159 self.pending_bytes += serialized.len();
160 self.pending.push(serialized);
161
162 if self.pending.len() >= self.max_batch_size || self.pending_bytes >= self.max_batch_bytes {
164 self.flush()?;
165 }
166
167 Ok(())
168 }
169
170 #[inline]
172 pub fn append_bytes(&mut self, bytes: Vec<u8>) -> Result<()> {
173 self.pending_bytes += bytes.len();
174 self.pending.push(bytes);
175
176 if self.pending.len() >= self.max_batch_size || self.pending_bytes >= self.max_batch_bytes {
177 self.flush()?;
178 }
179
180 Ok(())
181 }
182
183 pub fn flush(&mut self) -> Result<usize> {
187 if self.pending.is_empty() {
188 return Ok(0);
189 }
190
191 let count = self.pending.len();
192
193 self.header_buf[0..4].copy_from_slice(&BATCH_MAGIC.to_le_bytes());
195 self.header_buf[4..6].copy_from_slice(&BATCH_VERSION.to_le_bytes());
196 self.header_buf[6..8].copy_from_slice(&(count as u16).to_le_bytes());
197 self.header_buf[8..12].copy_from_slice(&(self.pending_bytes as u32).to_le_bytes());
198
199 let checksum = crc32fast::hash(&self.header_buf[..12]);
201 self.header_buf[12..16].copy_from_slice(&checksum.to_le_bytes());
202
203 let mut iovecs: Vec<IoSlice> = Vec::with_capacity(1 + self.pending.len());
205 iovecs.push(IoSlice::new(&self.header_buf));
206 for entry in &self.pending {
207 iovecs.push(IoSlice::new(entry));
208 }
209
210 let expected = BATCH_HEADER_SIZE + self.pending_bytes;
212 let written = self.file.write_vectored(&iovecs).map_err(SochDBError::Io)?;
213
214 if written != expected {
215 return Err(SochDBError::Io(io::Error::new(
216 io::ErrorKind::WriteZero,
217 format!("Incomplete batch write: {} < {}", written, expected),
218 )));
219 }
220
221 self.stats.entries_written += count as u64;
223 self.stats.batches_written += 1;
224 self.stats.bytes_written += written as u64;
225 self.stats.avg_batch_size =
226 self.stats.entries_written as f64 / self.stats.batches_written as f64;
227
228 self.pending.clear();
230 self.pending_bytes = 0;
231
232 Ok(count)
233 }
234
235 pub fn sync(&mut self) -> Result<()> {
237 if !self.pending.is_empty() {
239 self.flush()?;
240 }
241
242 self.file.sync_data().map_err(SochDBError::Io)?;
243
244 self.stats.syncs_performed += 1;
245 Ok(())
246 }
247
248 pub fn stats(&self) -> BatchedWalStats {
250 self.stats.clone()
251 }
252
253 #[inline]
255 pub fn pending_count(&self) -> usize {
256 self.pending.len()
257 }
258
259 #[inline]
261 pub fn pending_bytes(&self) -> usize {
262 self.pending_bytes
263 }
264}
265
266impl Drop for BatchedWalWriter {
267 fn drop(&mut self) {
268 let _ = self.flush();
270 }
271}
272
273pub struct BatchAccumulator {
277 txn_id: u64,
279 entries: Vec<TxnWalEntry>,
281}
282
283impl BatchAccumulator {
284 pub fn new(txn_id: u64) -> Self {
286 Self {
287 txn_id,
288 entries: Vec::with_capacity(16),
289 }
290 }
291
292 pub fn write(&mut self, key: Vec<u8>, value: Vec<u8>) {
294 self.entries
295 .push(TxnWalEntry::data(self.txn_id, key, value));
296 }
297
298 pub fn delete(&mut self, key: Vec<u8>) {
300 self.entries
303 .push(TxnWalEntry::data(self.txn_id, key, Vec::new()));
304 }
305
306 #[inline]
308 pub fn len(&self) -> usize {
309 self.entries.len()
310 }
311
312 #[inline]
314 pub fn is_empty(&self) -> bool {
315 self.entries.is_empty()
316 }
317
318 pub fn commit(mut self, writer: &mut BatchedWalWriter) -> Result<usize> {
326 self.entries.push(TxnWalEntry::txn_commit(self.txn_id));
328
329 let count = self.entries.len();
330
331 for entry in &self.entries {
333 writer.append(entry)?;
334 }
335
336 writer.flush()?;
338 writer.sync()?;
339
340 Ok(count)
341 }
342
343 pub fn abort(self) {
345 }
347
348 #[inline]
350 pub fn txn_id(&self) -> u64 {
351 self.txn_id
352 }
353}
354
355pub struct ConcurrentBatchedWal {
359 inner: Mutex<BatchedWalWriter>,
360 next_txn_id: AtomicU64,
362}
363
364impl ConcurrentBatchedWal {
365 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
367 Ok(Self {
368 inner: Mutex::new(BatchedWalWriter::new(path)?),
369 next_txn_id: AtomicU64::new(1),
370 })
371 }
372
373 pub fn begin(&self) -> BatchAccumulator {
375 let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
376 BatchAccumulator::new(txn_id)
377 }
378
379 pub fn commit(&self, batch: BatchAccumulator) -> Result<usize> {
381 let mut writer = self.inner.lock();
382 batch.commit(&mut writer)
383 }
384
385 pub fn append(&self, entry: &TxnWalEntry) -> Result<()> {
387 self.inner.lock().append(entry)
388 }
389
390 pub fn flush(&self) -> Result<usize> {
392 self.inner.lock().flush()
393 }
394
395 pub fn sync(&self) -> Result<()> {
397 self.inner.lock().sync()
398 }
399
400 pub fn stats(&self) -> BatchedWalStats {
402 self.inner.lock().stats()
403 }
404}
405
406pub struct BatchedWalReader {
410 file: File,
411 position: u64,
412}
413
414impl BatchedWalReader {
415 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
417 let file = File::open(path.as_ref()).map_err(SochDBError::Io)?;
418
419 Ok(Self { file, position: 0 })
420 }
421
422 pub fn read_batch(&mut self) -> Result<Option<Vec<TxnWalEntry>>> {
426 use std::io::Read;
427
428 let mut header = [0u8; BATCH_HEADER_SIZE];
430 match self.file.read_exact(&mut header) {
431 Ok(_) => {}
432 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
433 Err(e) => return Err(SochDBError::Io(e)),
434 }
435
436 let magic = u32::from_le_bytes(header[0..4].try_into().unwrap());
438 if magic != BATCH_MAGIC {
439 return Err(SochDBError::Internal("Invalid batch magic".into()));
440 }
441
442 let _version = u16::from_le_bytes(header[4..6].try_into().unwrap());
444 let entry_count = u16::from_le_bytes(header[6..8].try_into().unwrap()) as usize;
445 let total_bytes = u32::from_le_bytes(header[8..12].try_into().unwrap()) as usize;
446 let stored_checksum = u32::from_le_bytes(header[12..16].try_into().unwrap());
447
448 let computed_checksum = crc32fast::hash(&header[..12]);
450 if stored_checksum != computed_checksum {
451 return Err(SochDBError::Internal(
452 "Batch header checksum mismatch".into(),
453 ));
454 }
455
456 let mut data = vec![0u8; total_bytes];
458 self.file.read_exact(&mut data).map_err(SochDBError::Io)?;
459
460 let mut entries = Vec::with_capacity(entry_count);
462 let mut cursor = std::io::Cursor::new(&data);
463
464 for _ in 0..entry_count {
465 let entry = TxnWalEntry::from_reader(&mut cursor)?;
466 entries.push(entry);
467 }
468
469 self.position += BATCH_HEADER_SIZE as u64 + total_bytes as u64;
470
471 Ok(Some(entries))
472 }
473
474 pub fn position(&self) -> u64 {
476 self.position
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use tempfile::tempdir;
484
485 #[test]
486 fn test_batch_write_and_read() {
487 let dir = tempdir().unwrap();
488 let path = dir.path().join("test.wal");
489
490 {
492 let mut writer = BatchedWalWriter::new(&path).unwrap();
493
494 for i in 0..10 {
495 let entry = TxnWalEntry::data(
496 1,
497 format!("key{}", i).into_bytes(),
498 format!("value{}", i).into_bytes(),
499 );
500 writer.append(&entry).unwrap();
501 }
502
503 writer.flush().unwrap();
504 }
505
506 {
508 let mut reader = BatchedWalReader::open(&path).unwrap();
509 let batch = reader.read_batch().unwrap().unwrap();
510
511 assert_eq!(batch.len(), 10);
512 for (i, entry) in batch.iter().enumerate() {
513 assert_eq!(entry.key, format!("key{}", i).into_bytes());
514 assert_eq!(entry.value, format!("value{}", i).into_bytes());
515 }
516 }
517 }
518
519 #[test]
520 fn test_auto_flush_on_limit() {
521 let dir = tempdir().unwrap();
522 let path = dir.path().join("test.wal");
523
524 let mut writer = BatchedWalWriter::with_config(&path, 5, 1024 * 1024).unwrap();
525
526 for i in 0..4 {
528 let entry = TxnWalEntry::data(1, vec![i], vec![i]);
529 writer.append(&entry).unwrap();
530 }
531 assert_eq!(writer.pending_count(), 4);
532
533 let entry = TxnWalEntry::data(1, vec![4], vec![4]);
535 writer.append(&entry).unwrap();
536 assert_eq!(writer.pending_count(), 0); let stats = writer.stats();
539 assert_eq!(stats.batches_written, 1);
540 assert_eq!(stats.entries_written, 5);
541 }
542
543 #[test]
544 fn test_batch_accumulator() {
545 let dir = tempdir().unwrap();
546 let path = dir.path().join("test.wal");
547
548 let wal = ConcurrentBatchedWal::new(&path).unwrap();
549
550 let mut batch = wal.begin();
552 batch.write(b"key1".to_vec(), b"value1".to_vec());
553 batch.write(b"key2".to_vec(), b"value2".to_vec());
554 batch.write(b"key3".to_vec(), b"value3".to_vec());
555
556 assert_eq!(batch.len(), 3);
557
558 let count = wal.commit(batch).unwrap();
560 assert_eq!(count, 4); let stats = wal.stats();
564 assert_eq!(stats.entries_written, 4);
565 }
566
567 #[test]
568 fn test_batch_abort() {
569 let dir = tempdir().unwrap();
570 let path = dir.path().join("test.wal");
571
572 let wal = ConcurrentBatchedWal::new(&path).unwrap();
573 let wal_stats_before = wal.stats();
574
575 let mut batch = wal.begin();
577 batch.write(b"key1".to_vec(), b"value1".to_vec());
578 batch.write(b"key2".to_vec(), b"value2".to_vec());
579
580 batch.abort();
582
583 let stats = wal.stats();
585 assert_eq!(stats.entries_written, wal_stats_before.entries_written);
586 }
587}