1use std::collections::HashMap;
48use std::fs;
49use std::path::{Path, PathBuf};
50use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53
54use parking_lot::{Mutex, RwLock};
55use serde::{Deserialize, Serialize};
56
57use crate::hlc::HybridLogicalClock;
58use sochdb_core::{Result, SochDBError};
59
60pub type Lsn = u64;
62
63pub type PageId = u64;
65
66#[derive(Debug, Clone)]
68pub struct CheckpointConfig {
69 pub max_wal_size: u64,
71 pub max_interval: Duration,
73 pub min_records: u64,
75 pub truncate_wal: bool,
77 pub enabled: bool,
79}
80
81impl Default for CheckpointConfig {
82 fn default() -> Self {
83 Self {
84 max_wal_size: 64 * 1024 * 1024, max_interval: Duration::from_secs(60),
86 min_records: 100_000,
87 truncate_wal: true,
88 enabled: true,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ActiveTransactionEntry {
96 pub txn_id: u64,
98 pub first_lsn: Lsn,
100 pub last_lsn: Lsn,
102 pub start_ts: u64,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct DirtyPageEntry {
109 pub page_id: PageId,
111 pub recovery_lsn: Lsn,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CheckpointData {
118 pub checkpoint_id: u64,
120 pub begin_checkpoint_lsn: Lsn,
122 pub end_checkpoint_lsn: Lsn,
124 pub active_transactions: Vec<ActiveTransactionEntry>,
126 pub dirty_pages: Vec<DirtyPageEntry>,
128 pub timestamp: u64,
130 pub oldest_required_lsn: Lsn,
132}
133
134impl CheckpointData {
135 pub fn new(
137 checkpoint_id: u64,
138 begin_lsn: Lsn,
139 active_txns: Vec<ActiveTransactionEntry>,
140 dirty_pages: Vec<DirtyPageEntry>,
141 ) -> Self {
142 let oldest_txn_lsn = active_txns.iter().map(|t| t.first_lsn).min().unwrap_or(Lsn::MAX);
144 let oldest_page_lsn = dirty_pages.iter().map(|p| p.recovery_lsn).min().unwrap_or(Lsn::MAX);
145 let oldest_required_lsn = oldest_txn_lsn.min(oldest_page_lsn).min(begin_lsn);
146
147 Self {
148 checkpoint_id,
149 begin_checkpoint_lsn: begin_lsn,
150 end_checkpoint_lsn: 0, active_transactions: active_txns,
152 dirty_pages,
153 timestamp: std::time::SystemTime::now()
154 .duration_since(std::time::UNIX_EPOCH)
155 .unwrap()
156 .as_micros() as u64,
157 oldest_required_lsn,
158 }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct CheckpointMeta {
165 pub last_checkpoint: Option<CheckpointData>,
167 pub total_checkpoints: u64,
169 pub total_bytes_truncated: u64,
171}
172
173impl Default for CheckpointMeta {
174 fn default() -> Self {
175 Self {
176 last_checkpoint: None,
177 total_checkpoints: 0,
178 total_bytes_truncated: 0,
179 }
180 }
181}
182
183pub struct DirtyPageTracker {
185 dirty_pages: RwLock<HashMap<PageId, Lsn>>,
187}
188
189impl DirtyPageTracker {
190 pub fn new() -> Self {
191 Self {
192 dirty_pages: RwLock::new(HashMap::new()),
193 }
194 }
195
196 pub fn mark_dirty(&self, page_id: PageId, lsn: Lsn) {
198 let mut dirty = self.dirty_pages.write();
199 dirty.entry(page_id).or_insert(lsn);
200 }
201
202 pub fn mark_clean(&self, page_id: PageId) {
204 self.dirty_pages.write().remove(&page_id);
205 }
206
207 pub fn get_dirty_pages(&self) -> Vec<DirtyPageEntry> {
209 self.dirty_pages
210 .read()
211 .iter()
212 .map(|(&page_id, &recovery_lsn)| DirtyPageEntry { page_id, recovery_lsn })
213 .collect()
214 }
215
216 pub fn dirty_count(&self) -> usize {
218 self.dirty_pages.read().len()
219 }
220}
221
222impl Default for DirtyPageTracker {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228pub struct ActiveTransactionTracker {
230 active_txns: RwLock<HashMap<u64, (Lsn, Lsn, u64)>>,
232}
233
234impl ActiveTransactionTracker {
235 pub fn new() -> Self {
236 Self {
237 active_txns: RwLock::new(HashMap::new()),
238 }
239 }
240
241 pub fn register(&self, txn_id: u64, start_ts: u64) {
243 self.active_txns
244 .write()
245 .insert(txn_id, (Lsn::MAX, 0, start_ts));
246 }
247
248 pub fn update_lsn(&self, txn_id: u64, lsn: Lsn) {
250 if let Some(entry) = self.active_txns.write().get_mut(&txn_id) {
251 if entry.0 == Lsn::MAX {
252 entry.0 = lsn; }
254 entry.1 = lsn; }
256 }
257
258 pub fn remove(&self, txn_id: u64) {
260 self.active_txns.write().remove(&txn_id);
261 }
262
263 pub fn get_active_transactions(&self) -> Vec<ActiveTransactionEntry> {
265 self.active_txns
266 .read()
267 .iter()
268 .filter(|(_, (first_lsn, _, _))| *first_lsn != Lsn::MAX)
269 .map(|(&txn_id, &(first_lsn, last_lsn, start_ts))| ActiveTransactionEntry {
270 txn_id,
271 first_lsn,
272 last_lsn,
273 start_ts,
274 })
275 .collect()
276 }
277
278 pub fn active_count(&self) -> usize {
280 self.active_txns.read().len()
281 }
282}
283
284impl Default for ActiveTransactionTracker {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290pub struct CheckpointManager {
292 config: CheckpointConfig,
294 meta_path: PathBuf,
296 #[allow(dead_code)]
298 wal_dir: PathBuf,
299 meta: RwLock<CheckpointMeta>,
301 dirty_pages: Arc<DirtyPageTracker>,
303 active_txns: Arc<ActiveTransactionTracker>,
305 current_lsn: AtomicU64,
307 records_since_checkpoint: AtomicU64,
309 wal_bytes_since_checkpoint: AtomicU64,
311 last_checkpoint_time: Mutex<Instant>,
313 checkpoint_in_progress: AtomicBool,
315 next_checkpoint_id: AtomicU64,
317 #[allow(dead_code)]
319 hlc: Arc<HybridLogicalClock>,
320}
321
322impl CheckpointManager {
323 pub fn new(
325 data_dir: &Path,
326 config: CheckpointConfig,
327 dirty_pages: Arc<DirtyPageTracker>,
328 active_txns: Arc<ActiveTransactionTracker>,
329 hlc: Arc<HybridLogicalClock>,
330 ) -> Result<Self> {
331 let meta_path = data_dir.join("checkpoint.meta");
332 let wal_dir = data_dir.join("wal");
333
334 fs::create_dir_all(&wal_dir)?;
336
337 let meta = if meta_path.exists() {
339 let data = fs::read(&meta_path)?;
340 bincode::deserialize(&data).unwrap_or_default()
341 } else {
342 CheckpointMeta::default()
343 };
344
345 let next_id = meta.last_checkpoint.as_ref().map(|c| c.checkpoint_id + 1).unwrap_or(1);
346 let last_lsn = meta.last_checkpoint.as_ref().map(|c| c.end_checkpoint_lsn).unwrap_or(0);
347
348 Ok(Self {
349 config,
350 meta_path,
351 wal_dir,
352 meta: RwLock::new(meta),
353 dirty_pages,
354 active_txns,
355 current_lsn: AtomicU64::new(last_lsn),
356 records_since_checkpoint: AtomicU64::new(0),
357 wal_bytes_since_checkpoint: AtomicU64::new(0),
358 last_checkpoint_time: Mutex::new(Instant::now()),
359 checkpoint_in_progress: AtomicBool::new(false),
360 next_checkpoint_id: AtomicU64::new(next_id),
361 hlc,
362 })
363 }
364
365 #[inline]
367 pub fn next_lsn(&self) -> Lsn {
368 self.current_lsn.fetch_add(1, Ordering::SeqCst)
369 }
370
371 pub fn record_wal_write(&self, bytes: u64) {
373 self.records_since_checkpoint.fetch_add(1, Ordering::Relaxed);
374 self.wal_bytes_since_checkpoint.fetch_add(bytes, Ordering::Relaxed);
375 }
376
377 pub fn should_checkpoint(&self) -> bool {
379 if !self.config.enabled {
380 return false;
381 }
382
383 if self.checkpoint_in_progress.load(Ordering::Relaxed) {
384 return false;
385 }
386
387 let records = self.records_since_checkpoint.load(Ordering::Relaxed);
388 let bytes = self.wal_bytes_since_checkpoint.load(Ordering::Relaxed);
389 let elapsed = self.last_checkpoint_time.lock().elapsed();
390
391 records >= self.config.min_records
392 || bytes >= self.config.max_wal_size
393 || elapsed >= self.config.max_interval
394 }
395
396 pub fn checkpoint<F>(&self, flush_dirty_pages: F) -> Result<CheckpointData>
405 where
406 F: FnOnce(&[DirtyPageEntry]) -> Result<()>,
407 {
408 if self
410 .checkpoint_in_progress
411 .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
412 .is_err()
413 {
414 return Err(SochDBError::Internal("Checkpoint already in progress".into()));
415 }
416
417 struct CheckpointGuard<'a>(&'a AtomicBool);
419 impl<'a> Drop for CheckpointGuard<'a> {
420 fn drop(&mut self) {
421 self.0.store(false, Ordering::SeqCst);
422 }
423 }
424 let _guard = CheckpointGuard(&self.checkpoint_in_progress);
425
426 let checkpoint_id = self.next_checkpoint_id.fetch_add(1, Ordering::SeqCst);
427 let begin_lsn = self.next_lsn();
428
429 let active_txns = self.active_txns.get_active_transactions();
431 let dirty_pages = self.dirty_pages.get_dirty_pages();
432
433 let mut checkpoint = CheckpointData::new(checkpoint_id, begin_lsn, active_txns, dirty_pages.clone());
435
436 flush_dirty_pages(&dirty_pages)?;
438
439 for page in &dirty_pages {
441 self.dirty_pages.mark_clean(page.page_id);
442 }
443
444 let end_lsn = self.next_lsn();
446 checkpoint.end_checkpoint_lsn = end_lsn;
447
448 {
450 let mut meta = self.meta.write();
451 meta.last_checkpoint = Some(checkpoint.clone());
452 meta.total_checkpoints += 1;
453
454 let data = bincode::serialize(&*meta).map_err(|e| SochDBError::Serialization(e.to_string()))?;
456 fs::write(&self.meta_path, data)?;
457 }
458
459 self.records_since_checkpoint.store(0, Ordering::Relaxed);
461 self.wal_bytes_since_checkpoint.store(0, Ordering::Relaxed);
462 *self.last_checkpoint_time.lock() = Instant::now();
463
464 if self.config.truncate_wal {
466 self.truncate_wal(checkpoint.oldest_required_lsn)?;
467 }
468
469 Ok(checkpoint)
470 }
471
472 fn truncate_wal(&self, safe_lsn: Lsn) -> Result<()> {
474 let mut meta = self.meta.write();
481 if let Some(ref checkpoint) = meta.last_checkpoint {
482 let truncated = checkpoint.begin_checkpoint_lsn.saturating_sub(safe_lsn);
483 meta.total_bytes_truncated += truncated;
484 }
485
486 Ok(())
487 }
488
489 pub fn recovery_lsn(&self) -> Option<Lsn> {
491 self.meta
492 .read()
493 .last_checkpoint
494 .as_ref()
495 .map(|c| c.oldest_required_lsn)
496 }
497
498 pub fn last_checkpoint(&self) -> Option<CheckpointData> {
500 self.meta.read().last_checkpoint.clone()
501 }
502
503 pub fn stats(&self) -> CheckpointStats {
505 let meta = self.meta.read();
506 CheckpointStats {
507 total_checkpoints: meta.total_checkpoints,
508 total_bytes_truncated: meta.total_bytes_truncated,
509 records_since_checkpoint: self.records_since_checkpoint.load(Ordering::Relaxed),
510 wal_bytes_since_checkpoint: self.wal_bytes_since_checkpoint.load(Ordering::Relaxed),
511 dirty_pages: self.dirty_pages.dirty_count(),
512 active_transactions: self.active_txns.active_count(),
513 }
514 }
515}
516
517#[derive(Debug, Clone)]
519pub struct CheckpointStats {
520 pub total_checkpoints: u64,
521 pub total_bytes_truncated: u64,
522 pub records_since_checkpoint: u64,
523 pub wal_bytes_since_checkpoint: u64,
524 pub dirty_pages: usize,
525 pub active_transactions: usize,
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use tempfile::TempDir;
532
533 #[test]
534 fn test_checkpoint_data_creation() {
535 let active_txns = vec![
536 ActiveTransactionEntry {
537 txn_id: 1,
538 first_lsn: 100,
539 last_lsn: 150,
540 start_ts: 1000,
541 },
542 ActiveTransactionEntry {
543 txn_id: 2,
544 first_lsn: 120,
545 last_lsn: 180,
546 start_ts: 1100,
547 },
548 ];
549
550 let dirty_pages = vec![
551 DirtyPageEntry { page_id: 10, recovery_lsn: 90 },
552 DirtyPageEntry { page_id: 20, recovery_lsn: 110 },
553 ];
554
555 let checkpoint = CheckpointData::new(1, 200, active_txns, dirty_pages);
556
557 assert_eq!(checkpoint.oldest_required_lsn, 90);
559 }
560
561 #[test]
562 fn test_dirty_page_tracker() {
563 let tracker = DirtyPageTracker::new();
564
565 tracker.mark_dirty(1, 100);
566 tracker.mark_dirty(2, 110);
567 tracker.mark_dirty(1, 120); assert_eq!(tracker.dirty_count(), 2);
570
571 let pages = tracker.get_dirty_pages();
572 assert_eq!(pages.len(), 2);
573
574 let page1 = pages.iter().find(|p| p.page_id == 1).unwrap();
576 assert_eq!(page1.recovery_lsn, 100);
577
578 tracker.mark_clean(1);
579 assert_eq!(tracker.dirty_count(), 1);
580 }
581
582 #[test]
583 fn test_active_transaction_tracker() {
584 let tracker = ActiveTransactionTracker::new();
585
586 tracker.register(1, 1000);
587 tracker.update_lsn(1, 100);
588 tracker.update_lsn(1, 150);
589
590 tracker.register(2, 1100);
591 tracker.update_lsn(2, 120);
592
593 assert_eq!(tracker.active_count(), 2);
594
595 let txns = tracker.get_active_transactions();
596 assert_eq!(txns.len(), 2);
597
598 let txn1 = txns.iter().find(|t| t.txn_id == 1).unwrap();
599 assert_eq!(txn1.first_lsn, 100);
600 assert_eq!(txn1.last_lsn, 150);
601
602 tracker.remove(1);
603 assert_eq!(tracker.active_count(), 1);
604 }
605
606 #[test]
607 fn test_checkpoint_manager() -> Result<()> {
608 let temp_dir = TempDir::new().unwrap();
609 let dirty_pages = Arc::new(DirtyPageTracker::new());
610 let active_txns = Arc::new(ActiveTransactionTracker::new());
611 let hlc = Arc::new(HybridLogicalClock::new());
612
613 let manager = CheckpointManager::new(
614 temp_dir.path(),
615 CheckpointConfig::default(),
616 dirty_pages.clone(),
617 active_txns.clone(),
618 hlc,
619 )?;
620
621 dirty_pages.mark_dirty(1, manager.next_lsn());
623 dirty_pages.mark_dirty(2, manager.next_lsn());
624
625 active_txns.register(100, 1000);
627 active_txns.update_lsn(100, manager.next_lsn());
628
629 let checkpoint = manager.checkpoint(|_pages| Ok(()))?;
631
632 assert_eq!(checkpoint.checkpoint_id, 1);
633 assert_eq!(checkpoint.dirty_pages.len(), 2);
634 assert_eq!(checkpoint.active_transactions.len(), 1);
635
636 Ok(())
637 }
638}