sochdb_storage/
checkpoint.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! ARIES-Style Checkpointing with WAL Compaction
16//!
17//! From mm.md Task 1.4: Checkpoint and WAL Truncation for Bounded Recovery
18//!
19//! ## Problem
20//!
21//! Without active checkpointing + truncation, recovery requires replaying the entire WAL,
22//! trending toward unbounded startup time as WAL grows.
23//!
24//! ## Solution
25//!
26//! ARIES-style checkpointing with:
27//! 1. Periodic checkpoint triggers (time-based or size-based)
28//! 2. Checkpoint record with active_txns and dirty_pages
29//! 3. Flush all dirty pages to stable storage
30//! 4. Truncate WAL prefix up to checkpoint LSN
31//!
32//! ## Math
33//!
34//! ```text
35//! Without checkpointing:
36//!   Recovery time = O(total_WAL_records) = O(lifetime_operations)
37//!
38//! With checkpointing every C operations:
39//!   Recovery time = O(records_since_checkpoint) ≤ O(C)
40//!
41//! For C = 100,000 records, ~10ms replay time:
42//!   Recovery time bounded at ~1s regardless of DB lifetime
43//!
44//! WAL size bounded: max_size = checkpoint_interval × avg_record_size
45//! ```
46
47use std::collections::HashMap;
48use std::fs;
49use std::path::{Path, PathBuf};
50use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53
54use parking_lot::{Mutex, RwLock};
55use serde::{Deserialize, Serialize};
56
57use crate::hlc::HybridLogicalClock;
58use sochdb_core::{Result, SochDBError};
59
60/// Log Sequence Number - monotonically increasing identifier for WAL records
61pub type Lsn = u64;
62
63/// Page identifier
64pub type PageId = u64;
65
66/// Checkpoint interval configuration
67#[derive(Debug, Clone)]
68pub struct CheckpointConfig {
69    /// Maximum WAL size before forced checkpoint (bytes)
70    pub max_wal_size: u64,
71    /// Maximum time between checkpoints
72    pub max_interval: Duration,
73    /// Minimum records before checkpoint
74    pub min_records: u64,
75    /// Whether to truncate WAL after checkpoint
76    pub truncate_wal: bool,
77    /// Whether checkpointing is enabled
78    pub enabled: bool,
79}
80
81impl Default for CheckpointConfig {
82    fn default() -> Self {
83        Self {
84            max_wal_size: 64 * 1024 * 1024, // 64 MB
85            max_interval: Duration::from_secs(60),
86            min_records: 100_000,
87            truncate_wal: true,
88            enabled: true,
89        }
90    }
91}
92
93/// Active transaction entry for checkpoint
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ActiveTransactionEntry {
96    /// Transaction ID
97    pub txn_id: u64,
98    /// First LSN written by this transaction
99    pub first_lsn: Lsn,
100    /// Last LSN written by this transaction
101    pub last_lsn: Lsn,
102    /// Transaction start timestamp
103    pub start_ts: u64,
104}
105
106/// Dirty page entry for checkpoint
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct DirtyPageEntry {
109    /// Page ID
110    pub page_id: PageId,
111    /// Recovery LSN (first LSN that dirtied this page)
112    pub recovery_lsn: Lsn,
113}
114
115/// Checkpoint data written to WAL
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CheckpointData {
118    /// Checkpoint ID (monotonically increasing)
119    pub checkpoint_id: u64,
120    /// LSN at start of checkpoint
121    pub begin_checkpoint_lsn: Lsn,
122    /// LSN at end of checkpoint
123    pub end_checkpoint_lsn: Lsn,
124    /// Active transactions at checkpoint time
125    pub active_transactions: Vec<ActiveTransactionEntry>,
126    /// Dirty pages at checkpoint time
127    pub dirty_pages: Vec<DirtyPageEntry>,
128    /// Timestamp when checkpoint was taken
129    pub timestamp: u64,
130    /// Oldest LSN needed for recovery (min of active txn first_lsn and dirty page recovery_lsn)
131    pub oldest_required_lsn: Lsn,
132}
133
134impl CheckpointData {
135    /// Create a new checkpoint
136    pub fn new(
137        checkpoint_id: u64,
138        begin_lsn: Lsn,
139        active_txns: Vec<ActiveTransactionEntry>,
140        dirty_pages: Vec<DirtyPageEntry>,
141    ) -> Self {
142        // Calculate oldest required LSN
143        let oldest_txn_lsn = active_txns.iter().map(|t| t.first_lsn).min().unwrap_or(Lsn::MAX);
144        let oldest_page_lsn = dirty_pages.iter().map(|p| p.recovery_lsn).min().unwrap_or(Lsn::MAX);
145        let oldest_required_lsn = oldest_txn_lsn.min(oldest_page_lsn).min(begin_lsn);
146
147        Self {
148            checkpoint_id,
149            begin_checkpoint_lsn: begin_lsn,
150            end_checkpoint_lsn: 0, // Set after checkpoint is complete
151            active_transactions: active_txns,
152            dirty_pages,
153            timestamp: std::time::SystemTime::now()
154                .duration_since(std::time::UNIX_EPOCH)
155                .unwrap()
156                .as_micros() as u64,
157            oldest_required_lsn,
158        }
159    }
160}
161
162/// Checkpoint state persisted to disk
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct CheckpointMeta {
165    /// Last completed checkpoint data
166    pub last_checkpoint: Option<CheckpointData>,
167    /// Total checkpoints taken
168    pub total_checkpoints: u64,
169    /// Total bytes truncated from WAL
170    pub total_bytes_truncated: u64,
171}
172
173impl Default for CheckpointMeta {
174    fn default() -> Self {
175        Self {
176            last_checkpoint: None,
177            total_checkpoints: 0,
178            total_bytes_truncated: 0,
179        }
180    }
181}
182
183/// Dirty page tracker for efficient checkpointing
184pub struct DirtyPageTracker {
185    /// Map of page_id -> recovery_lsn (first LSN that dirtied page)
186    dirty_pages: RwLock<HashMap<PageId, Lsn>>,
187}
188
189impl DirtyPageTracker {
190    pub fn new() -> Self {
191        Self {
192            dirty_pages: RwLock::new(HashMap::new()),
193        }
194    }
195
196    /// Mark a page as dirty with its recovery LSN
197    pub fn mark_dirty(&self, page_id: PageId, lsn: Lsn) {
198        let mut dirty = self.dirty_pages.write();
199        dirty.entry(page_id).or_insert(lsn);
200    }
201
202    /// Mark a page as clean (after flush to disk)
203    pub fn mark_clean(&self, page_id: PageId) {
204        self.dirty_pages.write().remove(&page_id);
205    }
206
207    /// Get all dirty pages for checkpoint
208    pub fn get_dirty_pages(&self) -> Vec<DirtyPageEntry> {
209        self.dirty_pages
210            .read()
211            .iter()
212            .map(|(&page_id, &recovery_lsn)| DirtyPageEntry { page_id, recovery_lsn })
213            .collect()
214    }
215
216    /// Get count of dirty pages
217    pub fn dirty_count(&self) -> usize {
218        self.dirty_pages.read().len()
219    }
220}
221
222impl Default for DirtyPageTracker {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228/// Active transaction tracker for checkpointing
229pub struct ActiveTransactionTracker {
230    /// Map of txn_id -> (first_lsn, last_lsn, start_ts)
231    active_txns: RwLock<HashMap<u64, (Lsn, Lsn, u64)>>,
232}
233
234impl ActiveTransactionTracker {
235    pub fn new() -> Self {
236        Self {
237            active_txns: RwLock::new(HashMap::new()),
238        }
239    }
240
241    /// Register a new transaction
242    pub fn register(&self, txn_id: u64, start_ts: u64) {
243        self.active_txns
244            .write()
245            .insert(txn_id, (Lsn::MAX, 0, start_ts));
246    }
247
248    /// Update transaction's LSN range
249    pub fn update_lsn(&self, txn_id: u64, lsn: Lsn) {
250        if let Some(entry) = self.active_txns.write().get_mut(&txn_id) {
251            if entry.0 == Lsn::MAX {
252                entry.0 = lsn; // First LSN
253            }
254            entry.1 = lsn; // Last LSN
255        }
256    }
257
258    /// Remove a transaction (on commit or abort)
259    pub fn remove(&self, txn_id: u64) {
260        self.active_txns.write().remove(&txn_id);
261    }
262
263    /// Get all active transactions for checkpoint
264    pub fn get_active_transactions(&self) -> Vec<ActiveTransactionEntry> {
265        self.active_txns
266            .read()
267            .iter()
268            .filter(|(_, (first_lsn, _, _))| *first_lsn != Lsn::MAX)
269            .map(|(&txn_id, &(first_lsn, last_lsn, start_ts))| ActiveTransactionEntry {
270                txn_id,
271                first_lsn,
272                last_lsn,
273                start_ts,
274            })
275            .collect()
276    }
277
278    /// Get count of active transactions
279    pub fn active_count(&self) -> usize {
280        self.active_txns.read().len()
281    }
282}
283
284impl Default for ActiveTransactionTracker {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290/// Checkpoint manager
291pub struct CheckpointManager {
292    /// Configuration
293    config: CheckpointConfig,
294    /// Path to checkpoint metadata file
295    meta_path: PathBuf,
296    /// Path to WAL directory
297    #[allow(dead_code)]
298    wal_dir: PathBuf,
299    /// Current checkpoint metadata
300    meta: RwLock<CheckpointMeta>,
301    /// Dirty page tracker
302    dirty_pages: Arc<DirtyPageTracker>,
303    /// Active transaction tracker
304    active_txns: Arc<ActiveTransactionTracker>,
305    /// Current LSN counter
306    current_lsn: AtomicU64,
307    /// Records since last checkpoint
308    records_since_checkpoint: AtomicU64,
309    /// WAL bytes since last checkpoint
310    wal_bytes_since_checkpoint: AtomicU64,
311    /// Last checkpoint time
312    last_checkpoint_time: Mutex<Instant>,
313    /// Checkpoint in progress flag
314    checkpoint_in_progress: AtomicBool,
315    /// Next checkpoint ID
316    next_checkpoint_id: AtomicU64,
317    /// HLC for timestamps
318    #[allow(dead_code)]
319    hlc: Arc<HybridLogicalClock>,
320}
321
322impl CheckpointManager {
323    /// Create a new checkpoint manager
324    pub fn new(
325        data_dir: &Path,
326        config: CheckpointConfig,
327        dirty_pages: Arc<DirtyPageTracker>,
328        active_txns: Arc<ActiveTransactionTracker>,
329        hlc: Arc<HybridLogicalClock>,
330    ) -> Result<Self> {
331        let meta_path = data_dir.join("checkpoint.meta");
332        let wal_dir = data_dir.join("wal");
333
334        // Ensure directories exist
335        fs::create_dir_all(&wal_dir)?;
336
337        // Load existing metadata
338        let meta = if meta_path.exists() {
339            let data = fs::read(&meta_path)?;
340            bincode::deserialize(&data).unwrap_or_default()
341        } else {
342            CheckpointMeta::default()
343        };
344
345        let next_id = meta.last_checkpoint.as_ref().map(|c| c.checkpoint_id + 1).unwrap_or(1);
346        let last_lsn = meta.last_checkpoint.as_ref().map(|c| c.end_checkpoint_lsn).unwrap_or(0);
347
348        Ok(Self {
349            config,
350            meta_path,
351            wal_dir,
352            meta: RwLock::new(meta),
353            dirty_pages,
354            active_txns,
355            current_lsn: AtomicU64::new(last_lsn),
356            records_since_checkpoint: AtomicU64::new(0),
357            wal_bytes_since_checkpoint: AtomicU64::new(0),
358            last_checkpoint_time: Mutex::new(Instant::now()),
359            checkpoint_in_progress: AtomicBool::new(false),
360            next_checkpoint_id: AtomicU64::new(next_id),
361            hlc,
362        })
363    }
364
365    /// Allocate the next LSN
366    #[inline]
367    pub fn next_lsn(&self) -> Lsn {
368        self.current_lsn.fetch_add(1, Ordering::SeqCst)
369    }
370
371    /// Record a WAL write for checkpoint tracking
372    pub fn record_wal_write(&self, bytes: u64) {
373        self.records_since_checkpoint.fetch_add(1, Ordering::Relaxed);
374        self.wal_bytes_since_checkpoint.fetch_add(bytes, Ordering::Relaxed);
375    }
376
377    /// Check if checkpoint is needed
378    pub fn should_checkpoint(&self) -> bool {
379        if !self.config.enabled {
380            return false;
381        }
382
383        if self.checkpoint_in_progress.load(Ordering::Relaxed) {
384            return false;
385        }
386
387        let records = self.records_since_checkpoint.load(Ordering::Relaxed);
388        let bytes = self.wal_bytes_since_checkpoint.load(Ordering::Relaxed);
389        let elapsed = self.last_checkpoint_time.lock().elapsed();
390
391        records >= self.config.min_records
392            || bytes >= self.config.max_wal_size
393            || elapsed >= self.config.max_interval
394    }
395
396    /// Take a checkpoint
397    ///
398    /// This is the main checkpoint operation:
399    /// 1. Write BEGIN_CHECKPOINT record
400    /// 2. Collect active transactions and dirty pages
401    /// 3. Flush all dirty pages to stable storage
402    /// 4. Write END_CHECKPOINT record with collected data
403    /// 5. Optionally truncate WAL
404    pub fn checkpoint<F>(&self, flush_dirty_pages: F) -> Result<CheckpointData>
405    where
406        F: FnOnce(&[DirtyPageEntry]) -> Result<()>,
407    {
408        // Set checkpoint in progress
409        if self
410            .checkpoint_in_progress
411            .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
412            .is_err()
413        {
414            return Err(SochDBError::Internal("Checkpoint already in progress".into()));
415        }
416
417        // Guard to reset flag on exit (manual scope guard)
418        struct CheckpointGuard<'a>(&'a AtomicBool);
419        impl<'a> Drop for CheckpointGuard<'a> {
420            fn drop(&mut self) {
421                self.0.store(false, Ordering::SeqCst);
422            }
423        }
424        let _guard = CheckpointGuard(&self.checkpoint_in_progress);
425
426        let checkpoint_id = self.next_checkpoint_id.fetch_add(1, Ordering::SeqCst);
427        let begin_lsn = self.next_lsn();
428
429        // Collect state
430        let active_txns = self.active_txns.get_active_transactions();
431        let dirty_pages = self.dirty_pages.get_dirty_pages();
432
433        // Create checkpoint data
434        let mut checkpoint = CheckpointData::new(checkpoint_id, begin_lsn, active_txns, dirty_pages.clone());
435
436        // Flush all dirty pages to stable storage
437        flush_dirty_pages(&dirty_pages)?;
438
439        // Mark pages as clean
440        for page in &dirty_pages {
441            self.dirty_pages.mark_clean(page.page_id);
442        }
443
444        // Record end LSN
445        let end_lsn = self.next_lsn();
446        checkpoint.end_checkpoint_lsn = end_lsn;
447
448        // Update metadata
449        {
450            let mut meta = self.meta.write();
451            meta.last_checkpoint = Some(checkpoint.clone());
452            meta.total_checkpoints += 1;
453
454            // Persist metadata
455            let data = bincode::serialize(&*meta).map_err(|e| SochDBError::Serialization(e.to_string()))?;
456            fs::write(&self.meta_path, data)?;
457        }
458
459        // Reset counters
460        self.records_since_checkpoint.store(0, Ordering::Relaxed);
461        self.wal_bytes_since_checkpoint.store(0, Ordering::Relaxed);
462        *self.last_checkpoint_time.lock() = Instant::now();
463
464        // Truncate WAL if configured
465        if self.config.truncate_wal {
466            self.truncate_wal(checkpoint.oldest_required_lsn)?;
467        }
468
469        Ok(checkpoint)
470    }
471
472    /// Truncate WAL up to the given LSN
473    fn truncate_wal(&self, safe_lsn: Lsn) -> Result<()> {
474        // In a real implementation, this would:
475        // 1. Identify WAL segments that can be removed
476        // 2. Rename/archive or delete old segments
477        // 3. Update metadata
478
479        // For now, we just track the truncation point
480        let mut meta = self.meta.write();
481        if let Some(ref checkpoint) = meta.last_checkpoint {
482            let truncated = checkpoint.begin_checkpoint_lsn.saturating_sub(safe_lsn);
483            meta.total_bytes_truncated += truncated;
484        }
485
486        Ok(())
487    }
488
489    /// Get the LSN that is safe for recovery (oldest required LSN)
490    pub fn recovery_lsn(&self) -> Option<Lsn> {
491        self.meta
492            .read()
493            .last_checkpoint
494            .as_ref()
495            .map(|c| c.oldest_required_lsn)
496    }
497
498    /// Get the last checkpoint
499    pub fn last_checkpoint(&self) -> Option<CheckpointData> {
500        self.meta.read().last_checkpoint.clone()
501    }
502
503    /// Get checkpoint statistics
504    pub fn stats(&self) -> CheckpointStats {
505        let meta = self.meta.read();
506        CheckpointStats {
507            total_checkpoints: meta.total_checkpoints,
508            total_bytes_truncated: meta.total_bytes_truncated,
509            records_since_checkpoint: self.records_since_checkpoint.load(Ordering::Relaxed),
510            wal_bytes_since_checkpoint: self.wal_bytes_since_checkpoint.load(Ordering::Relaxed),
511            dirty_pages: self.dirty_pages.dirty_count(),
512            active_transactions: self.active_txns.active_count(),
513        }
514    }
515}
516
517/// Checkpoint statistics
518#[derive(Debug, Clone)]
519pub struct CheckpointStats {
520    pub total_checkpoints: u64,
521    pub total_bytes_truncated: u64,
522    pub records_since_checkpoint: u64,
523    pub wal_bytes_since_checkpoint: u64,
524    pub dirty_pages: usize,
525    pub active_transactions: usize,
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use tempfile::TempDir;
532
533    #[test]
534    fn test_checkpoint_data_creation() {
535        let active_txns = vec![
536            ActiveTransactionEntry {
537                txn_id: 1,
538                first_lsn: 100,
539                last_lsn: 150,
540                start_ts: 1000,
541            },
542            ActiveTransactionEntry {
543                txn_id: 2,
544                first_lsn: 120,
545                last_lsn: 180,
546                start_ts: 1100,
547            },
548        ];
549
550        let dirty_pages = vec![
551            DirtyPageEntry { page_id: 10, recovery_lsn: 90 },
552            DirtyPageEntry { page_id: 20, recovery_lsn: 110 },
553        ];
554
555        let checkpoint = CheckpointData::new(1, 200, active_txns, dirty_pages);
556
557        // Oldest required LSN should be 90 (from dirty page)
558        assert_eq!(checkpoint.oldest_required_lsn, 90);
559    }
560
561    #[test]
562    fn test_dirty_page_tracker() {
563        let tracker = DirtyPageTracker::new();
564
565        tracker.mark_dirty(1, 100);
566        tracker.mark_dirty(2, 110);
567        tracker.mark_dirty(1, 120); // Should not update (already dirty)
568
569        assert_eq!(tracker.dirty_count(), 2);
570
571        let pages = tracker.get_dirty_pages();
572        assert_eq!(pages.len(), 2);
573
574        // First LSN should be preserved
575        let page1 = pages.iter().find(|p| p.page_id == 1).unwrap();
576        assert_eq!(page1.recovery_lsn, 100);
577
578        tracker.mark_clean(1);
579        assert_eq!(tracker.dirty_count(), 1);
580    }
581
582    #[test]
583    fn test_active_transaction_tracker() {
584        let tracker = ActiveTransactionTracker::new();
585
586        tracker.register(1, 1000);
587        tracker.update_lsn(1, 100);
588        tracker.update_lsn(1, 150);
589
590        tracker.register(2, 1100);
591        tracker.update_lsn(2, 120);
592
593        assert_eq!(tracker.active_count(), 2);
594
595        let txns = tracker.get_active_transactions();
596        assert_eq!(txns.len(), 2);
597
598        let txn1 = txns.iter().find(|t| t.txn_id == 1).unwrap();
599        assert_eq!(txn1.first_lsn, 100);
600        assert_eq!(txn1.last_lsn, 150);
601
602        tracker.remove(1);
603        assert_eq!(tracker.active_count(), 1);
604    }
605
606    #[test]
607    fn test_checkpoint_manager() -> Result<()> {
608        let temp_dir = TempDir::new().unwrap();
609        let dirty_pages = Arc::new(DirtyPageTracker::new());
610        let active_txns = Arc::new(ActiveTransactionTracker::new());
611        let hlc = Arc::new(HybridLogicalClock::new());
612
613        let manager = CheckpointManager::new(
614            temp_dir.path(),
615            CheckpointConfig::default(),
616            dirty_pages.clone(),
617            active_txns.clone(),
618            hlc,
619        )?;
620
621        // Mark some dirty pages
622        dirty_pages.mark_dirty(1, manager.next_lsn());
623        dirty_pages.mark_dirty(2, manager.next_lsn());
624
625        // Register a transaction
626        active_txns.register(100, 1000);
627        active_txns.update_lsn(100, manager.next_lsn());
628
629        // Take a checkpoint
630        let checkpoint = manager.checkpoint(|_pages| Ok(()))?;
631
632        assert_eq!(checkpoint.checkpoint_id, 1);
633        assert_eq!(checkpoint.dirty_pages.len(), 2);
634        assert_eq!(checkpoint.active_transactions.len(), 1);
635
636        Ok(())
637    }
638}