1use super::record::{MIN_RECORD_SIZE, WalRecord, WalRecordType};
4use crate::error::{Result, XervError};
5use crate::types::{NodeId, TraceId};
6use byteorder::{LittleEndian, ReadBytesExt};
7use fs2::FileExt;
8use parking_lot::Mutex;
9use std::collections::HashMap;
10use std::fs::{File, OpenOptions};
11use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14
15#[derive(Debug, Clone)]
17pub struct WalConfig {
18 pub directory: PathBuf,
20 pub max_file_size: u64,
22 pub sync_on_write: bool,
24 pub buffer_size: usize,
26}
27
28impl Default for WalConfig {
29 fn default() -> Self {
30 Self {
31 directory: PathBuf::from("/tmp/xerv/wal"),
32 max_file_size: 64 * 1024 * 1024, sync_on_write: true,
34 buffer_size: 64 * 1024, }
36 }
37}
38
39impl WalConfig {
40 pub fn in_memory() -> Self {
42 Self {
43 directory: std::env::temp_dir().join(format!("xerv_wal_{}", uuid::Uuid::new_v4())),
44 max_file_size: 64 * 1024 * 1024,
45 sync_on_write: false,
46 buffer_size: 64 * 1024,
47 }
48 }
49
50 pub fn with_directory(mut self, dir: impl Into<PathBuf>) -> Self {
52 self.directory = dir.into();
53 self
54 }
55
56 pub fn with_sync(mut self, sync: bool) -> Self {
58 self.sync_on_write = sync;
59 self
60 }
61}
62
63struct WalInner {
65 file: BufWriter<File>,
67 path: PathBuf,
69 file_size: u64,
71 config: WalConfig,
73 sequence: u64,
75}
76
77pub struct Wal {
79 inner: Arc<Mutex<WalInner>>,
80}
81
82impl Wal {
83 pub fn open(config: WalConfig) -> Result<Self> {
85 std::fs::create_dir_all(&config.directory).map_err(|e| XervError::WalWrite {
87 trace_id: TraceId::new(),
88 cause: format!("Failed to create WAL directory: {}", e),
89 })?;
90
91 let (path, sequence) = find_or_create_wal_file(&config.directory)?;
93
94 let file = OpenOptions::new()
95 .create(true)
96 .append(true)
97 .open(&path)
98 .map_err(|e| XervError::WalWrite {
99 trace_id: TraceId::new(),
100 cause: format!("Failed to open WAL file: {}", e),
101 })?;
102
103 file.try_lock_exclusive().map_err(|e| XervError::WalWrite {
105 trace_id: TraceId::new(),
106 cause: format!("Failed to lock WAL file: {}", e),
107 })?;
108
109 let file_size = file.metadata().map(|m| m.len()).unwrap_or(0);
110
111 let inner = WalInner {
112 file: BufWriter::with_capacity(config.buffer_size, file),
113 path,
114 file_size,
115 config,
116 sequence,
117 };
118
119 Ok(Self {
120 inner: Arc::new(Mutex::new(inner)),
121 })
122 }
123
124 pub fn write(&self, record: &WalRecord) -> Result<()> {
126 let mut inner = self.inner.lock();
127
128 let bytes = record.to_bytes().map_err(|e| XervError::WalWrite {
129 trace_id: record.trace_id,
130 cause: e.to_string(),
131 })?;
132
133 if inner.file_size + bytes.len() as u64 > inner.config.max_file_size {
135 self.rotate_locked(&mut inner)?;
136 }
137
138 inner
139 .file
140 .write_all(&bytes)
141 .map_err(|e| XervError::WalWrite {
142 trace_id: record.trace_id,
143 cause: e.to_string(),
144 })?;
145
146 inner.file_size += bytes.len() as u64;
147
148 if inner.config.sync_on_write {
149 inner.file.flush().map_err(|e| XervError::WalWrite {
150 trace_id: record.trace_id,
151 cause: e.to_string(),
152 })?;
153 inner
154 .file
155 .get_ref()
156 .sync_data()
157 .map_err(|e| XervError::WalWrite {
158 trace_id: record.trace_id,
159 cause: e.to_string(),
160 })?;
161 }
162
163 Ok(())
164 }
165
166 pub fn flush(&self) -> Result<()> {
168 let mut inner = self.inner.lock();
169 inner.file.flush().map_err(|e| XervError::WalWrite {
170 trace_id: TraceId::new(),
171 cause: e.to_string(),
172 })?;
173 inner
174 .file
175 .get_ref()
176 .sync_data()
177 .map_err(|e| XervError::WalWrite {
178 trace_id: TraceId::new(),
179 cause: e.to_string(),
180 })
181 }
182
183 fn rotate_locked(&self, inner: &mut WalInner) -> Result<()> {
185 inner.file.flush().map_err(|e| XervError::WalWrite {
187 trace_id: TraceId::new(),
188 cause: e.to_string(),
189 })?;
190
191 inner.sequence += 1;
193 let new_path = inner
194 .config
195 .directory
196 .join(format!("wal_{:016x}.log", inner.sequence));
197
198 let new_file = OpenOptions::new()
199 .create(true)
200 .append(true)
201 .open(&new_path)
202 .map_err(|e| XervError::WalWrite {
203 trace_id: TraceId::new(),
204 cause: format!("Failed to create new WAL file: {}", e),
205 })?;
206
207 new_file
208 .try_lock_exclusive()
209 .map_err(|e| XervError::WalWrite {
210 trace_id: TraceId::new(),
211 cause: format!("Failed to lock new WAL file: {}", e),
212 })?;
213
214 let _ = inner.file.get_ref().unlock();
216
217 inner.file = BufWriter::with_capacity(inner.config.buffer_size, new_file);
218 inner.path = new_path;
219 inner.file_size = 0;
220
221 Ok(())
222 }
223
224 pub fn path(&self) -> PathBuf {
226 self.inner.lock().path.clone()
227 }
228
229 pub fn reader(&self) -> WalReader {
231 let inner = self.inner.lock();
232 WalReader {
233 directory: inner.config.directory.clone(),
234 }
235 }
236}
237
238impl Drop for Wal {
239 fn drop(&mut self) {
240 if let Some(inner) = Arc::get_mut(&mut self.inner) {
241 let inner = inner.get_mut();
242 let _ = inner.file.flush();
243 let _ = inner.file.get_ref().unlock();
244 }
245 }
246}
247
248fn find_or_create_wal_file(directory: &Path) -> Result<(PathBuf, u64)> {
250 let mut max_sequence = 0u64;
251
252 if let Ok(entries) = std::fs::read_dir(directory) {
253 for entry in entries.flatten() {
254 let name = entry.file_name();
255 let name_str = name.to_string_lossy();
256
257 if name_str.starts_with("wal_") && name_str.ends_with(".log") {
258 if let Some(seq_str) = name_str
259 .strip_prefix("wal_")
260 .and_then(|s| s.strip_suffix(".log"))
261 {
262 if let Ok(seq) = u64::from_str_radix(seq_str, 16) {
263 max_sequence = max_sequence.max(seq);
264 }
265 }
266 }
267 }
268 }
269
270 let path = directory.join(format!("wal_{:016x}.log", max_sequence));
272
273 if path.exists() {
274 if let Ok(meta) = std::fs::metadata(&path) {
275 if meta.len() > 32 * 1024 * 1024 {
277 let new_seq = max_sequence + 1;
278 let new_path = directory.join(format!("wal_{:016x}.log", new_seq));
279 return Ok((new_path, new_seq));
280 }
281 }
282 }
283
284 Ok((path, max_sequence))
286}
287
288pub struct WalReader {
290 directory: PathBuf,
291}
292
293impl WalReader {
294 pub fn new(directory: impl Into<PathBuf>) -> Self {
296 Self {
297 directory: directory.into(),
298 }
299 }
300
301 pub fn read_all(&self) -> Result<Vec<WalRecord>> {
303 let mut records = Vec::new();
304 let mut files: Vec<PathBuf> = Vec::new();
305
306 if let Ok(entries) = std::fs::read_dir(&self.directory) {
308 for entry in entries.flatten() {
309 let path = entry.path();
310 if path.extension().is_some_and(|ext| ext == "log") {
311 files.push(path);
312 }
313 }
314 }
315
316 files.sort();
318
319 for path in files {
321 records.extend(self.read_file(&path)?);
322 }
323
324 Ok(records)
325 }
326
327 fn read_file(&self, path: &Path) -> Result<Vec<WalRecord>> {
329 let file = File::open(path).map_err(|e| XervError::WalRead {
330 cause: format!("Failed to open {}: {}", path.display(), e),
331 })?;
332
333 let mut reader = BufReader::new(file);
334 let mut records = Vec::new();
335
336 loop {
337 let length = match reader.read_u32::<LittleEndian>() {
339 Ok(len) => len as usize,
340 Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
341 Err(e) => {
342 return Err(XervError::WalRead {
343 cause: format!("Failed to read record length: {}", e),
344 });
345 }
346 };
347
348 if length < MIN_RECORD_SIZE {
349 return Err(XervError::WalCorruption {
350 position: reader.stream_position().unwrap_or(0),
351 cause: format!("Invalid record length: {}", length),
352 });
353 }
354
355 reader
357 .seek(SeekFrom::Current(-4))
358 .map_err(|e| XervError::WalRead {
359 cause: format!("Seek failed: {}", e),
360 })?;
361
362 let mut buf = vec![0u8; length];
364 reader
365 .read_exact(&mut buf)
366 .map_err(|e| XervError::WalRead {
367 cause: format!("Failed to read record: {}", e),
368 })?;
369
370 match WalRecord::from_bytes(&buf) {
371 Ok(record) => records.push(record),
372 Err(e) => {
373 tracing::warn!("Corrupted WAL record at {}: {}", path.display(), e);
375 }
376 }
377 }
378
379 Ok(records)
380 }
381
382 pub fn get_incomplete_traces(&self) -> Result<HashMap<TraceId, TraceRecoveryState>> {
386 let records = self.read_all()?;
387 let mut traces: HashMap<TraceId, TraceRecoveryState> = HashMap::new();
388
389 for record in records {
390 match record.record_type {
391 WalRecordType::TraceStart => {
392 traces.insert(
393 record.trace_id,
394 TraceRecoveryState {
395 trace_id: record.trace_id,
396 last_completed_node: None,
397 suspended_at: None,
398 started_nodes: Vec::new(),
399 completed_nodes: HashMap::new(),
400 },
401 );
402 }
403 WalRecordType::NodeStart => {
404 if let Some(state) = traces.get_mut(&record.trace_id) {
405 state.started_nodes.push(record.node_id);
406 }
407 }
408 WalRecordType::NodeDone => {
409 if let Some(state) = traces.get_mut(&record.trace_id) {
410 state.last_completed_node = Some(record.node_id);
411 state.started_nodes.retain(|&n| n != record.node_id);
412 state.completed_nodes.insert(
414 record.node_id,
415 NodeOutputLocation {
416 offset: record.output_offset,
417 size: record.output_size,
418 schema_hash: record.schema_hash,
419 },
420 );
421 }
422 }
423 WalRecordType::TraceComplete | WalRecordType::TraceFailed => {
424 traces.remove(&record.trace_id);
425 }
426 WalRecordType::TraceSuspended => {
427 if let Some(state) = traces.get_mut(&record.trace_id) {
428 state.suspended_at = Some(record.node_id);
429 }
430 }
431 WalRecordType::TraceResumed => {
432 if let Some(state) = traces.get_mut(&record.trace_id) {
433 state.suspended_at = None;
434 }
435 }
436 _ => {}
437 }
438 }
439
440 Ok(traces)
441 }
442}
443
444#[derive(Debug, Clone)]
446pub struct TraceRecoveryState {
447 pub trace_id: TraceId,
449 pub last_completed_node: Option<NodeId>,
451 pub suspended_at: Option<NodeId>,
453 pub started_nodes: Vec<NodeId>,
455 pub completed_nodes: HashMap<NodeId, NodeOutputLocation>,
457}
458
459#[derive(Debug, Clone, Copy)]
461pub struct NodeOutputLocation {
462 pub offset: crate::types::ArenaOffset,
464 pub size: u32,
466 pub schema_hash: u64,
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::types::ArenaOffset;
474 use tempfile::tempdir;
475
476 #[test]
477 fn wal_write_and_read() {
478 let dir = tempdir().unwrap();
479 let config = WalConfig::default()
480 .with_directory(dir.path())
481 .with_sync(false);
482
483 let wal = Wal::open(config).unwrap();
484
485 let trace_id = TraceId::new();
486 let node_id = NodeId::new(1);
487
488 wal.write(&WalRecord::trace_start(trace_id)).unwrap();
490 wal.write(&WalRecord::node_start(trace_id, node_id))
491 .unwrap();
492 wal.write(&WalRecord::node_done(
493 trace_id,
494 node_id,
495 ArenaOffset::new(0x100),
496 64,
497 0,
498 ))
499 .unwrap();
500 wal.write(&WalRecord::trace_complete(trace_id)).unwrap();
501 wal.flush().unwrap();
502
503 let reader = wal.reader();
505 let records = reader.read_all().unwrap();
506
507 assert_eq!(records.len(), 4);
508 assert_eq!(records[0].record_type, WalRecordType::TraceStart);
509 assert_eq!(records[1].record_type, WalRecordType::NodeStart);
510 assert_eq!(records[2].record_type, WalRecordType::NodeDone);
511 assert_eq!(records[3].record_type, WalRecordType::TraceComplete);
512 }
513
514 #[test]
515 fn wal_incomplete_trace_detection() {
516 let dir = tempdir().unwrap();
517 let config = WalConfig::default()
518 .with_directory(dir.path())
519 .with_sync(false);
520
521 let wal = Wal::open(config).unwrap();
522
523 let trace1 = TraceId::new();
524 let trace2 = TraceId::new();
525 let node_id = NodeId::new(1);
526
527 wal.write(&WalRecord::trace_start(trace1)).unwrap();
529 wal.write(&WalRecord::node_done(
530 trace1,
531 node_id,
532 ArenaOffset::NULL,
533 0,
534 0,
535 ))
536 .unwrap();
537 wal.write(&WalRecord::trace_complete(trace1)).unwrap();
538
539 wal.write(&WalRecord::trace_start(trace2)).unwrap();
541 wal.write(&WalRecord::node_start(trace2, node_id)).unwrap();
542 wal.flush().unwrap();
545
546 let reader = wal.reader();
547 let incomplete = reader.get_incomplete_traces().unwrap();
548
549 assert!(!incomplete.contains_key(&trace1));
550 assert!(incomplete.contains_key(&trace2));
551
552 let state = incomplete.get(&trace2).unwrap();
553 assert!(state.started_nodes.contains(&node_id));
554 }
555}