1mod checkpoint_store;
13mod error;
14pub mod file_store;
15mod preview;
16mod retention;
17mod state;
18
19use std::sync::Arc;
20
21use parking_lot::{Mutex, RwLock};
22
23pub use checkpoint_store::CheckpointStore;
24pub use error::{CheckpointError, Result};
25pub use file_store::FileCheckpointStore;
26pub use preview::{format_confirmation_prompt, format_warning, PreviewGenerator};
27pub use retention::RetentionManager;
28pub use state::{
29 CheckpointInfo, CheckpointMetadata, CheckpointState, CheckpointTrigger, DestructiveOp,
30 GraphMeta, OperationPreview, RelationalMeta, VectorMeta,
31};
32use tensor_store::TensorStore;
33
34#[derive(Debug, Clone)]
36pub struct CheckpointConfig {
37 pub max_checkpoints: usize,
39 pub auto_checkpoint: bool,
41 pub interactive_confirm: bool,
43 pub preview_sample_size: usize,
45}
46
47impl Default for CheckpointConfig {
48 fn default() -> Self {
49 Self {
50 max_checkpoints: 10,
51 auto_checkpoint: true,
52 interactive_confirm: true,
53 preview_sample_size: 5,
54 }
55 }
56}
57
58impl CheckpointConfig {
59 #[must_use]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 #[must_use]
67 pub const fn with_max_checkpoints(mut self, max: usize) -> Self {
68 self.max_checkpoints = max;
69 self
70 }
71
72 #[must_use]
74 pub const fn with_auto_checkpoint(mut self, enabled: bool) -> Self {
75 self.auto_checkpoint = enabled;
76 self
77 }
78
79 #[must_use]
81 pub const fn with_interactive_confirm(mut self, enabled: bool) -> Self {
82 self.interactive_confirm = enabled;
83 self
84 }
85
86 #[must_use]
88 pub const fn with_preview_sample_size(mut self, size: usize) -> Self {
89 self.preview_sample_size = size;
90 self
91 }
92}
93
94pub trait ConfirmationHandler: Send + Sync {
96 fn confirm(&self, op: &DestructiveOp, preview: &OperationPreview) -> bool;
98}
99
100pub struct AutoConfirm;
102
103impl ConfirmationHandler for AutoConfirm {
104 fn confirm(&self, _op: &DestructiveOp, _preview: &OperationPreview) -> bool {
105 true
106 }
107}
108
109pub struct AutoReject;
111
112impl ConfirmationHandler for AutoReject {
113 fn confirm(&self, _op: &DestructiveOp, _preview: &OperationPreview) -> bool {
114 false
115 }
116}
117
118pub struct CheckpointManager {
123 store: Arc<dyn CheckpointStore>,
124 config: CheckpointConfig,
125 retention: RetentionManager,
126 preview_gen: PreviewGenerator,
127 confirm_handler: RwLock<Option<Arc<dyn ConfirmationHandler>>>,
128 op_lock: Mutex<()>,
130}
131
132impl CheckpointManager {
133 #[must_use]
135 pub fn new(store: Arc<dyn CheckpointStore>, config: CheckpointConfig) -> Self {
136 let retention = RetentionManager::new(config.max_checkpoints);
137 let preview_gen = PreviewGenerator::new(config.preview_sample_size);
138
139 Self {
140 store,
141 config,
142 retention,
143 preview_gen,
144 confirm_handler: RwLock::new(None),
145 op_lock: Mutex::new(()),
146 }
147 }
148
149 pub fn set_confirmation_handler(&self, handler: Arc<dyn ConfirmationHandler>) {
151 *self.confirm_handler.write() = Some(handler);
152 }
153
154 #[must_use]
156 pub const fn config(&self) -> &CheckpointConfig {
157 &self.config
158 }
159
160 pub fn create(&self, name: Option<&str>, tensor_store: &TensorStore) -> Result<String> {
166 let _guard = self.op_lock.lock();
167
168 let id = uuid::Uuid::new_v4().to_string();
169 let name = name.map_or_else(
170 || {
171 let now = std::time::SystemTime::now()
172 .duration_since(std::time::UNIX_EPOCH)
173 .map(|d| d.as_secs())
174 .unwrap_or(0);
175 format!("checkpoint-{now}")
176 },
177 String::from,
178 );
179
180 let metadata = Self::collect_metadata(tensor_store);
181 let snapshot_bytes = tensor_store
182 .snapshot_bytes()
183 .map_err(|e| CheckpointError::Snapshot(e.to_string()))?;
184
185 let state = CheckpointState::new(id.clone(), name, snapshot_bytes, metadata);
186
187 self.store.store(&state)?;
188 self.retention.enforce(self.store.as_ref())?;
189
190 Ok(id)
191 }
192
193 pub fn create_auto(
199 &self,
200 command: &str,
201 op: DestructiveOp,
202 preview: OperationPreview,
203 tensor_store: &TensorStore,
204 ) -> Result<String> {
205 let _guard = self.op_lock.lock();
206
207 let id = uuid::Uuid::new_v4().to_string();
208 let name = format!(
209 "auto-before-{}",
210 op.operation_name().to_lowercase().replace(' ', "-")
211 );
212
213 let trigger = CheckpointTrigger::new(command.to_string(), op, preview);
214 let metadata = Self::collect_metadata(tensor_store);
215 let snapshot_bytes = tensor_store
216 .snapshot_bytes()
217 .map_err(|e| CheckpointError::Snapshot(e.to_string()))?;
218
219 let state =
220 CheckpointState::new(id.clone(), name, snapshot_bytes, metadata).with_trigger(trigger);
221
222 self.store.store(&state)?;
223 self.retention.enforce(self.store.as_ref())?;
224
225 Ok(id)
226 }
227
228 #[must_use]
230 pub fn request_confirmation(&self, op: &DestructiveOp, preview: &OperationPreview) -> bool {
231 if !self.config.interactive_confirm {
232 return true;
233 }
234
235 self.confirm_handler
236 .read()
237 .as_ref()
238 .map_or(true, |handler| handler.confirm(op, preview))
239 }
240
241 #[must_use]
243 pub fn generate_preview(
244 &self,
245 op: &DestructiveOp,
246 sample_data: Vec<String>,
247 ) -> OperationPreview {
248 self.preview_gen.generate(op, sample_data)
249 }
250
251 pub fn list(&self, limit: Option<usize>) -> Result<Vec<CheckpointInfo>> {
257 self.store.list(limit)
258 }
259
260 pub fn rollback(&self, id_or_name: &str, tensor_store: &TensorStore) -> Result<()> {
266 let _guard = self.op_lock.lock();
267
268 let state = self.store.load(id_or_name)?;
269
270 tensor_store
271 .restore_from_bytes(&state.store_snapshot)
272 .map_err(|e| CheckpointError::Snapshot(e.to_string()))?;
273
274 Ok(())
275 }
276
277 pub fn delete(&self, id_or_name: &str) -> Result<()> {
283 let _guard = self.op_lock.lock();
284 self.store.delete(id_or_name)
285 }
286
287 #[must_use]
289 pub const fn auto_checkpoint_enabled(&self) -> bool {
290 self.config.auto_checkpoint
291 }
292
293 #[must_use]
295 pub const fn interactive_confirm_enabled(&self) -> bool {
296 self.config.interactive_confirm
297 }
298
299 fn collect_metadata(store: &TensorStore) -> CheckpointMetadata {
300 let store_key_count = store.len();
301
302 let table_keys: Vec<_> = store.scan("_schema:");
304 let table_count = table_keys.len();
305 let mut total_rows = 0;
306
307 for key in &table_keys {
308 if let Some(table_name) = key.strip_prefix("_schema:") {
309 total_rows += store.scan_count(&format!("{table_name}:"));
310 }
311 }
312
313 let node_count = store.scan_count("node:");
315 let edge_count = store.scan_count("edge:");
316
317 let embedding_count = store.scan_count("_embed:");
319
320 CheckpointMetadata::new(
321 RelationalMeta::new(table_count, total_rows),
322 GraphMeta::new(node_count, edge_count),
323 VectorMeta::new(embedding_count),
324 store_key_count,
325 )
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use tensor_store::{ScalarValue, TensorData, TensorValue};
332
333 use super::*;
334
335 fn make_tensor(key: &str, value: &str) -> TensorData {
336 let mut t = TensorData::new();
337 t.set(
338 key,
339 TensorValue::Scalar(ScalarValue::String(value.to_string())),
340 );
341 t
342 }
343
344 fn setup_with_dir() -> (CheckpointManager, TensorStore, tempfile::TempDir) {
345 let dir = tempfile::tempdir().unwrap();
346 let store = TensorStore::new();
347 let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
348 let config = CheckpointConfig::default();
349 let manager = CheckpointManager::new(file_store, config);
350 (manager, store, dir)
351 }
352
353 #[test]
354 fn test_create_manual_checkpoint() {
355 let (manager, store, _dir) = setup_with_dir();
356
357 store.put("user:1", make_tensor("name", "Alice")).unwrap();
358
359 let id = manager.create(Some("my-checkpoint"), &store).unwrap();
360 assert!(!id.is_empty());
361
362 let list = manager.list(None).unwrap();
363 assert_eq!(list.len(), 1);
364 assert_eq!(list[0].name, "my-checkpoint");
365 }
366
367 #[test]
368 fn test_create_auto_checkpoint() {
369 let (manager, store, _dir) = setup_with_dir();
370
371 let op = DestructiveOp::Delete {
372 table: "users".to_string(),
373 row_count: 5,
374 };
375 let preview = OperationPreview::new("Deleting 5 rows".to_string(), vec![], 5);
376
377 let id = manager
378 .create_auto("DELETE FROM users", op, preview, &store)
379 .unwrap();
380 assert!(!id.is_empty());
381
382 let list = manager.list(None).unwrap();
383 assert_eq!(list.len(), 1);
384 assert!(list[0].name.starts_with("auto-before-"));
385 }
386
387 #[test]
388 fn test_rollback() {
389 let (manager, store, _dir) = setup_with_dir();
390
391 store.put("user:1", make_tensor("name", "Alice")).unwrap();
392
393 let id = manager.create(Some("before-delete"), &store).unwrap();
394
395 store.delete("user:1").unwrap();
396 assert!(!store.exists("user:1"));
397
398 manager.rollback(&id, &store).unwrap();
399
400 assert!(store.exists("user:1"));
401 let data = store.get("user:1").unwrap();
402 assert_eq!(
403 data.get("name"),
404 Some(&TensorValue::Scalar(ScalarValue::String(
405 "Alice".to_string()
406 )))
407 );
408 }
409
410 #[test]
411 fn test_rollback_by_name() {
412 let (manager, store, _dir) = setup_with_dir();
413
414 store.put("key", make_tensor("val", "original")).unwrap();
415
416 manager.create(Some("named-checkpoint"), &store).unwrap();
417
418 store.delete("key").unwrap();
419
420 manager.rollback("named-checkpoint", &store).unwrap();
421
422 assert!(store.exists("key"));
423 }
424
425 #[test]
426 fn test_retention() {
427 let dir = tempfile::tempdir().unwrap();
428 let store = TensorStore::new();
429 let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
430 let config = CheckpointConfig::default().with_max_checkpoints(2);
431 let manager = CheckpointManager::new(file_store, config);
432
433 for i in 0..5 {
434 manager.create(Some(&format!("cp-{i}")), &store).unwrap();
435 }
436
437 let list = manager.list(None).unwrap();
438 assert_eq!(list.len(), 2);
439
440 for cp in &list {
441 assert!(cp.name.starts_with("cp-"));
442 }
443 }
444
445 #[test]
446 fn test_confirmation_handler() {
447 let (manager, _store, _dir) = setup_with_dir();
448
449 manager.set_confirmation_handler(Arc::new(AutoReject));
450
451 let op = DestructiveOp::Delete {
452 table: "test".to_string(),
453 row_count: 1,
454 };
455 let preview = OperationPreview::empty("test");
456
457 assert!(!manager.request_confirmation(&op, &preview));
458 }
459
460 #[test]
461 fn test_metadata_collection() {
462 let (manager, store, _dir) = setup_with_dir();
463
464 store
465 .put("_schema:users", make_tensor("name", "users"))
466 .unwrap();
467 store.put("users:1", make_tensor("name", "Alice")).unwrap();
468 store.put("users:2", make_tensor("name", "Bob")).unwrap();
469 store.put("node:1", make_tensor("label", "Person")).unwrap();
470 store.put("edge:1", make_tensor("type", "KNOWS")).unwrap();
471
472 let mut embed_data = TensorData::new();
473 embed_data.set("vec", TensorValue::Vector(vec![1.0, 2.0]));
474 store.put("_embed:doc1", embed_data).unwrap();
475
476 let id = manager.create(None, &store).unwrap();
477 let state = manager.list(None).unwrap();
478 assert_eq!(state.len(), 1);
479
480 let dir = tempfile::tempdir().unwrap();
482 let file_store = FileCheckpointStore::new(dir.path()).unwrap();
483 let loaded = manager.list(None).unwrap();
485 assert!(!loaded.is_empty());
486 assert!(!id.is_empty());
488
489 let metadata = CheckpointManager::collect_metadata(&store);
491 assert_eq!(metadata.relational.table_count, 1);
492 assert_eq!(metadata.relational.total_rows, 2);
493 assert_eq!(metadata.graph.node_count, 1);
494 assert_eq!(metadata.graph.edge_count, 1);
495 assert_eq!(metadata.vector.embedding_count, 1);
496
497 drop(file_store);
498 }
499
500 #[test]
501 fn test_delete_checkpoint() {
502 let (manager, store, _dir) = setup_with_dir();
503
504 let id = manager.create(Some("to-delete"), &store).unwrap();
505 assert_eq!(manager.list(None).unwrap().len(), 1);
506
507 manager.delete(&id).unwrap();
508 assert_eq!(manager.list(None).unwrap().len(), 0);
509 }
510
511 #[test]
512 fn test_delete_by_name() {
513 let (manager, store, _dir) = setup_with_dir();
514
515 manager.create(Some("named-cp"), &store).unwrap();
516 assert_eq!(manager.list(None).unwrap().len(), 1);
517
518 manager.delete("named-cp").unwrap();
519 assert_eq!(manager.list(None).unwrap().len(), 0);
520 }
521
522 #[test]
523 fn test_delete_not_found() {
524 let (manager, _store, _dir) = setup_with_dir();
525
526 let result = manager.delete("non-existent");
527 assert!(matches!(result, Err(CheckpointError::NotFound(_))));
528 }
529
530 #[test]
531 fn test_rollback_not_found() {
532 let (manager, store, _dir) = setup_with_dir();
533
534 let result = manager.rollback("non-existent", &store);
535 assert!(matches!(result, Err(CheckpointError::NotFound(_))));
536 }
537
538 #[test]
539 fn test_config_methods() {
540 let config = CheckpointConfig::new()
541 .with_max_checkpoints(5)
542 .with_auto_checkpoint(false)
543 .with_interactive_confirm(false)
544 .with_preview_sample_size(10);
545
546 assert_eq!(config.max_checkpoints, 5);
547 assert!(!config.auto_checkpoint);
548 assert!(!config.interactive_confirm);
549 assert_eq!(config.preview_sample_size, 10);
550 }
551
552 #[test]
553 fn test_auto_checkpoint_enabled() {
554 let dir = tempfile::tempdir().unwrap();
555 let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
556 let config = CheckpointConfig::default().with_auto_checkpoint(false);
557 let manager = CheckpointManager::new(file_store, config);
558
559 assert!(!manager.auto_checkpoint_enabled());
560 }
561
562 #[test]
563 fn test_interactive_confirm_enabled() {
564 let dir = tempfile::tempdir().unwrap();
565 let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
566 let config = CheckpointConfig::default().with_interactive_confirm(false);
567 let manager = CheckpointManager::new(file_store, config);
568
569 assert!(!manager.interactive_confirm_enabled());
570 }
571
572 #[test]
573 fn test_request_confirmation_without_handler() {
574 let (manager, _store, _dir) = setup_with_dir();
575
576 let op = DestructiveOp::Delete {
577 table: "test".to_string(),
578 row_count: 1,
579 };
580 let preview = OperationPreview::empty("test");
581
582 assert!(manager.request_confirmation(&op, &preview));
583 }
584
585 #[test]
586 fn test_request_confirmation_disabled() {
587 let dir = tempfile::tempdir().unwrap();
588 let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
589 let config = CheckpointConfig::default().with_interactive_confirm(false);
590 let manager = CheckpointManager::new(file_store, config);
591
592 let op = DestructiveOp::Delete {
593 table: "test".to_string(),
594 row_count: 1,
595 };
596 let preview = OperationPreview::empty("test");
597
598 assert!(manager.request_confirmation(&op, &preview));
599 }
600
601 #[test]
602 fn test_auto_confirm_handler() {
603 let (manager, _store, _dir) = setup_with_dir();
604
605 manager.set_confirmation_handler(Arc::new(AutoConfirm));
606
607 let op = DestructiveOp::Delete {
608 table: "test".to_string(),
609 row_count: 1,
610 };
611 let preview = OperationPreview::empty("test");
612
613 assert!(manager.request_confirmation(&op, &preview));
614 }
615
616 #[test]
617 fn test_generate_preview() {
618 let (manager, _store, _dir) = setup_with_dir();
619
620 let op = DestructiveOp::Delete {
621 table: "users".to_string(),
622 row_count: 10,
623 };
624 let sample = vec!["row1".to_string(), "row2".to_string()];
625
626 let preview = manager.generate_preview(&op, sample);
627 assert_eq!(preview.affected_count, 10);
628 assert_eq!(preview.sample_data.len(), 2);
629 }
630
631 #[test]
632 fn test_list_with_limit() {
633 let (manager, store, _dir) = setup_with_dir();
634
635 for i in 0..5 {
636 manager.create(Some(&format!("cp-{i}")), &store).unwrap();
637 }
638
639 let list = manager.list(Some(3)).unwrap();
640 assert_eq!(list.len(), 3);
641 }
642
643 #[test]
644 fn test_config_accessor() {
645 let (manager, _store, _dir) = setup_with_dir();
646
647 let config = manager.config();
648 assert_eq!(config.max_checkpoints, 10);
649 }
650
651 #[test]
652 fn test_create_unnamed_checkpoint() {
653 let (manager, store, _dir) = setup_with_dir();
654
655 let id = manager.create(None, &store).unwrap();
656 assert!(!id.is_empty());
657
658 let list = manager.list(None).unwrap();
659 assert_eq!(list.len(), 1);
660 assert!(list[0].name.starts_with("checkpoint-"));
661 }
662}