venus_core/state/
manager.rs

1//! State manager for Venus notebooks.
2//!
3//! Handles saving and loading cell outputs with automatic format selection.
4//!
5//! # Salsa Integration
6//!
7//! The StateManager can sync its outputs with Salsa's incremental computation
8//! system via the [`CellOutputData`] and [`ExecutionStatus`] types. Use:
9//!
10//! - [`sync_output_to_salsa()`](StateManager::sync_output_to_salsa) to convert
11//!   a single output for Salsa tracking
12//! - [`sync_all_to_salsa()`](StateManager::sync_all_to_salsa) to export all
13//!   outputs as a vector of execution statuses
14//! - [`load_from_salsa()`](StateManager::load_from_salsa) to import an output
15//!   from Salsa's cached data
16
17use std::collections::{HashMap, HashSet};
18use std::fs;
19use std::path::{Path, PathBuf};
20use std::sync::Arc;
21
22use crate::error::{Error, Result};
23use crate::graph::CellId;
24use crate::salsa_db::{CellOutputData, ExecutionStatus};
25
26use super::output::BoxedOutput;
27use super::schema::{SchemaChange, TypeFingerprint};
28
29/// Manages cell state persistence and invalidation.
30pub struct StateManager {
31    /// Base directory for state storage
32    state_dir: PathBuf,
33
34    /// In-memory cache of cell outputs
35    outputs: HashMap<CellId, Arc<BoxedOutput>>,
36
37    /// Type fingerprints for schema validation
38    fingerprints: HashMap<CellId, TypeFingerprint>,
39
40    /// Dirty cells that need to be persisted.
41    /// Uses HashSet to avoid duplicate writes when save() is called multiple times.
42    dirty: HashSet<CellId>,
43}
44
45impl StateManager {
46    /// Create a new state manager with the given state directory.
47    pub fn new(state_dir: impl AsRef<Path>) -> Result<Self> {
48        let state_dir = state_dir.as_ref().to_path_buf();
49        fs::create_dir_all(&state_dir)?;
50
51        Ok(Self {
52            state_dir,
53            outputs: HashMap::new(),
54            fingerprints: HashMap::new(),
55            dirty: HashSet::new(),
56        })
57    }
58
59    /// Save a cell output.
60    pub fn save<T: super::output::CellOutput>(&mut self, cell_id: CellId, value: &T) -> Result<()> {
61        let boxed = BoxedOutput::new(value)?;
62        self.outputs.insert(cell_id, Arc::new(boxed));
63        self.dirty.insert(cell_id);
64        Ok(())
65    }
66
67    /// Load a cell output.
68    pub fn load<T>(&self, cell_id: CellId) -> Result<T>
69    where
70        T: super::output::CellOutput + rkyv::Archive,
71        T::Archived: rkyv::Deserialize<T, rkyv::rancor::Strategy<rkyv::de::Pool, rkyv::rancor::Error>>,
72    {
73        // Try in-memory cache first
74        if let Some(boxed) = self.outputs.get(&cell_id) {
75            return boxed.deserialize();
76        }
77
78        // Try loading from disk
79        let path = self.output_path(cell_id);
80        if path.exists() {
81            let bytes = fs::read(&path)?;
82            let boxed: BoxedOutput = rkyv::from_bytes::<BoxedOutput, rkyv::rancor::Error>(&bytes)
83                .map_err(|e| Error::Deserialization(e.to_string()))?;
84            return boxed.deserialize();
85        }
86
87        Err(Error::CellNotFound(format!(
88            "No output for cell {:?}",
89            cell_id
90        )))
91    }
92
93    /// Get a reference to a cached output without deserializing.
94    pub fn get_output(&self, cell_id: CellId) -> Option<Arc<BoxedOutput>> {
95        self.outputs.get(&cell_id).cloned()
96    }
97
98    /// Store a pre-serialized output directly.
99    ///
100    /// Used by the execution engine to store outputs from FFI calls.
101    pub fn store_output(&mut self, cell_id: CellId, output: BoxedOutput) {
102        self.outputs.insert(cell_id, Arc::new(output));
103        self.dirty.insert(cell_id);
104    }
105
106    /// Check if a cell has a cached output.
107    pub fn has_output(&self, cell_id: CellId) -> bool {
108        self.outputs.contains_key(&cell_id) || self.output_path(cell_id).exists()
109    }
110
111    /// Invalidate a cell's output (e.g., when its source changes).
112    pub fn invalidate(&mut self, cell_id: CellId) {
113        self.outputs.remove(&cell_id);
114        self.fingerprints.remove(&cell_id);
115
116        // Remove from disk
117        let path = self.output_path(cell_id);
118        let _ = fs::remove_file(path);
119    }
120
121    /// Invalidate multiple cells.
122    pub fn invalidate_many(&mut self, cell_ids: &[CellId]) {
123        for &cell_id in cell_ids {
124            self.invalidate(cell_id);
125        }
126    }
127
128    /// Called when a cell is modified - invalidates it and all dependents.
129    ///
130    /// Returns the list of invalidated cell IDs.
131    pub fn on_cell_modified(&mut self, cell_id: CellId, dependents: &[CellId]) -> Vec<CellId> {
132        let mut invalidated = vec![cell_id];
133        invalidated.extend_from_slice(dependents);
134
135        for &id in &invalidated {
136            self.invalidate(id);
137        }
138
139        invalidated
140    }
141
142    /// Update the type fingerprint for a cell and check for schema changes.
143    pub fn update_fingerprint(
144        &mut self,
145        cell_id: CellId,
146        new_fingerprint: TypeFingerprint,
147    ) -> SchemaChange {
148        if let Some(old) = self.fingerprints.get(&cell_id) {
149            let change = old.compare(&new_fingerprint);
150
151            if change.is_breaking() {
152                // Invalidate cached output on breaking change
153                self.invalidate(cell_id);
154                tracing::warn!(
155                    "Schema change for cell {:?}: {}",
156                    cell_id,
157                    change.description()
158                );
159            }
160
161            self.fingerprints.insert(cell_id, new_fingerprint);
162            change
163        } else {
164            self.fingerprints.insert(cell_id, new_fingerprint);
165            SchemaChange::None
166        }
167    }
168
169    /// Persist all dirty outputs to disk.
170    ///
171    /// Uses atomic write pattern to prevent partial state corruption.
172    /// If any write fails, failed cells remain marked as dirty.
173    pub fn flush(&mut self) -> Result<()> {
174        let dirty_cells: Vec<_> = self.dirty.drain().collect();
175        let mut failed_cells = Vec::new();
176        let mut last_error = None;
177
178        for cell_id in dirty_cells {
179            if let Some(boxed) = self.outputs.get(&cell_id) {
180                let path = self.output_path(cell_id);
181
182                // Attempt to write this cell
183                let result = (|| -> Result<()> {
184                    // Ensure parent directory exists
185                    if let Some(parent) = path.parent() {
186                        fs::create_dir_all(parent)?;
187                    }
188
189                    let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(boxed.as_ref())
190                        .map_err(|e| Error::Serialization(e.to_string()))?;
191
192                    // Use atomic write pattern: write to temp file, then rename
193                    let temp_path = path.with_extension("tmp");
194                    fs::write(&temp_path, &bytes)?;
195                    fs::rename(&temp_path, &path)?;
196
197                    Ok(())
198                })();
199
200                if let Err(e) = result {
201                    // Track failed cell to re-add to dirty set
202                    failed_cells.push(cell_id);
203                    last_error = Some(e);
204                }
205            }
206        }
207
208        // Re-add failed cells to dirty set
209        for cell_id in failed_cells {
210            self.dirty.insert(cell_id);
211        }
212
213        // Return error if any writes failed
214        if let Some(e) = last_error {
215            return Err(e);
216        }
217
218        Ok(())
219    }
220
221    /// Load all cached outputs from disk.
222    pub fn restore(&mut self) -> Result<usize> {
223        let outputs_dir = self.state_dir.join("outputs");
224        if !outputs_dir.exists() {
225            return Ok(0);
226        }
227
228        let mut count = 0;
229        for entry in fs::read_dir(&outputs_dir)? {
230            let entry = entry?;
231            let path = entry.path();
232
233            if path.extension().is_some_and(|e| e == "bin")
234                && let Some(stem) = path.file_stem().and_then(|s| s.to_str())
235                && let Ok(id) = stem.parse::<usize>()
236            {
237                let cell_id = CellId::new(id);
238                let bytes = fs::read(&path)?;
239
240                match rkyv::from_bytes::<BoxedOutput, rkyv::rancor::Error>(&bytes) {
241                    Ok(boxed) => {
242                        self.outputs.insert(cell_id, Arc::new(boxed));
243                        count += 1;
244                    }
245                    Err(e) => {
246                        tracing::warn!("Failed to restore output for cell {}: {}", id, e);
247                    }
248                }
249            }
250        }
251
252        tracing::info!("Restored {} cached outputs", count);
253        Ok(count)
254    }
255
256    /// Get the path for a cell's output file.
257    fn output_path(&self, cell_id: CellId) -> PathBuf {
258        self.state_dir
259            .join("outputs")
260            .join(format!("{}.bin", cell_id.as_usize()))
261    }
262
263    // =========================================================================
264    // Salsa Integration
265    // =========================================================================
266
267    /// Convert a single cell output to Salsa-compatible format.
268    ///
269    /// Returns `None` if the cell has no cached output.
270    ///
271    /// # Arguments
272    ///
273    /// * `cell_id` - The cell to export
274    /// * `inputs_hash` - Hash of the cell's input values (for staleness detection)
275    /// * `execution_time_ms` - How long the cell took to execute
276    pub fn sync_output_to_salsa(
277        &self,
278        cell_id: CellId,
279        inputs_hash: u64,
280        execution_time_ms: u64,
281    ) -> Option<CellOutputData> {
282        self.outputs.get(&cell_id).map(|boxed| {
283            CellOutputData::from_boxed(cell_id.as_usize(), boxed, inputs_hash, execution_time_ms)
284        })
285    }
286
287    /// Export all outputs to a vector of execution statuses for Salsa.
288    ///
289    /// Creates a vector sized to `cell_count` where each index corresponds
290    /// to a cell ID. Cells without outputs are marked as `Pending`.
291    ///
292    /// # Arguments
293    ///
294    /// * `cell_count` - Total number of cells in the notebook
295    /// * `get_inputs_hash` - Closure to get the inputs hash for each cell
296    /// * `get_execution_time` - Closure to get execution time for each cell (0 if unknown)
297    pub fn sync_all_to_salsa<F, G>(
298        &self,
299        cell_count: usize,
300        get_inputs_hash: F,
301        get_execution_time: G,
302    ) -> Vec<ExecutionStatus>
303    where
304        F: Fn(CellId) -> u64,
305        G: Fn(CellId) -> u64,
306    {
307        (0..cell_count)
308            .map(|idx| {
309                let cell_id = CellId::new(idx);
310                if let Some(boxed) = self.outputs.get(&cell_id) {
311                    let output_data = CellOutputData::from_boxed(
312                        idx,
313                        boxed,
314                        get_inputs_hash(cell_id),
315                        get_execution_time(cell_id),
316                    );
317                    ExecutionStatus::Success(output_data)
318                } else {
319                    ExecutionStatus::Pending
320                }
321            })
322            .collect()
323    }
324
325    /// Import an output from Salsa's cached data.
326    ///
327    /// Converts a `CellOutputData` back to a `BoxedOutput` and stores it.
328    pub fn load_from_salsa(&mut self, output_data: &CellOutputData) {
329        let cell_id = CellId::new(output_data.cell_id);
330        let boxed = output_data.to_boxed();
331        self.outputs.insert(cell_id, Arc::new(boxed));
332        // Don't mark as dirty - this came from Salsa, not from execution
333    }
334
335    /// Import all successful outputs from Salsa's execution statuses.
336    ///
337    /// Returns the number of outputs imported.
338    pub fn load_all_from_salsa(&mut self, statuses: &[ExecutionStatus]) -> usize {
339        let mut count = 0;
340        for status in statuses {
341            if let ExecutionStatus::Success(output_data) = status {
342                self.load_from_salsa(output_data);
343                count += 1;
344            }
345        }
346        count
347    }
348
349    /// Check if a Salsa output is still valid for the current inputs.
350    ///
351    /// Note: The StateManager doesn't store input hashes, so this method only
352    /// checks if an output exists. For actual validation, use the `is_valid_for()`
353    /// method on `CellOutputData` with the Salsa-cached output.
354    ///
355    /// The `_current_inputs_hash` parameter is reserved for future use when
356    /// input hash tracking is added to the StateManager.
357    pub fn is_salsa_output_valid(&self, cell_id: CellId, _current_inputs_hash: u64) -> bool {
358        self.has_output(cell_id)
359    }
360
361    /// Clear all state (for testing or reset).
362    pub fn clear(&mut self) -> Result<()> {
363        self.outputs.clear();
364        self.fingerprints.clear();
365        self.dirty.clear();
366
367        let outputs_dir = self.state_dir.join("outputs");
368        if outputs_dir.exists() {
369            fs::remove_dir_all(&outputs_dir)?;
370        }
371
372        Ok(())
373    }
374
375    /// Get statistics about the state manager.
376    pub fn stats(&self) -> StateStats {
377        StateStats {
378            cached_outputs: self.outputs.len(),
379            dirty_outputs: self.dirty.len(),
380            fingerprints: self.fingerprints.len(),
381        }
382    }
383}
384
385/// Statistics about the state manager.
386#[derive(Debug, Clone)]
387pub struct StateStats {
388    /// Number of outputs in memory cache
389    pub cached_outputs: usize,
390
391    /// Number of outputs pending persistence
392    pub dirty_outputs: usize,
393
394    /// Number of type fingerprints tracked
395    pub fingerprints: usize,
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use rkyv::{Archive, Deserialize, Serialize};
402    use tempfile::TempDir;
403
404    #[derive(Debug, Clone, PartialEq, Archive, Serialize, Deserialize)]
405    struct TestOutput {
406        value: i32,
407    }
408
409    fn setup() -> (StateManager, TempDir) {
410        let temp = TempDir::new().unwrap();
411        let manager = StateManager::new(temp.path()).unwrap();
412        (manager, temp)
413    }
414
415    #[test]
416    fn test_save_and_load() {
417        let (mut manager, _temp) = setup();
418        let cell_id = CellId::new(0);
419
420        let output = TestOutput { value: 42 };
421        manager.save(cell_id, &output).unwrap();
422
423        let loaded: TestOutput = manager.load(cell_id).unwrap();
424        assert_eq!(output, loaded);
425    }
426
427    #[test]
428    fn test_invalidate() {
429        let (mut manager, _temp) = setup();
430        let cell_id = CellId::new(0);
431
432        let output = TestOutput { value: 42 };
433        manager.save(cell_id, &output).unwrap();
434
435        assert!(manager.has_output(cell_id));
436        manager.invalidate(cell_id);
437        assert!(!manager.has_output(cell_id));
438    }
439
440    #[test]
441    fn test_persist_and_restore() {
442        let temp = TempDir::new().unwrap();
443        let cell_id = CellId::new(0);
444
445        {
446            let mut manager = StateManager::new(temp.path()).unwrap();
447            let output = TestOutput { value: 42 };
448            manager.save(cell_id, &output).unwrap();
449            manager.flush().unwrap();
450        }
451
452        {
453            let mut manager = StateManager::new(temp.path()).unwrap();
454            manager.restore().unwrap();
455            let loaded: TestOutput = manager.load(cell_id).unwrap();
456            assert_eq!(loaded.value, 42);
457        }
458    }
459
460    #[test]
461    fn test_on_cell_modified() {
462        let (mut manager, _temp) = setup();
463
464        let cell0 = CellId::new(0);
465        let cell1 = CellId::new(1);
466        let cell2 = CellId::new(2);
467
468        // Save outputs for all cells
469        manager.save(cell0, &TestOutput { value: 0 }).unwrap();
470        manager.save(cell1, &TestOutput { value: 1 }).unwrap();
471        manager.save(cell2, &TestOutput { value: 2 }).unwrap();
472
473        // Modify cell0, which has dependents cell1 and cell2
474        let invalidated = manager.on_cell_modified(cell0, &[cell1, cell2]);
475
476        assert_eq!(invalidated.len(), 3);
477        assert!(!manager.has_output(cell0));
478        assert!(!manager.has_output(cell1));
479        assert!(!manager.has_output(cell2));
480    }
481
482    #[test]
483    fn test_schema_change_detection() {
484        let (mut manager, _temp) = setup();
485        let cell_id = CellId::new(0);
486
487        // Save initial output
488        manager.save(cell_id, &TestOutput { value: 42 }).unwrap();
489
490        // First fingerprint
491        let fp1 =
492            TypeFingerprint::new("TestOutput", vec![("value".to_string(), "i32".to_string())]);
493        let change = manager.update_fingerprint(cell_id, fp1);
494        assert!(!change.is_breaking());
495
496        // Same fingerprint
497        let fp2 =
498            TypeFingerprint::new("TestOutput", vec![("value".to_string(), "i32".to_string())]);
499        let change = manager.update_fingerprint(cell_id, fp2);
500        assert!(!change.is_breaking());
501        assert!(manager.has_output(cell_id)); // Still cached
502
503        // Breaking change
504        let fp3 = TypeFingerprint::new(
505            "TestOutput",
506            vec![("value".to_string(), "i64".to_string())], // Type changed!
507        );
508        let change = manager.update_fingerprint(cell_id, fp3);
509        assert!(change.is_breaking());
510        assert!(!manager.has_output(cell_id)); // Invalidated
511    }
512
513    #[test]
514    fn test_sync_output_to_salsa() {
515        let (mut manager, _temp) = setup();
516        let cell_id = CellId::new(0);
517
518        // No output yet
519        assert!(manager.sync_output_to_salsa(cell_id, 12345, 100).is_none());
520
521        // Save an output
522        manager.save(cell_id, &TestOutput { value: 42 }).unwrap();
523
524        // Now we can sync to Salsa
525        let output_data = manager.sync_output_to_salsa(cell_id, 12345, 100).unwrap();
526        assert_eq!(output_data.cell_id, 0);
527        assert_eq!(output_data.inputs_hash, 12345);
528        assert_eq!(output_data.execution_time_ms, 100);
529        assert!(!output_data.bytes.is_empty());
530    }
531
532    #[test]
533    fn test_sync_all_to_salsa() {
534        let (mut manager, _temp) = setup();
535
536        // Save outputs for cells 0 and 2, skip cell 1
537        manager.save(CellId::new(0), &TestOutput { value: 0 }).unwrap();
538        manager.save(CellId::new(2), &TestOutput { value: 2 }).unwrap();
539
540        let statuses = manager.sync_all_to_salsa(
541            3,
542            |cell_id| cell_id.as_usize() as u64 * 100, // inputs_hash
543            |cell_id| cell_id.as_usize() as u64 * 10,  // execution_time
544        );
545
546        assert_eq!(statuses.len(), 3);
547        assert!(matches!(statuses[0], ExecutionStatus::Success(_)));
548        assert!(matches!(statuses[1], ExecutionStatus::Pending));
549        assert!(matches!(statuses[2], ExecutionStatus::Success(_)));
550
551        // Check the output data for cell 0
552        if let ExecutionStatus::Success(data) = &statuses[0] {
553            assert_eq!(data.inputs_hash, 0);
554            assert_eq!(data.execution_time_ms, 0);
555        }
556
557        // Check the output data for cell 2
558        if let ExecutionStatus::Success(data) = &statuses[2] {
559            assert_eq!(data.inputs_hash, 200);
560            assert_eq!(data.execution_time_ms, 20);
561        }
562    }
563
564    #[test]
565    fn test_load_from_salsa() {
566        let (mut manager, _temp) = setup();
567        let cell_id = CellId::new(0);
568
569        // Create a CellOutputData directly
570        let output = TestOutput { value: 99 };
571        let boxed = BoxedOutput::new(&output).unwrap();
572        let output_data = CellOutputData::from_boxed(0, &boxed, 12345, 50);
573
574        // Load from Salsa
575        manager.load_from_salsa(&output_data);
576
577        // Verify we can retrieve it
578        assert!(manager.has_output(cell_id));
579        let loaded: TestOutput = manager.load(cell_id).unwrap();
580        assert_eq!(loaded.value, 99);
581
582        // Should NOT be marked dirty (came from Salsa, not execution)
583        assert!(!manager.dirty.contains(&cell_id));
584    }
585
586    #[test]
587    fn test_load_all_from_salsa() {
588        let (mut manager, _temp) = setup();
589
590        // Create some execution statuses
591        let output0 = TestOutput { value: 100 };
592        let boxed0 = BoxedOutput::new(&output0).unwrap();
593        let data0 = CellOutputData::from_boxed(0, &boxed0, 0, 0);
594
595        let output2 = TestOutput { value: 200 };
596        let boxed2 = BoxedOutput::new(&output2).unwrap();
597        let data2 = CellOutputData::from_boxed(2, &boxed2, 0, 0);
598
599        let statuses = vec![
600            ExecutionStatus::Success(data0),
601            ExecutionStatus::Pending,
602            ExecutionStatus::Success(data2),
603            ExecutionStatus::Failed("error".to_string()),
604        ];
605
606        // Load from Salsa
607        let count = manager.load_all_from_salsa(&statuses);
608        assert_eq!(count, 2); // Only successful ones
609
610        // Verify outputs
611        assert!(manager.has_output(CellId::new(0)));
612        assert!(!manager.has_output(CellId::new(1))); // Was pending
613        assert!(manager.has_output(CellId::new(2)));
614        assert!(!manager.has_output(CellId::new(3))); // Was failed
615
616        let loaded0: TestOutput = manager.load(CellId::new(0)).unwrap();
617        assert_eq!(loaded0.value, 100);
618
619        let loaded2: TestOutput = manager.load(CellId::new(2)).unwrap();
620        assert_eq!(loaded2.value, 200);
621    }
622}