Skip to main content

tensor_checkpoint/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! `TensorCheckpoint` - Rollback/Checkpoint System for Neumann
3//!
4//! Provides checkpoint and rollback capabilities for the Neumann database:
5//! - Auto-checkpoints before destructive operations
6//! - Manual CHECKPOINT command for user-initiated snapshots
7//! - Interactive confirmation with preview of affected data
8//! - Count-based retention with automatic purge
9//!
10//! Checkpoints are stored on disk via `FileCheckpointStore`.
11
12mod checkpoint_store;
13mod error;
14pub mod file_store;
15mod preview;
16mod retention;
17mod state;
18
19use std::sync::Arc;
20
21use parking_lot::{Mutex, RwLock};
22
23pub use checkpoint_store::CheckpointStore;
24pub use error::{CheckpointError, Result};
25pub use file_store::FileCheckpointStore;
26pub use preview::{format_confirmation_prompt, format_warning, PreviewGenerator};
27pub use retention::RetentionManager;
28pub use state::{
29    CheckpointInfo, CheckpointMetadata, CheckpointState, CheckpointTrigger, DestructiveOp,
30    GraphMeta, OperationPreview, RelationalMeta, VectorMeta,
31};
32use tensor_store::TensorStore;
33
34/// Configuration for the checkpoint manager.
35#[derive(Debug, Clone)]
36pub struct CheckpointConfig {
37    /// Maximum number of checkpoints to retain (oldest are purged).
38    pub max_checkpoints: usize,
39    /// Whether to auto-checkpoint before destructive operations.
40    pub auto_checkpoint: bool,
41    /// Whether to prompt the user for confirmation before destructive operations.
42    pub interactive_confirm: bool,
43    /// Maximum number of sample data items shown in operation previews.
44    pub preview_sample_size: usize,
45}
46
47impl Default for CheckpointConfig {
48    fn default() -> Self {
49        Self {
50            max_checkpoints: 10,
51            auto_checkpoint: true,
52            interactive_confirm: true,
53            preview_sample_size: 5,
54        }
55    }
56}
57
58impl CheckpointConfig {
59    /// Create a default checkpoint configuration.
60    #[must_use]
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Set the maximum number of checkpoints to retain.
66    #[must_use]
67    pub const fn with_max_checkpoints(mut self, max: usize) -> Self {
68        self.max_checkpoints = max;
69        self
70    }
71
72    /// Enable or disable auto-checkpoints before destructive operations.
73    #[must_use]
74    pub const fn with_auto_checkpoint(mut self, enabled: bool) -> Self {
75        self.auto_checkpoint = enabled;
76        self
77    }
78
79    /// Enable or disable interactive confirmation prompts.
80    #[must_use]
81    pub const fn with_interactive_confirm(mut self, enabled: bool) -> Self {
82        self.interactive_confirm = enabled;
83        self
84    }
85
86    /// Set the number of sample data items shown in operation previews.
87    #[must_use]
88    pub const fn with_preview_sample_size(mut self, size: usize) -> Self {
89        self.preview_sample_size = size;
90        self
91    }
92}
93
94/// Trait for handling confirmation prompts before destructive operations.
95pub trait ConfirmationHandler: Send + Sync {
96    /// Return `true` to proceed with the operation, `false` to cancel.
97    fn confirm(&self, op: &DestructiveOp, preview: &OperationPreview) -> bool;
98}
99
100/// No-op confirmation handler that always confirms.
101pub struct AutoConfirm;
102
103impl ConfirmationHandler for AutoConfirm {
104    fn confirm(&self, _op: &DestructiveOp, _preview: &OperationPreview) -> bool {
105        true
106    }
107}
108
109/// Confirmation handler that always rejects (for testing).
110pub struct AutoReject;
111
112impl ConfirmationHandler for AutoReject {
113    fn confirm(&self, _op: &DestructiveOp, _preview: &OperationPreview) -> bool {
114        false
115    }
116}
117
118/// Central coordinator for creating, listing, restoring, and deleting checkpoints.
119///
120/// All state-changing operations are serialized through an internal mutex.
121/// `list()` is lock-free (reads only, tolerant of concurrent atomic file creation).
122pub struct CheckpointManager {
123    store: Arc<dyn CheckpointStore>,
124    config: CheckpointConfig,
125    retention: RetentionManager,
126    preview_gen: PreviewGenerator,
127    confirm_handler: RwLock<Option<Arc<dyn ConfirmationHandler>>>,
128    /// Serializes create, rollback, delete operations.
129    op_lock: Mutex<()>,
130}
131
132impl CheckpointManager {
133    /// Create a checkpoint manager backed by the given store and configuration.
134    #[must_use]
135    pub fn new(store: Arc<dyn CheckpointStore>, config: CheckpointConfig) -> Self {
136        let retention = RetentionManager::new(config.max_checkpoints);
137        let preview_gen = PreviewGenerator::new(config.preview_sample_size);
138
139        Self {
140            store,
141            config,
142            retention,
143            preview_gen,
144            confirm_handler: RwLock::new(None),
145            op_lock: Mutex::new(()),
146        }
147    }
148
149    /// Register a handler to be called for destructive operation confirmation.
150    pub fn set_confirmation_handler(&self, handler: Arc<dyn ConfirmationHandler>) {
151        *self.confirm_handler.write() = Some(handler);
152    }
153
154    /// Returns a reference to the current configuration.
155    #[must_use]
156    pub const fn config(&self) -> &CheckpointConfig {
157        &self.config
158    }
159
160    /// Create a manual checkpoint with optional name.
161    ///
162    /// # Errors
163    ///
164    /// Returns an error if the snapshot cannot be created or the checkpoint cannot be stored.
165    pub fn create(&self, name: Option<&str>, tensor_store: &TensorStore) -> Result<String> {
166        let _guard = self.op_lock.lock();
167
168        let id = uuid::Uuid::new_v4().to_string();
169        let name = name.map_or_else(
170            || {
171                let now = std::time::SystemTime::now()
172                    .duration_since(std::time::UNIX_EPOCH)
173                    .map(|d| d.as_secs())
174                    .unwrap_or(0);
175                format!("checkpoint-{now}")
176            },
177            String::from,
178        );
179
180        let metadata = Self::collect_metadata(tensor_store);
181        let snapshot_bytes = tensor_store
182            .snapshot_bytes()
183            .map_err(|e| CheckpointError::Snapshot(e.to_string()))?;
184
185        let state = CheckpointState::new(id.clone(), name, snapshot_bytes, metadata);
186
187        self.store.store(&state)?;
188        self.retention.enforce(self.store.as_ref())?;
189
190        Ok(id)
191    }
192
193    /// Create an auto-checkpoint before a destructive operation.
194    ///
195    /// # Errors
196    ///
197    /// Returns an error if the snapshot cannot be created or the checkpoint cannot be stored.
198    pub fn create_auto(
199        &self,
200        command: &str,
201        op: DestructiveOp,
202        preview: OperationPreview,
203        tensor_store: &TensorStore,
204    ) -> Result<String> {
205        let _guard = self.op_lock.lock();
206
207        let id = uuid::Uuid::new_v4().to_string();
208        let name = format!(
209            "auto-before-{}",
210            op.operation_name().to_lowercase().replace(' ', "-")
211        );
212
213        let trigger = CheckpointTrigger::new(command.to_string(), op, preview);
214        let metadata = Self::collect_metadata(tensor_store);
215        let snapshot_bytes = tensor_store
216            .snapshot_bytes()
217            .map_err(|e| CheckpointError::Snapshot(e.to_string()))?;
218
219        let state =
220            CheckpointState::new(id.clone(), name, snapshot_bytes, metadata).with_trigger(trigger);
221
222        self.store.store(&state)?;
223        self.retention.enforce(self.store.as_ref())?;
224
225        Ok(id)
226    }
227
228    /// Request confirmation for a destructive operation.
229    #[must_use]
230    pub fn request_confirmation(&self, op: &DestructiveOp, preview: &OperationPreview) -> bool {
231        if !self.config.interactive_confirm {
232            return true;
233        }
234
235        self.confirm_handler
236            .read()
237            .as_ref()
238            .map_or(true, |handler| handler.confirm(op, preview))
239    }
240
241    /// Generate a preview for a destructive operation.
242    #[must_use]
243    pub fn generate_preview(
244        &self,
245        op: &DestructiveOp,
246        sample_data: Vec<String>,
247    ) -> OperationPreview {
248        self.preview_gen.generate(op, sample_data)
249    }
250
251    /// List checkpoints, most recent first.
252    ///
253    /// # Errors
254    ///
255    /// Returns an error if the backing store cannot be enumerated.
256    pub fn list(&self, limit: Option<usize>) -> Result<Vec<CheckpointInfo>> {
257        self.store.list(limit)
258    }
259
260    /// Rollback to a checkpoint by ID or name.
261    ///
262    /// # Errors
263    ///
264    /// Returns an error if the checkpoint is not found or the snapshot cannot be restored.
265    pub fn rollback(&self, id_or_name: &str, tensor_store: &TensorStore) -> Result<()> {
266        let _guard = self.op_lock.lock();
267
268        let state = self.store.load(id_or_name)?;
269
270        tensor_store
271            .restore_from_bytes(&state.store_snapshot)
272            .map_err(|e| CheckpointError::Snapshot(e.to_string()))?;
273
274        Ok(())
275    }
276
277    /// Delete a checkpoint by ID or name.
278    ///
279    /// # Errors
280    ///
281    /// Returns an error if the checkpoint is not found or cannot be removed.
282    pub fn delete(&self, id_or_name: &str) -> Result<()> {
283        let _guard = self.op_lock.lock();
284        self.store.delete(id_or_name)
285    }
286
287    /// Returns whether auto-checkpoints are enabled for destructive operations.
288    #[must_use]
289    pub const fn auto_checkpoint_enabled(&self) -> bool {
290        self.config.auto_checkpoint
291    }
292
293    /// Returns whether interactive confirmation prompts are enabled.
294    #[must_use]
295    pub const fn interactive_confirm_enabled(&self) -> bool {
296        self.config.interactive_confirm
297    }
298
299    fn collect_metadata(store: &TensorStore) -> CheckpointMetadata {
300        let store_key_count = store.len();
301
302        // Count relational tables
303        let table_keys: Vec<_> = store.scan("_schema:");
304        let table_count = table_keys.len();
305        let mut total_rows = 0;
306
307        for key in &table_keys {
308            if let Some(table_name) = key.strip_prefix("_schema:") {
309                total_rows += store.scan_count(&format!("{table_name}:"));
310            }
311        }
312
313        // Count graph entities
314        let node_count = store.scan_count("node:");
315        let edge_count = store.scan_count("edge:");
316
317        // Count embeddings
318        let embedding_count = store.scan_count("_embed:");
319
320        CheckpointMetadata::new(
321            RelationalMeta::new(table_count, total_rows),
322            GraphMeta::new(node_count, edge_count),
323            VectorMeta::new(embedding_count),
324            store_key_count,
325        )
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use tensor_store::{ScalarValue, TensorData, TensorValue};
332
333    use super::*;
334
335    fn make_tensor(key: &str, value: &str) -> TensorData {
336        let mut t = TensorData::new();
337        t.set(
338            key,
339            TensorValue::Scalar(ScalarValue::String(value.to_string())),
340        );
341        t
342    }
343
344    fn setup_with_dir() -> (CheckpointManager, TensorStore, tempfile::TempDir) {
345        let dir = tempfile::tempdir().unwrap();
346        let store = TensorStore::new();
347        let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
348        let config = CheckpointConfig::default();
349        let manager = CheckpointManager::new(file_store, config);
350        (manager, store, dir)
351    }
352
353    #[test]
354    fn test_create_manual_checkpoint() {
355        let (manager, store, _dir) = setup_with_dir();
356
357        store.put("user:1", make_tensor("name", "Alice")).unwrap();
358
359        let id = manager.create(Some("my-checkpoint"), &store).unwrap();
360        assert!(!id.is_empty());
361
362        let list = manager.list(None).unwrap();
363        assert_eq!(list.len(), 1);
364        assert_eq!(list[0].name, "my-checkpoint");
365    }
366
367    #[test]
368    fn test_create_auto_checkpoint() {
369        let (manager, store, _dir) = setup_with_dir();
370
371        let op = DestructiveOp::Delete {
372            table: "users".to_string(),
373            row_count: 5,
374        };
375        let preview = OperationPreview::new("Deleting 5 rows".to_string(), vec![], 5);
376
377        let id = manager
378            .create_auto("DELETE FROM users", op, preview, &store)
379            .unwrap();
380        assert!(!id.is_empty());
381
382        let list = manager.list(None).unwrap();
383        assert_eq!(list.len(), 1);
384        assert!(list[0].name.starts_with("auto-before-"));
385    }
386
387    #[test]
388    fn test_rollback() {
389        let (manager, store, _dir) = setup_with_dir();
390
391        store.put("user:1", make_tensor("name", "Alice")).unwrap();
392
393        let id = manager.create(Some("before-delete"), &store).unwrap();
394
395        store.delete("user:1").unwrap();
396        assert!(!store.exists("user:1"));
397
398        manager.rollback(&id, &store).unwrap();
399
400        assert!(store.exists("user:1"));
401        let data = store.get("user:1").unwrap();
402        assert_eq!(
403            data.get("name"),
404            Some(&TensorValue::Scalar(ScalarValue::String(
405                "Alice".to_string()
406            )))
407        );
408    }
409
410    #[test]
411    fn test_rollback_by_name() {
412        let (manager, store, _dir) = setup_with_dir();
413
414        store.put("key", make_tensor("val", "original")).unwrap();
415
416        manager.create(Some("named-checkpoint"), &store).unwrap();
417
418        store.delete("key").unwrap();
419
420        manager.rollback("named-checkpoint", &store).unwrap();
421
422        assert!(store.exists("key"));
423    }
424
425    #[test]
426    fn test_retention() {
427        let dir = tempfile::tempdir().unwrap();
428        let store = TensorStore::new();
429        let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
430        let config = CheckpointConfig::default().with_max_checkpoints(2);
431        let manager = CheckpointManager::new(file_store, config);
432
433        for i in 0..5 {
434            manager.create(Some(&format!("cp-{i}")), &store).unwrap();
435        }
436
437        let list = manager.list(None).unwrap();
438        assert_eq!(list.len(), 2);
439
440        for cp in &list {
441            assert!(cp.name.starts_with("cp-"));
442        }
443    }
444
445    #[test]
446    fn test_confirmation_handler() {
447        let (manager, _store, _dir) = setup_with_dir();
448
449        manager.set_confirmation_handler(Arc::new(AutoReject));
450
451        let op = DestructiveOp::Delete {
452            table: "test".to_string(),
453            row_count: 1,
454        };
455        let preview = OperationPreview::empty("test");
456
457        assert!(!manager.request_confirmation(&op, &preview));
458    }
459
460    #[test]
461    fn test_metadata_collection() {
462        let (manager, store, _dir) = setup_with_dir();
463
464        store
465            .put("_schema:users", make_tensor("name", "users"))
466            .unwrap();
467        store.put("users:1", make_tensor("name", "Alice")).unwrap();
468        store.put("users:2", make_tensor("name", "Bob")).unwrap();
469        store.put("node:1", make_tensor("label", "Person")).unwrap();
470        store.put("edge:1", make_tensor("type", "KNOWS")).unwrap();
471
472        let mut embed_data = TensorData::new();
473        embed_data.set("vec", TensorValue::Vector(vec![1.0, 2.0]));
474        store.put("_embed:doc1", embed_data).unwrap();
475
476        let id = manager.create(None, &store).unwrap();
477        let state = manager.list(None).unwrap();
478        assert_eq!(state.len(), 1);
479
480        // Load the full state to check metadata
481        let dir = tempfile::tempdir().unwrap();
482        let file_store = FileCheckpointStore::new(dir.path()).unwrap();
483        // Re-create to check: metadata is in the checkpoint
484        let loaded = manager.list(None).unwrap();
485        assert!(!loaded.is_empty());
486        // Verify id is valid
487        assert!(!id.is_empty());
488
489        // Just verify the file_store module works for metadata storage
490        let metadata = CheckpointManager::collect_metadata(&store);
491        assert_eq!(metadata.relational.table_count, 1);
492        assert_eq!(metadata.relational.total_rows, 2);
493        assert_eq!(metadata.graph.node_count, 1);
494        assert_eq!(metadata.graph.edge_count, 1);
495        assert_eq!(metadata.vector.embedding_count, 1);
496
497        drop(file_store);
498    }
499
500    #[test]
501    fn test_delete_checkpoint() {
502        let (manager, store, _dir) = setup_with_dir();
503
504        let id = manager.create(Some("to-delete"), &store).unwrap();
505        assert_eq!(manager.list(None).unwrap().len(), 1);
506
507        manager.delete(&id).unwrap();
508        assert_eq!(manager.list(None).unwrap().len(), 0);
509    }
510
511    #[test]
512    fn test_delete_by_name() {
513        let (manager, store, _dir) = setup_with_dir();
514
515        manager.create(Some("named-cp"), &store).unwrap();
516        assert_eq!(manager.list(None).unwrap().len(), 1);
517
518        manager.delete("named-cp").unwrap();
519        assert_eq!(manager.list(None).unwrap().len(), 0);
520    }
521
522    #[test]
523    fn test_delete_not_found() {
524        let (manager, _store, _dir) = setup_with_dir();
525
526        let result = manager.delete("non-existent");
527        assert!(matches!(result, Err(CheckpointError::NotFound(_))));
528    }
529
530    #[test]
531    fn test_rollback_not_found() {
532        let (manager, store, _dir) = setup_with_dir();
533
534        let result = manager.rollback("non-existent", &store);
535        assert!(matches!(result, Err(CheckpointError::NotFound(_))));
536    }
537
538    #[test]
539    fn test_config_methods() {
540        let config = CheckpointConfig::new()
541            .with_max_checkpoints(5)
542            .with_auto_checkpoint(false)
543            .with_interactive_confirm(false)
544            .with_preview_sample_size(10);
545
546        assert_eq!(config.max_checkpoints, 5);
547        assert!(!config.auto_checkpoint);
548        assert!(!config.interactive_confirm);
549        assert_eq!(config.preview_sample_size, 10);
550    }
551
552    #[test]
553    fn test_auto_checkpoint_enabled() {
554        let dir = tempfile::tempdir().unwrap();
555        let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
556        let config = CheckpointConfig::default().with_auto_checkpoint(false);
557        let manager = CheckpointManager::new(file_store, config);
558
559        assert!(!manager.auto_checkpoint_enabled());
560    }
561
562    #[test]
563    fn test_interactive_confirm_enabled() {
564        let dir = tempfile::tempdir().unwrap();
565        let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
566        let config = CheckpointConfig::default().with_interactive_confirm(false);
567        let manager = CheckpointManager::new(file_store, config);
568
569        assert!(!manager.interactive_confirm_enabled());
570    }
571
572    #[test]
573    fn test_request_confirmation_without_handler() {
574        let (manager, _store, _dir) = setup_with_dir();
575
576        let op = DestructiveOp::Delete {
577            table: "test".to_string(),
578            row_count: 1,
579        };
580        let preview = OperationPreview::empty("test");
581
582        assert!(manager.request_confirmation(&op, &preview));
583    }
584
585    #[test]
586    fn test_request_confirmation_disabled() {
587        let dir = tempfile::tempdir().unwrap();
588        let file_store = Arc::new(FileCheckpointStore::new(dir.path()).unwrap());
589        let config = CheckpointConfig::default().with_interactive_confirm(false);
590        let manager = CheckpointManager::new(file_store, config);
591
592        let op = DestructiveOp::Delete {
593            table: "test".to_string(),
594            row_count: 1,
595        };
596        let preview = OperationPreview::empty("test");
597
598        assert!(manager.request_confirmation(&op, &preview));
599    }
600
601    #[test]
602    fn test_auto_confirm_handler() {
603        let (manager, _store, _dir) = setup_with_dir();
604
605        manager.set_confirmation_handler(Arc::new(AutoConfirm));
606
607        let op = DestructiveOp::Delete {
608            table: "test".to_string(),
609            row_count: 1,
610        };
611        let preview = OperationPreview::empty("test");
612
613        assert!(manager.request_confirmation(&op, &preview));
614    }
615
616    #[test]
617    fn test_generate_preview() {
618        let (manager, _store, _dir) = setup_with_dir();
619
620        let op = DestructiveOp::Delete {
621            table: "users".to_string(),
622            row_count: 10,
623        };
624        let sample = vec!["row1".to_string(), "row2".to_string()];
625
626        let preview = manager.generate_preview(&op, sample);
627        assert_eq!(preview.affected_count, 10);
628        assert_eq!(preview.sample_data.len(), 2);
629    }
630
631    #[test]
632    fn test_list_with_limit() {
633        let (manager, store, _dir) = setup_with_dir();
634
635        for i in 0..5 {
636            manager.create(Some(&format!("cp-{i}")), &store).unwrap();
637        }
638
639        let list = manager.list(Some(3)).unwrap();
640        assert_eq!(list.len(), 3);
641    }
642
643    #[test]
644    fn test_config_accessor() {
645        let (manager, _store, _dir) = setup_with_dir();
646
647        let config = manager.config();
648        assert_eq!(config.max_checkpoints, 10);
649    }
650
651    #[test]
652    fn test_create_unnamed_checkpoint() {
653        let (manager, store, _dir) = setup_with_dir();
654
655        let id = manager.create(None, &store).unwrap();
656        assert!(!id.is_empty());
657
658        let list = manager.list(None).unwrap();
659        assert_eq!(list.len(), 1);
660        assert!(list[0].name.starts_with("checkpoint-"));
661    }
662}