xerv_core/wal/
writer.rs

1//! WAL writer and reader implementation.
2
3use super::record::{MIN_RECORD_SIZE, WalRecord, WalRecordType};
4use crate::error::{Result, XervError};
5use crate::types::{NodeId, TraceId};
6use byteorder::{LittleEndian, ReadBytesExt};
7use fs2::FileExt;
8use parking_lot::Mutex;
9use std::collections::HashMap;
10use std::fs::{File, OpenOptions};
11use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14
15/// Configuration for WAL creation.
16#[derive(Debug, Clone)]
17pub struct WalConfig {
18    /// Directory for WAL files.
19    pub directory: PathBuf,
20    /// Maximum WAL file size before rotation.
21    pub max_file_size: u64,
22    /// Whether to sync after each write.
23    pub sync_on_write: bool,
24    /// Buffer size for writes.
25    pub buffer_size: usize,
26}
27
28impl Default for WalConfig {
29    fn default() -> Self {
30        Self {
31            directory: PathBuf::from("/tmp/xerv/wal"),
32            max_file_size: 64 * 1024 * 1024, // 64 MB
33            sync_on_write: true,
34            buffer_size: 64 * 1024, // 64 KB
35        }
36    }
37}
38
39impl WalConfig {
40    /// Create an in-memory WAL configuration (uses a temp directory).
41    pub fn in_memory() -> Self {
42        Self {
43            directory: std::env::temp_dir().join(format!("xerv_wal_{}", uuid::Uuid::new_v4())),
44            max_file_size: 64 * 1024 * 1024,
45            sync_on_write: false,
46            buffer_size: 64 * 1024,
47        }
48    }
49
50    /// Set the WAL directory.
51    pub fn with_directory(mut self, dir: impl Into<PathBuf>) -> Self {
52        self.directory = dir.into();
53        self
54    }
55
56    /// Set sync on write.
57    pub fn with_sync(mut self, sync: bool) -> Self {
58        self.sync_on_write = sync;
59        self
60    }
61}
62
63/// Internal state of the WAL writer.
64struct WalInner {
65    /// Current WAL file.
66    file: BufWriter<File>,
67    /// Path to current file.
68    path: PathBuf,
69    /// Current file size.
70    file_size: u64,
71    /// Configuration.
72    config: WalConfig,
73    /// Sequence number.
74    sequence: u64,
75}
76
77/// Write-Ahead Log for durability.
78pub struct Wal {
79    inner: Arc<Mutex<WalInner>>,
80}
81
82impl Wal {
83    /// Create or open a WAL.
84    pub fn open(config: WalConfig) -> Result<Self> {
85        // Ensure directory exists
86        std::fs::create_dir_all(&config.directory).map_err(|e| XervError::WalWrite {
87            trace_id: TraceId::new(),
88            cause: format!("Failed to create WAL directory: {}", e),
89        })?;
90
91        // Find or create WAL file
92        let (path, sequence) = find_or_create_wal_file(&config.directory)?;
93
94        let file = OpenOptions::new()
95            .create(true)
96            .append(true)
97            .open(&path)
98            .map_err(|e| XervError::WalWrite {
99                trace_id: TraceId::new(),
100                cause: format!("Failed to open WAL file: {}", e),
101            })?;
102
103        // Lock the file
104        file.try_lock_exclusive().map_err(|e| XervError::WalWrite {
105            trace_id: TraceId::new(),
106            cause: format!("Failed to lock WAL file: {}", e),
107        })?;
108
109        let file_size = file.metadata().map(|m| m.len()).unwrap_or(0);
110
111        let inner = WalInner {
112            file: BufWriter::with_capacity(config.buffer_size, file),
113            path,
114            file_size,
115            config,
116            sequence,
117        };
118
119        Ok(Self {
120            inner: Arc::new(Mutex::new(inner)),
121        })
122    }
123
124    /// Write a record to the WAL.
125    pub fn write(&self, record: &WalRecord) -> Result<()> {
126        let mut inner = self.inner.lock();
127
128        let bytes = record.to_bytes().map_err(|e| XervError::WalWrite {
129            trace_id: record.trace_id,
130            cause: e.to_string(),
131        })?;
132
133        // Check if we need to rotate
134        if inner.file_size + bytes.len() as u64 > inner.config.max_file_size {
135            self.rotate_locked(&mut inner)?;
136        }
137
138        inner
139            .file
140            .write_all(&bytes)
141            .map_err(|e| XervError::WalWrite {
142                trace_id: record.trace_id,
143                cause: e.to_string(),
144            })?;
145
146        inner.file_size += bytes.len() as u64;
147
148        if inner.config.sync_on_write {
149            inner.file.flush().map_err(|e| XervError::WalWrite {
150                trace_id: record.trace_id,
151                cause: e.to_string(),
152            })?;
153            inner
154                .file
155                .get_ref()
156                .sync_data()
157                .map_err(|e| XervError::WalWrite {
158                    trace_id: record.trace_id,
159                    cause: e.to_string(),
160                })?;
161        }
162
163        Ok(())
164    }
165
166    /// Flush pending writes.
167    pub fn flush(&self) -> Result<()> {
168        let mut inner = self.inner.lock();
169        inner.file.flush().map_err(|e| XervError::WalWrite {
170            trace_id: TraceId::new(),
171            cause: e.to_string(),
172        })?;
173        inner
174            .file
175            .get_ref()
176            .sync_data()
177            .map_err(|e| XervError::WalWrite {
178                trace_id: TraceId::new(),
179                cause: e.to_string(),
180            })
181    }
182
183    /// Rotate to a new WAL file.
184    fn rotate_locked(&self, inner: &mut WalInner) -> Result<()> {
185        // Flush current file
186        inner.file.flush().map_err(|e| XervError::WalWrite {
187            trace_id: TraceId::new(),
188            cause: e.to_string(),
189        })?;
190
191        // Create new file
192        inner.sequence += 1;
193        let new_path = inner
194            .config
195            .directory
196            .join(format!("wal_{:016x}.log", inner.sequence));
197
198        let new_file = OpenOptions::new()
199            .create(true)
200            .append(true)
201            .open(&new_path)
202            .map_err(|e| XervError::WalWrite {
203                trace_id: TraceId::new(),
204                cause: format!("Failed to create new WAL file: {}", e),
205            })?;
206
207        new_file
208            .try_lock_exclusive()
209            .map_err(|e| XervError::WalWrite {
210                trace_id: TraceId::new(),
211                cause: format!("Failed to lock new WAL file: {}", e),
212            })?;
213
214        // Unlock old file
215        let _ = inner.file.get_ref().unlock();
216
217        inner.file = BufWriter::with_capacity(inner.config.buffer_size, new_file);
218        inner.path = new_path;
219        inner.file_size = 0;
220
221        Ok(())
222    }
223
224    /// Get the current WAL file path.
225    pub fn path(&self) -> PathBuf {
226        self.inner.lock().path.clone()
227    }
228
229    /// Create a reader for recovery.
230    pub fn reader(&self) -> WalReader {
231        let inner = self.inner.lock();
232        WalReader {
233            directory: inner.config.directory.clone(),
234        }
235    }
236}
237
238impl Drop for Wal {
239    fn drop(&mut self) {
240        if let Some(inner) = Arc::get_mut(&mut self.inner) {
241            let inner = inner.get_mut();
242            let _ = inner.file.flush();
243            let _ = inner.file.get_ref().unlock();
244        }
245    }
246}
247
248/// Find the latest WAL file or create a new one.
249fn find_or_create_wal_file(directory: &Path) -> Result<(PathBuf, u64)> {
250    let mut max_sequence = 0u64;
251
252    if let Ok(entries) = std::fs::read_dir(directory) {
253        for entry in entries.flatten() {
254            let name = entry.file_name();
255            let name_str = name.to_string_lossy();
256
257            if name_str.starts_with("wal_") && name_str.ends_with(".log") {
258                if let Some(seq_str) = name_str
259                    .strip_prefix("wal_")
260                    .and_then(|s| s.strip_suffix(".log"))
261                {
262                    if let Ok(seq) = u64::from_str_radix(seq_str, 16) {
263                        max_sequence = max_sequence.max(seq);
264                    }
265                }
266            }
267        }
268    }
269
270    // Use the latest file if it exists and is not too large, otherwise create new
271    let path = directory.join(format!("wal_{:016x}.log", max_sequence));
272
273    if path.exists() {
274        if let Ok(meta) = std::fs::metadata(&path) {
275            // If file is already large, create a new one
276            if meta.len() > 32 * 1024 * 1024 {
277                let new_seq = max_sequence + 1;
278                let new_path = directory.join(format!("wal_{:016x}.log", new_seq));
279                return Ok((new_path, new_seq));
280            }
281        }
282    }
283
284    // If max_sequence is 0, this creates wal_0000000000000000.log
285    Ok((path, max_sequence))
286}
287
288/// WAL reader for recovery.
289pub struct WalReader {
290    directory: PathBuf,
291}
292
293impl WalReader {
294    /// Create a new WAL reader.
295    pub fn new(directory: impl Into<PathBuf>) -> Self {
296        Self {
297            directory: directory.into(),
298        }
299    }
300
301    /// Read all records from all WAL files.
302    pub fn read_all(&self) -> Result<Vec<WalRecord>> {
303        let mut records = Vec::new();
304        let mut files: Vec<PathBuf> = Vec::new();
305
306        // Collect all WAL files
307        if let Ok(entries) = std::fs::read_dir(&self.directory) {
308            for entry in entries.flatten() {
309                let path = entry.path();
310                if path.extension().is_some_and(|ext| ext == "log") {
311                    files.push(path);
312                }
313            }
314        }
315
316        // Sort by name (which includes sequence number)
317        files.sort();
318
319        // Read each file
320        for path in files {
321            records.extend(self.read_file(&path)?);
322        }
323
324        Ok(records)
325    }
326
327    /// Read records from a single WAL file.
328    fn read_file(&self, path: &Path) -> Result<Vec<WalRecord>> {
329        let file = File::open(path).map_err(|e| XervError::WalRead {
330            cause: format!("Failed to open {}: {}", path.display(), e),
331        })?;
332
333        let mut reader = BufReader::new(file);
334        let mut records = Vec::new();
335
336        loop {
337            // Try to read record length
338            let length = match reader.read_u32::<LittleEndian>() {
339                Ok(len) => len as usize,
340                Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
341                Err(e) => {
342                    return Err(XervError::WalRead {
343                        cause: format!("Failed to read record length: {}", e),
344                    });
345                }
346            };
347
348            if length < MIN_RECORD_SIZE {
349                return Err(XervError::WalCorruption {
350                    position: reader.stream_position().unwrap_or(0),
351                    cause: format!("Invalid record length: {}", length),
352                });
353            }
354
355            // Seek back to include length in the record
356            reader
357                .seek(SeekFrom::Current(-4))
358                .map_err(|e| XervError::WalRead {
359                    cause: format!("Seek failed: {}", e),
360                })?;
361
362            // Read full record
363            let mut buf = vec![0u8; length];
364            reader
365                .read_exact(&mut buf)
366                .map_err(|e| XervError::WalRead {
367                    cause: format!("Failed to read record: {}", e),
368                })?;
369
370            match WalRecord::from_bytes(&buf) {
371                Ok(record) => records.push(record),
372                Err(e) => {
373                    // Log corruption but continue reading
374                    tracing::warn!("Corrupted WAL record at {}: {}", path.display(), e);
375                }
376            }
377        }
378
379        Ok(records)
380    }
381
382    /// Get the state of in-flight traces from WAL records.
383    ///
384    /// Returns traces that started but did not complete or fail.
385    pub fn get_incomplete_traces(&self) -> Result<HashMap<TraceId, TraceRecoveryState>> {
386        let records = self.read_all()?;
387        let mut traces: HashMap<TraceId, TraceRecoveryState> = HashMap::new();
388
389        for record in records {
390            match record.record_type {
391                WalRecordType::TraceStart => {
392                    traces.insert(
393                        record.trace_id,
394                        TraceRecoveryState {
395                            trace_id: record.trace_id,
396                            last_completed_node: None,
397                            suspended_at: None,
398                            started_nodes: Vec::new(),
399                            completed_nodes: HashMap::new(),
400                        },
401                    );
402                }
403                WalRecordType::NodeStart => {
404                    if let Some(state) = traces.get_mut(&record.trace_id) {
405                        state.started_nodes.push(record.node_id);
406                    }
407                }
408                WalRecordType::NodeDone => {
409                    if let Some(state) = traces.get_mut(&record.trace_id) {
410                        state.last_completed_node = Some(record.node_id);
411                        state.started_nodes.retain(|&n| n != record.node_id);
412                        // Store the output location for recovery
413                        state.completed_nodes.insert(
414                            record.node_id,
415                            NodeOutputLocation {
416                                offset: record.output_offset,
417                                size: record.output_size,
418                                schema_hash: record.schema_hash,
419                            },
420                        );
421                    }
422                }
423                WalRecordType::TraceComplete | WalRecordType::TraceFailed => {
424                    traces.remove(&record.trace_id);
425                }
426                WalRecordType::TraceSuspended => {
427                    if let Some(state) = traces.get_mut(&record.trace_id) {
428                        state.suspended_at = Some(record.node_id);
429                    }
430                }
431                WalRecordType::TraceResumed => {
432                    if let Some(state) = traces.get_mut(&record.trace_id) {
433                        state.suspended_at = None;
434                    }
435                }
436                _ => {}
437            }
438        }
439
440        Ok(traces)
441    }
442}
443
444/// State needed to recover an incomplete trace.
445#[derive(Debug, Clone)]
446pub struct TraceRecoveryState {
447    /// The trace identifier.
448    pub trace_id: TraceId,
449    /// The last node that completed successfully (if any).
450    pub last_completed_node: Option<NodeId>,
451    /// Node where the trace is currently suspended (if any).
452    pub suspended_at: Option<NodeId>,
453    /// Nodes that have started execution but did not complete.
454    pub started_nodes: Vec<NodeId>,
455    /// Map of completed nodes to their output locations in the arena.
456    pub completed_nodes: HashMap<NodeId, NodeOutputLocation>,
457}
458
459/// Location of a node's output in the arena.
460#[derive(Debug, Clone, Copy)]
461pub struct NodeOutputLocation {
462    /// The arena offset where the node's output data is stored.
463    pub offset: crate::types::ArenaOffset,
464    /// Size of the output data in bytes.
465    pub size: u32,
466    /// Schema hash used to verify the output data type.
467    pub schema_hash: u64,
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::types::ArenaOffset;
474    use tempfile::tempdir;
475
476    #[test]
477    fn wal_write_and_read() {
478        let dir = tempdir().unwrap();
479        let config = WalConfig::default()
480            .with_directory(dir.path())
481            .with_sync(false);
482
483        let wal = Wal::open(config).unwrap();
484
485        let trace_id = TraceId::new();
486        let node_id = NodeId::new(1);
487
488        // Write some records
489        wal.write(&WalRecord::trace_start(trace_id)).unwrap();
490        wal.write(&WalRecord::node_start(trace_id, node_id))
491            .unwrap();
492        wal.write(&WalRecord::node_done(
493            trace_id,
494            node_id,
495            ArenaOffset::new(0x100),
496            64,
497            0,
498        ))
499        .unwrap();
500        wal.write(&WalRecord::trace_complete(trace_id)).unwrap();
501        wal.flush().unwrap();
502
503        // Read back
504        let reader = wal.reader();
505        let records = reader.read_all().unwrap();
506
507        assert_eq!(records.len(), 4);
508        assert_eq!(records[0].record_type, WalRecordType::TraceStart);
509        assert_eq!(records[1].record_type, WalRecordType::NodeStart);
510        assert_eq!(records[2].record_type, WalRecordType::NodeDone);
511        assert_eq!(records[3].record_type, WalRecordType::TraceComplete);
512    }
513
514    #[test]
515    fn wal_incomplete_trace_detection() {
516        let dir = tempdir().unwrap();
517        let config = WalConfig::default()
518            .with_directory(dir.path())
519            .with_sync(false);
520
521        let wal = Wal::open(config).unwrap();
522
523        let trace1 = TraceId::new();
524        let trace2 = TraceId::new();
525        let node_id = NodeId::new(1);
526
527        // Trace 1: complete
528        wal.write(&WalRecord::trace_start(trace1)).unwrap();
529        wal.write(&WalRecord::node_done(
530            trace1,
531            node_id,
532            ArenaOffset::NULL,
533            0,
534            0,
535        ))
536        .unwrap();
537        wal.write(&WalRecord::trace_complete(trace1)).unwrap();
538
539        // Trace 2: incomplete (crashed during node execution)
540        wal.write(&WalRecord::trace_start(trace2)).unwrap();
541        wal.write(&WalRecord::node_start(trace2, node_id)).unwrap();
542        // No NodeDone or TraceComplete
543
544        wal.flush().unwrap();
545
546        let reader = wal.reader();
547        let incomplete = reader.get_incomplete_traces().unwrap();
548
549        assert!(!incomplete.contains_key(&trace1));
550        assert!(incomplete.contains_key(&trace2));
551
552        let state = incomplete.get(&trace2).unwrap();
553        assert!(state.started_nodes.contains(&node_id));
554    }
555}