Skip to main content

sochdb_storage/
checkpoint.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! ARIES-Style Checkpointing with WAL Compaction
19//!
20//! From mm.md Task 1.4: Checkpoint and WAL Truncation for Bounded Recovery
21//!
22//! ## Problem
23//!
24//! Without active checkpointing + truncation, recovery requires replaying the entire WAL,
25//! trending toward unbounded startup time as WAL grows.
26//!
27//! ## Solution
28//!
29//! ARIES-style checkpointing with:
30//! 1. Periodic checkpoint triggers (time-based or size-based)
31//! 2. Checkpoint record with active_txns and dirty_pages
32//! 3. Flush all dirty pages to stable storage
33//! 4. Truncate WAL prefix up to checkpoint LSN
34//!
35//! ## Math
36//!
37//! ```text
38//! Without checkpointing:
39//!   Recovery time = O(total_WAL_records) = O(lifetime_operations)
40//!
41//! With checkpointing every C operations:
42//!   Recovery time = O(records_since_checkpoint) ≤ O(C)
43//!
44//! For C = 100,000 records, ~10ms replay time:
45//!   Recovery time bounded at ~1s regardless of DB lifetime
46//!
47//! WAL size bounded: max_size = checkpoint_interval × avg_record_size
48//! ```
49
50use std::collections::HashMap;
51use std::fs;
52use std::path::{Path, PathBuf};
53use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
54use std::sync::Arc;
55use std::time::{Duration, Instant};
56
57use parking_lot::{Mutex, RwLock};
58use serde::{Deserialize, Serialize};
59
60use crate::hlc::HybridLogicalClock;
61use sochdb_core::{Result, SochDBError};
62
63/// Log Sequence Number - monotonically increasing identifier for WAL records
64pub type Lsn = u64;
65
66/// Page identifier
67pub type PageId = u64;
68
69/// Checkpoint interval configuration
70#[derive(Debug, Clone)]
71pub struct CheckpointConfig {
72    /// Maximum WAL size before forced checkpoint (bytes)
73    pub max_wal_size: u64,
74    /// Maximum time between checkpoints
75    pub max_interval: Duration,
76    /// Minimum records before checkpoint
77    pub min_records: u64,
78    /// Whether to truncate WAL after checkpoint
79    pub truncate_wal: bool,
80    /// Whether checkpointing is enabled
81    pub enabled: bool,
82}
83
84impl Default for CheckpointConfig {
85    fn default() -> Self {
86        Self {
87            max_wal_size: 64 * 1024 * 1024, // 64 MB
88            max_interval: Duration::from_secs(60),
89            min_records: 100_000,
90            truncate_wal: true,
91            enabled: true,
92        }
93    }
94}
95
96/// Active transaction entry for checkpoint
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ActiveTransactionEntry {
99    /// Transaction ID
100    pub txn_id: u64,
101    /// First LSN written by this transaction
102    pub first_lsn: Lsn,
103    /// Last LSN written by this transaction
104    pub last_lsn: Lsn,
105    /// Transaction start timestamp
106    pub start_ts: u64,
107}
108
109/// Dirty page entry for checkpoint
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct DirtyPageEntry {
112    /// Page ID
113    pub page_id: PageId,
114    /// Recovery LSN (first LSN that dirtied this page)
115    pub recovery_lsn: Lsn,
116}
117
118/// Checkpoint data written to WAL
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct CheckpointData {
121    /// Checkpoint ID (monotonically increasing)
122    pub checkpoint_id: u64,
123    /// LSN at start of checkpoint
124    pub begin_checkpoint_lsn: Lsn,
125    /// LSN at end of checkpoint
126    pub end_checkpoint_lsn: Lsn,
127    /// Active transactions at checkpoint time
128    pub active_transactions: Vec<ActiveTransactionEntry>,
129    /// Dirty pages at checkpoint time
130    pub dirty_pages: Vec<DirtyPageEntry>,
131    /// Timestamp when checkpoint was taken
132    pub timestamp: u64,
133    /// Oldest LSN needed for recovery (min of active txn first_lsn and dirty page recovery_lsn)
134    pub oldest_required_lsn: Lsn,
135}
136
137impl CheckpointData {
138    /// Create a new checkpoint
139    pub fn new(
140        checkpoint_id: u64,
141        begin_lsn: Lsn,
142        active_txns: Vec<ActiveTransactionEntry>,
143        dirty_pages: Vec<DirtyPageEntry>,
144    ) -> Self {
145        // Calculate oldest required LSN
146        let oldest_txn_lsn = active_txns.iter().map(|t| t.first_lsn).min().unwrap_or(Lsn::MAX);
147        let oldest_page_lsn = dirty_pages.iter().map(|p| p.recovery_lsn).min().unwrap_or(Lsn::MAX);
148        let oldest_required_lsn = oldest_txn_lsn.min(oldest_page_lsn).min(begin_lsn);
149
150        Self {
151            checkpoint_id,
152            begin_checkpoint_lsn: begin_lsn,
153            end_checkpoint_lsn: 0, // Set after checkpoint is complete
154            active_transactions: active_txns,
155            dirty_pages,
156            timestamp: std::time::SystemTime::now()
157                .duration_since(std::time::UNIX_EPOCH)
158                .unwrap()
159                .as_micros() as u64,
160            oldest_required_lsn,
161        }
162    }
163}
164
165/// Checkpoint state persisted to disk
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct CheckpointMeta {
168    /// Last completed checkpoint data
169    pub last_checkpoint: Option<CheckpointData>,
170    /// Total checkpoints taken
171    pub total_checkpoints: u64,
172    /// Total bytes truncated from WAL
173    pub total_bytes_truncated: u64,
174}
175
176impl Default for CheckpointMeta {
177    fn default() -> Self {
178        Self {
179            last_checkpoint: None,
180            total_checkpoints: 0,
181            total_bytes_truncated: 0,
182        }
183    }
184}
185
186/// Dirty page tracker for efficient checkpointing
187pub struct DirtyPageTracker {
188    /// Map of page_id -> recovery_lsn (first LSN that dirtied page)
189    dirty_pages: RwLock<HashMap<PageId, Lsn>>,
190}
191
192impl DirtyPageTracker {
193    pub fn new() -> Self {
194        Self {
195            dirty_pages: RwLock::new(HashMap::new()),
196        }
197    }
198
199    /// Mark a page as dirty with its recovery LSN
200    pub fn mark_dirty(&self, page_id: PageId, lsn: Lsn) {
201        let mut dirty = self.dirty_pages.write();
202        dirty.entry(page_id).or_insert(lsn);
203    }
204
205    /// Mark a page as clean (after flush to disk)
206    pub fn mark_clean(&self, page_id: PageId) {
207        self.dirty_pages.write().remove(&page_id);
208    }
209
210    /// Get all dirty pages for checkpoint
211    pub fn get_dirty_pages(&self) -> Vec<DirtyPageEntry> {
212        self.dirty_pages
213            .read()
214            .iter()
215            .map(|(&page_id, &recovery_lsn)| DirtyPageEntry { page_id, recovery_lsn })
216            .collect()
217    }
218
219    /// Get count of dirty pages
220    pub fn dirty_count(&self) -> usize {
221        self.dirty_pages.read().len()
222    }
223}
224
225impl Default for DirtyPageTracker {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231/// Active transaction tracker for checkpointing
232pub struct ActiveTransactionTracker {
233    /// Map of txn_id -> (first_lsn, last_lsn, start_ts)
234    active_txns: RwLock<HashMap<u64, (Lsn, Lsn, u64)>>,
235}
236
237impl ActiveTransactionTracker {
238    pub fn new() -> Self {
239        Self {
240            active_txns: RwLock::new(HashMap::new()),
241        }
242    }
243
244    /// Register a new transaction
245    pub fn register(&self, txn_id: u64, start_ts: u64) {
246        self.active_txns
247            .write()
248            .insert(txn_id, (Lsn::MAX, 0, start_ts));
249    }
250
251    /// Update transaction's LSN range
252    pub fn update_lsn(&self, txn_id: u64, lsn: Lsn) {
253        if let Some(entry) = self.active_txns.write().get_mut(&txn_id) {
254            if entry.0 == Lsn::MAX {
255                entry.0 = lsn; // First LSN
256            }
257            entry.1 = lsn; // Last LSN
258        }
259    }
260
261    /// Remove a transaction (on commit or abort)
262    pub fn remove(&self, txn_id: u64) {
263        self.active_txns.write().remove(&txn_id);
264    }
265
266    /// Get all active transactions for checkpoint
267    pub fn get_active_transactions(&self) -> Vec<ActiveTransactionEntry> {
268        self.active_txns
269            .read()
270            .iter()
271            .filter(|(_, (first_lsn, _, _))| *first_lsn != Lsn::MAX)
272            .map(|(&txn_id, &(first_lsn, last_lsn, start_ts))| ActiveTransactionEntry {
273                txn_id,
274                first_lsn,
275                last_lsn,
276                start_ts,
277            })
278            .collect()
279    }
280
281    /// Get count of active transactions
282    pub fn active_count(&self) -> usize {
283        self.active_txns.read().len()
284    }
285}
286
287impl Default for ActiveTransactionTracker {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293/// Checkpoint manager
294pub struct CheckpointManager {
295    /// Configuration
296    config: CheckpointConfig,
297    /// Path to checkpoint metadata file
298    meta_path: PathBuf,
299    /// Path to WAL directory
300    #[allow(dead_code)]
301    wal_dir: PathBuf,
302    /// Current checkpoint metadata
303    meta: RwLock<CheckpointMeta>,
304    /// Dirty page tracker
305    dirty_pages: Arc<DirtyPageTracker>,
306    /// Active transaction tracker
307    active_txns: Arc<ActiveTransactionTracker>,
308    /// Current LSN counter
309    current_lsn: AtomicU64,
310    /// Records since last checkpoint
311    records_since_checkpoint: AtomicU64,
312    /// WAL bytes since last checkpoint
313    wal_bytes_since_checkpoint: AtomicU64,
314    /// Last checkpoint time
315    last_checkpoint_time: Mutex<Instant>,
316    /// Checkpoint in progress flag
317    checkpoint_in_progress: AtomicBool,
318    /// Next checkpoint ID
319    next_checkpoint_id: AtomicU64,
320    /// HLC for timestamps
321    #[allow(dead_code)]
322    hlc: Arc<HybridLogicalClock>,
323}
324
325impl CheckpointManager {
326    /// Create a new checkpoint manager
327    pub fn new(
328        data_dir: &Path,
329        config: CheckpointConfig,
330        dirty_pages: Arc<DirtyPageTracker>,
331        active_txns: Arc<ActiveTransactionTracker>,
332        hlc: Arc<HybridLogicalClock>,
333    ) -> Result<Self> {
334        let meta_path = data_dir.join("checkpoint.meta");
335        let wal_dir = data_dir.join("wal");
336
337        // Ensure directories exist
338        fs::create_dir_all(&wal_dir)?;
339
340        // Load existing metadata
341        let meta = if meta_path.exists() {
342            let data = fs::read(&meta_path)?;
343            bincode::deserialize(&data).unwrap_or_default()
344        } else {
345            CheckpointMeta::default()
346        };
347
348        let next_id = meta.last_checkpoint.as_ref().map(|c| c.checkpoint_id + 1).unwrap_or(1);
349        let last_lsn = meta.last_checkpoint.as_ref().map(|c| c.end_checkpoint_lsn).unwrap_or(0);
350
351        Ok(Self {
352            config,
353            meta_path,
354            wal_dir,
355            meta: RwLock::new(meta),
356            dirty_pages,
357            active_txns,
358            current_lsn: AtomicU64::new(last_lsn),
359            records_since_checkpoint: AtomicU64::new(0),
360            wal_bytes_since_checkpoint: AtomicU64::new(0),
361            last_checkpoint_time: Mutex::new(Instant::now()),
362            checkpoint_in_progress: AtomicBool::new(false),
363            next_checkpoint_id: AtomicU64::new(next_id),
364            hlc,
365        })
366    }
367
368    /// Allocate the next LSN
369    #[inline]
370    pub fn next_lsn(&self) -> Lsn {
371        self.current_lsn.fetch_add(1, Ordering::SeqCst)
372    }
373
374    /// Record a WAL write for checkpoint tracking
375    pub fn record_wal_write(&self, bytes: u64) {
376        self.records_since_checkpoint.fetch_add(1, Ordering::Relaxed);
377        self.wal_bytes_since_checkpoint.fetch_add(bytes, Ordering::Relaxed);
378    }
379
380    /// Check if checkpoint is needed
381    pub fn should_checkpoint(&self) -> bool {
382        if !self.config.enabled {
383            return false;
384        }
385
386        if self.checkpoint_in_progress.load(Ordering::Relaxed) {
387            return false;
388        }
389
390        let records = self.records_since_checkpoint.load(Ordering::Relaxed);
391        let bytes = self.wal_bytes_since_checkpoint.load(Ordering::Relaxed);
392        let elapsed = self.last_checkpoint_time.lock().elapsed();
393
394        records >= self.config.min_records
395            || bytes >= self.config.max_wal_size
396            || elapsed >= self.config.max_interval
397    }
398
399    /// Take a checkpoint
400    ///
401    /// This is the main checkpoint operation:
402    /// 1. Write BEGIN_CHECKPOINT record
403    /// 2. Collect active transactions and dirty pages
404    /// 3. Flush all dirty pages to stable storage
405    /// 4. Write END_CHECKPOINT record with collected data
406    /// 5. Optionally truncate WAL
407    pub fn checkpoint<F>(&self, flush_dirty_pages: F) -> Result<CheckpointData>
408    where
409        F: FnOnce(&[DirtyPageEntry]) -> Result<()>,
410    {
411        // Set checkpoint in progress
412        if self
413            .checkpoint_in_progress
414            .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
415            .is_err()
416        {
417            return Err(SochDBError::Internal("Checkpoint already in progress".into()));
418        }
419
420        // Guard to reset flag on exit (manual scope guard)
421        struct CheckpointGuard<'a>(&'a AtomicBool);
422        impl<'a> Drop for CheckpointGuard<'a> {
423            fn drop(&mut self) {
424                self.0.store(false, Ordering::SeqCst);
425            }
426        }
427        let _guard = CheckpointGuard(&self.checkpoint_in_progress);
428
429        let checkpoint_id = self.next_checkpoint_id.fetch_add(1, Ordering::SeqCst);
430        let begin_lsn = self.next_lsn();
431
432        // Collect state
433        let active_txns = self.active_txns.get_active_transactions();
434        let dirty_pages = self.dirty_pages.get_dirty_pages();
435
436        // Create checkpoint data
437        let mut checkpoint = CheckpointData::new(checkpoint_id, begin_lsn, active_txns, dirty_pages.clone());
438
439        // Flush all dirty pages to stable storage
440        flush_dirty_pages(&dirty_pages)?;
441
442        // Mark pages as clean
443        for page in &dirty_pages {
444            self.dirty_pages.mark_clean(page.page_id);
445        }
446
447        // Record end LSN
448        let end_lsn = self.next_lsn();
449        checkpoint.end_checkpoint_lsn = end_lsn;
450
451        // Update metadata
452        {
453            let mut meta = self.meta.write();
454            meta.last_checkpoint = Some(checkpoint.clone());
455            meta.total_checkpoints += 1;
456
457            // Persist metadata
458            let data = bincode::serialize(&*meta).map_err(|e| SochDBError::Serialization(e.to_string()))?;
459            fs::write(&self.meta_path, data)?;
460        }
461
462        // Reset counters
463        self.records_since_checkpoint.store(0, Ordering::Relaxed);
464        self.wal_bytes_since_checkpoint.store(0, Ordering::Relaxed);
465        *self.last_checkpoint_time.lock() = Instant::now();
466
467        // Truncate WAL if configured
468        if self.config.truncate_wal {
469            self.truncate_wal(checkpoint.oldest_required_lsn)?;
470        }
471
472        Ok(checkpoint)
473    }
474
475    /// Truncate WAL up to the given LSN
476    fn truncate_wal(&self, safe_lsn: Lsn) -> Result<()> {
477        // In a real implementation, this would:
478        // 1. Identify WAL segments that can be removed
479        // 2. Rename/archive or delete old segments
480        // 3. Update metadata
481
482        // For now, we just track the truncation point
483        let mut meta = self.meta.write();
484        if let Some(ref checkpoint) = meta.last_checkpoint {
485            let truncated = checkpoint.begin_checkpoint_lsn.saturating_sub(safe_lsn);
486            meta.total_bytes_truncated += truncated;
487        }
488
489        Ok(())
490    }
491
492    /// Get the LSN that is safe for recovery (oldest required LSN)
493    pub fn recovery_lsn(&self) -> Option<Lsn> {
494        self.meta
495            .read()
496            .last_checkpoint
497            .as_ref()
498            .map(|c| c.oldest_required_lsn)
499    }
500
501    /// Get the last checkpoint
502    pub fn last_checkpoint(&self) -> Option<CheckpointData> {
503        self.meta.read().last_checkpoint.clone()
504    }
505
506    /// Get checkpoint statistics
507    pub fn stats(&self) -> CheckpointStats {
508        let meta = self.meta.read();
509        CheckpointStats {
510            total_checkpoints: meta.total_checkpoints,
511            total_bytes_truncated: meta.total_bytes_truncated,
512            records_since_checkpoint: self.records_since_checkpoint.load(Ordering::Relaxed),
513            wal_bytes_since_checkpoint: self.wal_bytes_since_checkpoint.load(Ordering::Relaxed),
514            dirty_pages: self.dirty_pages.dirty_count(),
515            active_transactions: self.active_txns.active_count(),
516        }
517    }
518}
519
520/// Checkpoint statistics
521#[derive(Debug, Clone)]
522pub struct CheckpointStats {
523    pub total_checkpoints: u64,
524    pub total_bytes_truncated: u64,
525    pub records_since_checkpoint: u64,
526    pub wal_bytes_since_checkpoint: u64,
527    pub dirty_pages: usize,
528    pub active_transactions: usize,
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use tempfile::TempDir;
535
536    #[test]
537    fn test_checkpoint_data_creation() {
538        let active_txns = vec![
539            ActiveTransactionEntry {
540                txn_id: 1,
541                first_lsn: 100,
542                last_lsn: 150,
543                start_ts: 1000,
544            },
545            ActiveTransactionEntry {
546                txn_id: 2,
547                first_lsn: 120,
548                last_lsn: 180,
549                start_ts: 1100,
550            },
551        ];
552
553        let dirty_pages = vec![
554            DirtyPageEntry { page_id: 10, recovery_lsn: 90 },
555            DirtyPageEntry { page_id: 20, recovery_lsn: 110 },
556        ];
557
558        let checkpoint = CheckpointData::new(1, 200, active_txns, dirty_pages);
559
560        // Oldest required LSN should be 90 (from dirty page)
561        assert_eq!(checkpoint.oldest_required_lsn, 90);
562    }
563
564    #[test]
565    fn test_dirty_page_tracker() {
566        let tracker = DirtyPageTracker::new();
567
568        tracker.mark_dirty(1, 100);
569        tracker.mark_dirty(2, 110);
570        tracker.mark_dirty(1, 120); // Should not update (already dirty)
571
572        assert_eq!(tracker.dirty_count(), 2);
573
574        let pages = tracker.get_dirty_pages();
575        assert_eq!(pages.len(), 2);
576
577        // First LSN should be preserved
578        let page1 = pages.iter().find(|p| p.page_id == 1).unwrap();
579        assert_eq!(page1.recovery_lsn, 100);
580
581        tracker.mark_clean(1);
582        assert_eq!(tracker.dirty_count(), 1);
583    }
584
585    #[test]
586    fn test_active_transaction_tracker() {
587        let tracker = ActiveTransactionTracker::new();
588
589        tracker.register(1, 1000);
590        tracker.update_lsn(1, 100);
591        tracker.update_lsn(1, 150);
592
593        tracker.register(2, 1100);
594        tracker.update_lsn(2, 120);
595
596        assert_eq!(tracker.active_count(), 2);
597
598        let txns = tracker.get_active_transactions();
599        assert_eq!(txns.len(), 2);
600
601        let txn1 = txns.iter().find(|t| t.txn_id == 1).unwrap();
602        assert_eq!(txn1.first_lsn, 100);
603        assert_eq!(txn1.last_lsn, 150);
604
605        tracker.remove(1);
606        assert_eq!(tracker.active_count(), 1);
607    }
608
609    #[test]
610    fn test_checkpoint_manager() -> Result<()> {
611        let temp_dir = TempDir::new().unwrap();
612        let dirty_pages = Arc::new(DirtyPageTracker::new());
613        let active_txns = Arc::new(ActiveTransactionTracker::new());
614        let hlc = Arc::new(HybridLogicalClock::new());
615
616        let manager = CheckpointManager::new(
617            temp_dir.path(),
618            CheckpointConfig::default(),
619            dirty_pages.clone(),
620            active_txns.clone(),
621            hlc,
622        )?;
623
624        // Mark some dirty pages
625        dirty_pages.mark_dirty(1, manager.next_lsn());
626        dirty_pages.mark_dirty(2, manager.next_lsn());
627
628        // Register a transaction
629        active_txns.register(100, 1000);
630        active_txns.update_lsn(100, manager.next_lsn());
631
632        // Take a checkpoint
633        let checkpoint = manager.checkpoint(|_pages| Ok(()))?;
634
635        assert_eq!(checkpoint.checkpoint_id, 1);
636        assert_eq!(checkpoint.dirty_pages.len(), 2);
637        assert_eq!(checkpoint.active_transactions.len(), 1);
638
639        Ok(())
640    }
641}