1use std::collections::HashMap;
51use std::fs;
52use std::path::{Path, PathBuf};
53use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
54use std::sync::Arc;
55use std::time::{Duration, Instant};
56
57use parking_lot::{Mutex, RwLock};
58use serde::{Deserialize, Serialize};
59
60use crate::hlc::HybridLogicalClock;
61use sochdb_core::{Result, SochDBError};
62
63pub type Lsn = u64;
65
66pub type PageId = u64;
68
69#[derive(Debug, Clone)]
71pub struct CheckpointConfig {
72 pub max_wal_size: u64,
74 pub max_interval: Duration,
76 pub min_records: u64,
78 pub truncate_wal: bool,
80 pub enabled: bool,
82}
83
84impl Default for CheckpointConfig {
85 fn default() -> Self {
86 Self {
87 max_wal_size: 64 * 1024 * 1024, max_interval: Duration::from_secs(60),
89 min_records: 100_000,
90 truncate_wal: true,
91 enabled: true,
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ActiveTransactionEntry {
99 pub txn_id: u64,
101 pub first_lsn: Lsn,
103 pub last_lsn: Lsn,
105 pub start_ts: u64,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct DirtyPageEntry {
112 pub page_id: PageId,
114 pub recovery_lsn: Lsn,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct CheckpointData {
121 pub checkpoint_id: u64,
123 pub begin_checkpoint_lsn: Lsn,
125 pub end_checkpoint_lsn: Lsn,
127 pub active_transactions: Vec<ActiveTransactionEntry>,
129 pub dirty_pages: Vec<DirtyPageEntry>,
131 pub timestamp: u64,
133 pub oldest_required_lsn: Lsn,
135}
136
137impl CheckpointData {
138 pub fn new(
140 checkpoint_id: u64,
141 begin_lsn: Lsn,
142 active_txns: Vec<ActiveTransactionEntry>,
143 dirty_pages: Vec<DirtyPageEntry>,
144 ) -> Self {
145 let oldest_txn_lsn = active_txns.iter().map(|t| t.first_lsn).min().unwrap_or(Lsn::MAX);
147 let oldest_page_lsn = dirty_pages.iter().map(|p| p.recovery_lsn).min().unwrap_or(Lsn::MAX);
148 let oldest_required_lsn = oldest_txn_lsn.min(oldest_page_lsn).min(begin_lsn);
149
150 Self {
151 checkpoint_id,
152 begin_checkpoint_lsn: begin_lsn,
153 end_checkpoint_lsn: 0, active_transactions: active_txns,
155 dirty_pages,
156 timestamp: std::time::SystemTime::now()
157 .duration_since(std::time::UNIX_EPOCH)
158 .unwrap()
159 .as_micros() as u64,
160 oldest_required_lsn,
161 }
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct CheckpointMeta {
168 pub last_checkpoint: Option<CheckpointData>,
170 pub total_checkpoints: u64,
172 pub total_bytes_truncated: u64,
174}
175
176impl Default for CheckpointMeta {
177 fn default() -> Self {
178 Self {
179 last_checkpoint: None,
180 total_checkpoints: 0,
181 total_bytes_truncated: 0,
182 }
183 }
184}
185
186pub struct DirtyPageTracker {
188 dirty_pages: RwLock<HashMap<PageId, Lsn>>,
190}
191
192impl DirtyPageTracker {
193 pub fn new() -> Self {
194 Self {
195 dirty_pages: RwLock::new(HashMap::new()),
196 }
197 }
198
199 pub fn mark_dirty(&self, page_id: PageId, lsn: Lsn) {
201 let mut dirty = self.dirty_pages.write();
202 dirty.entry(page_id).or_insert(lsn);
203 }
204
205 pub fn mark_clean(&self, page_id: PageId) {
207 self.dirty_pages.write().remove(&page_id);
208 }
209
210 pub fn get_dirty_pages(&self) -> Vec<DirtyPageEntry> {
212 self.dirty_pages
213 .read()
214 .iter()
215 .map(|(&page_id, &recovery_lsn)| DirtyPageEntry { page_id, recovery_lsn })
216 .collect()
217 }
218
219 pub fn dirty_count(&self) -> usize {
221 self.dirty_pages.read().len()
222 }
223}
224
225impl Default for DirtyPageTracker {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231pub struct ActiveTransactionTracker {
233 active_txns: RwLock<HashMap<u64, (Lsn, Lsn, u64)>>,
235}
236
237impl ActiveTransactionTracker {
238 pub fn new() -> Self {
239 Self {
240 active_txns: RwLock::new(HashMap::new()),
241 }
242 }
243
244 pub fn register(&self, txn_id: u64, start_ts: u64) {
246 self.active_txns
247 .write()
248 .insert(txn_id, (Lsn::MAX, 0, start_ts));
249 }
250
251 pub fn update_lsn(&self, txn_id: u64, lsn: Lsn) {
253 if let Some(entry) = self.active_txns.write().get_mut(&txn_id) {
254 if entry.0 == Lsn::MAX {
255 entry.0 = lsn; }
257 entry.1 = lsn; }
259 }
260
261 pub fn remove(&self, txn_id: u64) {
263 self.active_txns.write().remove(&txn_id);
264 }
265
266 pub fn get_active_transactions(&self) -> Vec<ActiveTransactionEntry> {
268 self.active_txns
269 .read()
270 .iter()
271 .filter(|(_, (first_lsn, _, _))| *first_lsn != Lsn::MAX)
272 .map(|(&txn_id, &(first_lsn, last_lsn, start_ts))| ActiveTransactionEntry {
273 txn_id,
274 first_lsn,
275 last_lsn,
276 start_ts,
277 })
278 .collect()
279 }
280
281 pub fn active_count(&self) -> usize {
283 self.active_txns.read().len()
284 }
285}
286
287impl Default for ActiveTransactionTracker {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293pub struct CheckpointManager {
295 config: CheckpointConfig,
297 meta_path: PathBuf,
299 #[allow(dead_code)]
301 wal_dir: PathBuf,
302 meta: RwLock<CheckpointMeta>,
304 dirty_pages: Arc<DirtyPageTracker>,
306 active_txns: Arc<ActiveTransactionTracker>,
308 current_lsn: AtomicU64,
310 records_since_checkpoint: AtomicU64,
312 wal_bytes_since_checkpoint: AtomicU64,
314 last_checkpoint_time: Mutex<Instant>,
316 checkpoint_in_progress: AtomicBool,
318 next_checkpoint_id: AtomicU64,
320 #[allow(dead_code)]
322 hlc: Arc<HybridLogicalClock>,
323}
324
325impl CheckpointManager {
326 pub fn new(
328 data_dir: &Path,
329 config: CheckpointConfig,
330 dirty_pages: Arc<DirtyPageTracker>,
331 active_txns: Arc<ActiveTransactionTracker>,
332 hlc: Arc<HybridLogicalClock>,
333 ) -> Result<Self> {
334 let meta_path = data_dir.join("checkpoint.meta");
335 let wal_dir = data_dir.join("wal");
336
337 fs::create_dir_all(&wal_dir)?;
339
340 let meta = if meta_path.exists() {
342 let data = fs::read(&meta_path)?;
343 bincode::deserialize(&data).unwrap_or_default()
344 } else {
345 CheckpointMeta::default()
346 };
347
348 let next_id = meta.last_checkpoint.as_ref().map(|c| c.checkpoint_id + 1).unwrap_or(1);
349 let last_lsn = meta.last_checkpoint.as_ref().map(|c| c.end_checkpoint_lsn).unwrap_or(0);
350
351 Ok(Self {
352 config,
353 meta_path,
354 wal_dir,
355 meta: RwLock::new(meta),
356 dirty_pages,
357 active_txns,
358 current_lsn: AtomicU64::new(last_lsn),
359 records_since_checkpoint: AtomicU64::new(0),
360 wal_bytes_since_checkpoint: AtomicU64::new(0),
361 last_checkpoint_time: Mutex::new(Instant::now()),
362 checkpoint_in_progress: AtomicBool::new(false),
363 next_checkpoint_id: AtomicU64::new(next_id),
364 hlc,
365 })
366 }
367
368 #[inline]
370 pub fn next_lsn(&self) -> Lsn {
371 self.current_lsn.fetch_add(1, Ordering::SeqCst)
372 }
373
374 pub fn record_wal_write(&self, bytes: u64) {
376 self.records_since_checkpoint.fetch_add(1, Ordering::Relaxed);
377 self.wal_bytes_since_checkpoint.fetch_add(bytes, Ordering::Relaxed);
378 }
379
380 pub fn should_checkpoint(&self) -> bool {
382 if !self.config.enabled {
383 return false;
384 }
385
386 if self.checkpoint_in_progress.load(Ordering::Relaxed) {
387 return false;
388 }
389
390 let records = self.records_since_checkpoint.load(Ordering::Relaxed);
391 let bytes = self.wal_bytes_since_checkpoint.load(Ordering::Relaxed);
392 let elapsed = self.last_checkpoint_time.lock().elapsed();
393
394 records >= self.config.min_records
395 || bytes >= self.config.max_wal_size
396 || elapsed >= self.config.max_interval
397 }
398
399 pub fn checkpoint<F>(&self, flush_dirty_pages: F) -> Result<CheckpointData>
408 where
409 F: FnOnce(&[DirtyPageEntry]) -> Result<()>,
410 {
411 if self
413 .checkpoint_in_progress
414 .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
415 .is_err()
416 {
417 return Err(SochDBError::Internal("Checkpoint already in progress".into()));
418 }
419
420 struct CheckpointGuard<'a>(&'a AtomicBool);
422 impl<'a> Drop for CheckpointGuard<'a> {
423 fn drop(&mut self) {
424 self.0.store(false, Ordering::SeqCst);
425 }
426 }
427 let _guard = CheckpointGuard(&self.checkpoint_in_progress);
428
429 let checkpoint_id = self.next_checkpoint_id.fetch_add(1, Ordering::SeqCst);
430 let begin_lsn = self.next_lsn();
431
432 let active_txns = self.active_txns.get_active_transactions();
434 let dirty_pages = self.dirty_pages.get_dirty_pages();
435
436 let mut checkpoint = CheckpointData::new(checkpoint_id, begin_lsn, active_txns, dirty_pages.clone());
438
439 flush_dirty_pages(&dirty_pages)?;
441
442 for page in &dirty_pages {
444 self.dirty_pages.mark_clean(page.page_id);
445 }
446
447 let end_lsn = self.next_lsn();
449 checkpoint.end_checkpoint_lsn = end_lsn;
450
451 {
453 let mut meta = self.meta.write();
454 meta.last_checkpoint = Some(checkpoint.clone());
455 meta.total_checkpoints += 1;
456
457 let data = bincode::serialize(&*meta).map_err(|e| SochDBError::Serialization(e.to_string()))?;
459 fs::write(&self.meta_path, data)?;
460 }
461
462 self.records_since_checkpoint.store(0, Ordering::Relaxed);
464 self.wal_bytes_since_checkpoint.store(0, Ordering::Relaxed);
465 *self.last_checkpoint_time.lock() = Instant::now();
466
467 if self.config.truncate_wal {
469 self.truncate_wal(checkpoint.oldest_required_lsn)?;
470 }
471
472 Ok(checkpoint)
473 }
474
475 fn truncate_wal(&self, safe_lsn: Lsn) -> Result<()> {
477 let mut meta = self.meta.write();
484 if let Some(ref checkpoint) = meta.last_checkpoint {
485 let truncated = checkpoint.begin_checkpoint_lsn.saturating_sub(safe_lsn);
486 meta.total_bytes_truncated += truncated;
487 }
488
489 Ok(())
490 }
491
492 pub fn recovery_lsn(&self) -> Option<Lsn> {
494 self.meta
495 .read()
496 .last_checkpoint
497 .as_ref()
498 .map(|c| c.oldest_required_lsn)
499 }
500
501 pub fn last_checkpoint(&self) -> Option<CheckpointData> {
503 self.meta.read().last_checkpoint.clone()
504 }
505
506 pub fn stats(&self) -> CheckpointStats {
508 let meta = self.meta.read();
509 CheckpointStats {
510 total_checkpoints: meta.total_checkpoints,
511 total_bytes_truncated: meta.total_bytes_truncated,
512 records_since_checkpoint: self.records_since_checkpoint.load(Ordering::Relaxed),
513 wal_bytes_since_checkpoint: self.wal_bytes_since_checkpoint.load(Ordering::Relaxed),
514 dirty_pages: self.dirty_pages.dirty_count(),
515 active_transactions: self.active_txns.active_count(),
516 }
517 }
518}
519
520#[derive(Debug, Clone)]
522pub struct CheckpointStats {
523 pub total_checkpoints: u64,
524 pub total_bytes_truncated: u64,
525 pub records_since_checkpoint: u64,
526 pub wal_bytes_since_checkpoint: u64,
527 pub dirty_pages: usize,
528 pub active_transactions: usize,
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use tempfile::TempDir;
535
536 #[test]
537 fn test_checkpoint_data_creation() {
538 let active_txns = vec![
539 ActiveTransactionEntry {
540 txn_id: 1,
541 first_lsn: 100,
542 last_lsn: 150,
543 start_ts: 1000,
544 },
545 ActiveTransactionEntry {
546 txn_id: 2,
547 first_lsn: 120,
548 last_lsn: 180,
549 start_ts: 1100,
550 },
551 ];
552
553 let dirty_pages = vec![
554 DirtyPageEntry { page_id: 10, recovery_lsn: 90 },
555 DirtyPageEntry { page_id: 20, recovery_lsn: 110 },
556 ];
557
558 let checkpoint = CheckpointData::new(1, 200, active_txns, dirty_pages);
559
560 assert_eq!(checkpoint.oldest_required_lsn, 90);
562 }
563
564 #[test]
565 fn test_dirty_page_tracker() {
566 let tracker = DirtyPageTracker::new();
567
568 tracker.mark_dirty(1, 100);
569 tracker.mark_dirty(2, 110);
570 tracker.mark_dirty(1, 120); assert_eq!(tracker.dirty_count(), 2);
573
574 let pages = tracker.get_dirty_pages();
575 assert_eq!(pages.len(), 2);
576
577 let page1 = pages.iter().find(|p| p.page_id == 1).unwrap();
579 assert_eq!(page1.recovery_lsn, 100);
580
581 tracker.mark_clean(1);
582 assert_eq!(tracker.dirty_count(), 1);
583 }
584
585 #[test]
586 fn test_active_transaction_tracker() {
587 let tracker = ActiveTransactionTracker::new();
588
589 tracker.register(1, 1000);
590 tracker.update_lsn(1, 100);
591 tracker.update_lsn(1, 150);
592
593 tracker.register(2, 1100);
594 tracker.update_lsn(2, 120);
595
596 assert_eq!(tracker.active_count(), 2);
597
598 let txns = tracker.get_active_transactions();
599 assert_eq!(txns.len(), 2);
600
601 let txn1 = txns.iter().find(|t| t.txn_id == 1).unwrap();
602 assert_eq!(txn1.first_lsn, 100);
603 assert_eq!(txn1.last_lsn, 150);
604
605 tracker.remove(1);
606 assert_eq!(tracker.active_count(), 1);
607 }
608
609 #[test]
610 fn test_checkpoint_manager() -> Result<()> {
611 let temp_dir = TempDir::new().unwrap();
612 let dirty_pages = Arc::new(DirtyPageTracker::new());
613 let active_txns = Arc::new(ActiveTransactionTracker::new());
614 let hlc = Arc::new(HybridLogicalClock::new());
615
616 let manager = CheckpointManager::new(
617 temp_dir.path(),
618 CheckpointConfig::default(),
619 dirty_pages.clone(),
620 active_txns.clone(),
621 hlc,
622 )?;
623
624 dirty_pages.mark_dirty(1, manager.next_lsn());
626 dirty_pages.mark_dirty(2, manager.next_lsn());
627
628 active_txns.register(100, 1000);
630 active_txns.update_lsn(100, manager.next_lsn());
631
632 let checkpoint = manager.checkpoint(|_pages| Ok(()))?;
634
635 assert_eq!(checkpoint.checkpoint_id, 1);
636 assert_eq!(checkpoint.dirty_pages.len(), 2);
637 assert_eq!(checkpoint.active_transactions.len(), 1);
638
639 Ok(())
640 }
641}