1use std::collections::{HashMap, HashSet};
18use std::fs;
19use std::path::{Path, PathBuf};
20use std::sync::Arc;
21
22use crate::error::{Error, Result};
23use crate::graph::CellId;
24use crate::salsa_db::{CellOutputData, ExecutionStatus};
25
26use super::output::BoxedOutput;
27use super::schema::{SchemaChange, TypeFingerprint};
28
29pub struct StateManager {
31 state_dir: PathBuf,
33
34 outputs: HashMap<CellId, Arc<BoxedOutput>>,
36
37 fingerprints: HashMap<CellId, TypeFingerprint>,
39
40 dirty: HashSet<CellId>,
43}
44
45impl StateManager {
46 pub fn new(state_dir: impl AsRef<Path>) -> Result<Self> {
48 let state_dir = state_dir.as_ref().to_path_buf();
49 fs::create_dir_all(&state_dir)?;
50
51 Ok(Self {
52 state_dir,
53 outputs: HashMap::new(),
54 fingerprints: HashMap::new(),
55 dirty: HashSet::new(),
56 })
57 }
58
59 pub fn save<T: super::output::CellOutput>(&mut self, cell_id: CellId, value: &T) -> Result<()> {
61 let boxed = BoxedOutput::new(value)?;
62 self.outputs.insert(cell_id, Arc::new(boxed));
63 self.dirty.insert(cell_id);
64 Ok(())
65 }
66
67 pub fn load<T>(&self, cell_id: CellId) -> Result<T>
69 where
70 T: super::output::CellOutput + rkyv::Archive,
71 T::Archived: rkyv::Deserialize<T, rkyv::rancor::Strategy<rkyv::de::Pool, rkyv::rancor::Error>>,
72 {
73 if let Some(boxed) = self.outputs.get(&cell_id) {
75 return boxed.deserialize();
76 }
77
78 let path = self.output_path(cell_id);
80 if path.exists() {
81 let bytes = fs::read(&path)?;
82 let boxed: BoxedOutput = rkyv::from_bytes::<BoxedOutput, rkyv::rancor::Error>(&bytes)
83 .map_err(|e| Error::Deserialization(e.to_string()))?;
84 return boxed.deserialize();
85 }
86
87 Err(Error::CellNotFound(format!(
88 "No output for cell {:?}",
89 cell_id
90 )))
91 }
92
93 pub fn get_output(&self, cell_id: CellId) -> Option<Arc<BoxedOutput>> {
95 self.outputs.get(&cell_id).cloned()
96 }
97
98 pub fn store_output(&mut self, cell_id: CellId, output: BoxedOutput) {
102 self.outputs.insert(cell_id, Arc::new(output));
103 self.dirty.insert(cell_id);
104 }
105
106 pub fn has_output(&self, cell_id: CellId) -> bool {
108 self.outputs.contains_key(&cell_id) || self.output_path(cell_id).exists()
109 }
110
111 pub fn invalidate(&mut self, cell_id: CellId) {
113 self.outputs.remove(&cell_id);
114 self.fingerprints.remove(&cell_id);
115
116 let path = self.output_path(cell_id);
118 let _ = fs::remove_file(path);
119 }
120
121 pub fn invalidate_many(&mut self, cell_ids: &[CellId]) {
123 for &cell_id in cell_ids {
124 self.invalidate(cell_id);
125 }
126 }
127
128 pub fn on_cell_modified(&mut self, cell_id: CellId, dependents: &[CellId]) -> Vec<CellId> {
132 let mut invalidated = vec![cell_id];
133 invalidated.extend_from_slice(dependents);
134
135 for &id in &invalidated {
136 self.invalidate(id);
137 }
138
139 invalidated
140 }
141
142 pub fn update_fingerprint(
144 &mut self,
145 cell_id: CellId,
146 new_fingerprint: TypeFingerprint,
147 ) -> SchemaChange {
148 if let Some(old) = self.fingerprints.get(&cell_id) {
149 let change = old.compare(&new_fingerprint);
150
151 if change.is_breaking() {
152 self.invalidate(cell_id);
154 tracing::warn!(
155 "Schema change for cell {:?}: {}",
156 cell_id,
157 change.description()
158 );
159 }
160
161 self.fingerprints.insert(cell_id, new_fingerprint);
162 change
163 } else {
164 self.fingerprints.insert(cell_id, new_fingerprint);
165 SchemaChange::None
166 }
167 }
168
169 pub fn flush(&mut self) -> Result<()> {
174 let dirty_cells: Vec<_> = self.dirty.drain().collect();
175 let mut failed_cells = Vec::new();
176 let mut last_error = None;
177
178 for cell_id in dirty_cells {
179 if let Some(boxed) = self.outputs.get(&cell_id) {
180 let path = self.output_path(cell_id);
181
182 let result = (|| -> Result<()> {
184 if let Some(parent) = path.parent() {
186 fs::create_dir_all(parent)?;
187 }
188
189 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(boxed.as_ref())
190 .map_err(|e| Error::Serialization(e.to_string()))?;
191
192 let temp_path = path.with_extension("tmp");
194 fs::write(&temp_path, &bytes)?;
195 fs::rename(&temp_path, &path)?;
196
197 Ok(())
198 })();
199
200 if let Err(e) = result {
201 failed_cells.push(cell_id);
203 last_error = Some(e);
204 }
205 }
206 }
207
208 for cell_id in failed_cells {
210 self.dirty.insert(cell_id);
211 }
212
213 if let Some(e) = last_error {
215 return Err(e);
216 }
217
218 Ok(())
219 }
220
221 pub fn restore(&mut self) -> Result<usize> {
223 let outputs_dir = self.state_dir.join("outputs");
224 if !outputs_dir.exists() {
225 return Ok(0);
226 }
227
228 let mut count = 0;
229 for entry in fs::read_dir(&outputs_dir)? {
230 let entry = entry?;
231 let path = entry.path();
232
233 if path.extension().is_some_and(|e| e == "bin")
234 && let Some(stem) = path.file_stem().and_then(|s| s.to_str())
235 && let Ok(id) = stem.parse::<usize>()
236 {
237 let cell_id = CellId::new(id);
238 let bytes = fs::read(&path)?;
239
240 match rkyv::from_bytes::<BoxedOutput, rkyv::rancor::Error>(&bytes) {
241 Ok(boxed) => {
242 self.outputs.insert(cell_id, Arc::new(boxed));
243 count += 1;
244 }
245 Err(e) => {
246 tracing::warn!("Failed to restore output for cell {}: {}", id, e);
247 }
248 }
249 }
250 }
251
252 tracing::info!("Restored {} cached outputs", count);
253 Ok(count)
254 }
255
256 fn output_path(&self, cell_id: CellId) -> PathBuf {
258 self.state_dir
259 .join("outputs")
260 .join(format!("{}.bin", cell_id.as_usize()))
261 }
262
263 pub fn sync_output_to_salsa(
277 &self,
278 cell_id: CellId,
279 inputs_hash: u64,
280 execution_time_ms: u64,
281 ) -> Option<CellOutputData> {
282 self.outputs.get(&cell_id).map(|boxed| {
283 CellOutputData::from_boxed(cell_id.as_usize(), boxed, inputs_hash, execution_time_ms)
284 })
285 }
286
287 pub fn sync_all_to_salsa<F, G>(
298 &self,
299 cell_count: usize,
300 get_inputs_hash: F,
301 get_execution_time: G,
302 ) -> Vec<ExecutionStatus>
303 where
304 F: Fn(CellId) -> u64,
305 G: Fn(CellId) -> u64,
306 {
307 (0..cell_count)
308 .map(|idx| {
309 let cell_id = CellId::new(idx);
310 if let Some(boxed) = self.outputs.get(&cell_id) {
311 let output_data = CellOutputData::from_boxed(
312 idx,
313 boxed,
314 get_inputs_hash(cell_id),
315 get_execution_time(cell_id),
316 );
317 ExecutionStatus::Success(output_data)
318 } else {
319 ExecutionStatus::Pending
320 }
321 })
322 .collect()
323 }
324
325 pub fn load_from_salsa(&mut self, output_data: &CellOutputData) {
329 let cell_id = CellId::new(output_data.cell_id);
330 let boxed = output_data.to_boxed();
331 self.outputs.insert(cell_id, Arc::new(boxed));
332 }
334
335 pub fn load_all_from_salsa(&mut self, statuses: &[ExecutionStatus]) -> usize {
339 let mut count = 0;
340 for status in statuses {
341 if let ExecutionStatus::Success(output_data) = status {
342 self.load_from_salsa(output_data);
343 count += 1;
344 }
345 }
346 count
347 }
348
349 pub fn is_salsa_output_valid(&self, cell_id: CellId, _current_inputs_hash: u64) -> bool {
358 self.has_output(cell_id)
359 }
360
361 pub fn clear(&mut self) -> Result<()> {
363 self.outputs.clear();
364 self.fingerprints.clear();
365 self.dirty.clear();
366
367 let outputs_dir = self.state_dir.join("outputs");
368 if outputs_dir.exists() {
369 fs::remove_dir_all(&outputs_dir)?;
370 }
371
372 Ok(())
373 }
374
375 pub fn stats(&self) -> StateStats {
377 StateStats {
378 cached_outputs: self.outputs.len(),
379 dirty_outputs: self.dirty.len(),
380 fingerprints: self.fingerprints.len(),
381 }
382 }
383}
384
385#[derive(Debug, Clone)]
387pub struct StateStats {
388 pub cached_outputs: usize,
390
391 pub dirty_outputs: usize,
393
394 pub fingerprints: usize,
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use rkyv::{Archive, Deserialize, Serialize};
402 use tempfile::TempDir;
403
404 #[derive(Debug, Clone, PartialEq, Archive, Serialize, Deserialize)]
405 struct TestOutput {
406 value: i32,
407 }
408
409 fn setup() -> (StateManager, TempDir) {
410 let temp = TempDir::new().unwrap();
411 let manager = StateManager::new(temp.path()).unwrap();
412 (manager, temp)
413 }
414
415 #[test]
416 fn test_save_and_load() {
417 let (mut manager, _temp) = setup();
418 let cell_id = CellId::new(0);
419
420 let output = TestOutput { value: 42 };
421 manager.save(cell_id, &output).unwrap();
422
423 let loaded: TestOutput = manager.load(cell_id).unwrap();
424 assert_eq!(output, loaded);
425 }
426
427 #[test]
428 fn test_invalidate() {
429 let (mut manager, _temp) = setup();
430 let cell_id = CellId::new(0);
431
432 let output = TestOutput { value: 42 };
433 manager.save(cell_id, &output).unwrap();
434
435 assert!(manager.has_output(cell_id));
436 manager.invalidate(cell_id);
437 assert!(!manager.has_output(cell_id));
438 }
439
440 #[test]
441 fn test_persist_and_restore() {
442 let temp = TempDir::new().unwrap();
443 let cell_id = CellId::new(0);
444
445 {
446 let mut manager = StateManager::new(temp.path()).unwrap();
447 let output = TestOutput { value: 42 };
448 manager.save(cell_id, &output).unwrap();
449 manager.flush().unwrap();
450 }
451
452 {
453 let mut manager = StateManager::new(temp.path()).unwrap();
454 manager.restore().unwrap();
455 let loaded: TestOutput = manager.load(cell_id).unwrap();
456 assert_eq!(loaded.value, 42);
457 }
458 }
459
460 #[test]
461 fn test_on_cell_modified() {
462 let (mut manager, _temp) = setup();
463
464 let cell0 = CellId::new(0);
465 let cell1 = CellId::new(1);
466 let cell2 = CellId::new(2);
467
468 manager.save(cell0, &TestOutput { value: 0 }).unwrap();
470 manager.save(cell1, &TestOutput { value: 1 }).unwrap();
471 manager.save(cell2, &TestOutput { value: 2 }).unwrap();
472
473 let invalidated = manager.on_cell_modified(cell0, &[cell1, cell2]);
475
476 assert_eq!(invalidated.len(), 3);
477 assert!(!manager.has_output(cell0));
478 assert!(!manager.has_output(cell1));
479 assert!(!manager.has_output(cell2));
480 }
481
482 #[test]
483 fn test_schema_change_detection() {
484 let (mut manager, _temp) = setup();
485 let cell_id = CellId::new(0);
486
487 manager.save(cell_id, &TestOutput { value: 42 }).unwrap();
489
490 let fp1 =
492 TypeFingerprint::new("TestOutput", vec![("value".to_string(), "i32".to_string())]);
493 let change = manager.update_fingerprint(cell_id, fp1);
494 assert!(!change.is_breaking());
495
496 let fp2 =
498 TypeFingerprint::new("TestOutput", vec![("value".to_string(), "i32".to_string())]);
499 let change = manager.update_fingerprint(cell_id, fp2);
500 assert!(!change.is_breaking());
501 assert!(manager.has_output(cell_id)); let fp3 = TypeFingerprint::new(
505 "TestOutput",
506 vec![("value".to_string(), "i64".to_string())], );
508 let change = manager.update_fingerprint(cell_id, fp3);
509 assert!(change.is_breaking());
510 assert!(!manager.has_output(cell_id)); }
512
513 #[test]
514 fn test_sync_output_to_salsa() {
515 let (mut manager, _temp) = setup();
516 let cell_id = CellId::new(0);
517
518 assert!(manager.sync_output_to_salsa(cell_id, 12345, 100).is_none());
520
521 manager.save(cell_id, &TestOutput { value: 42 }).unwrap();
523
524 let output_data = manager.sync_output_to_salsa(cell_id, 12345, 100).unwrap();
526 assert_eq!(output_data.cell_id, 0);
527 assert_eq!(output_data.inputs_hash, 12345);
528 assert_eq!(output_data.execution_time_ms, 100);
529 assert!(!output_data.bytes.is_empty());
530 }
531
532 #[test]
533 fn test_sync_all_to_salsa() {
534 let (mut manager, _temp) = setup();
535
536 manager.save(CellId::new(0), &TestOutput { value: 0 }).unwrap();
538 manager.save(CellId::new(2), &TestOutput { value: 2 }).unwrap();
539
540 let statuses = manager.sync_all_to_salsa(
541 3,
542 |cell_id| cell_id.as_usize() as u64 * 100, |cell_id| cell_id.as_usize() as u64 * 10, );
545
546 assert_eq!(statuses.len(), 3);
547 assert!(matches!(statuses[0], ExecutionStatus::Success(_)));
548 assert!(matches!(statuses[1], ExecutionStatus::Pending));
549 assert!(matches!(statuses[2], ExecutionStatus::Success(_)));
550
551 if let ExecutionStatus::Success(data) = &statuses[0] {
553 assert_eq!(data.inputs_hash, 0);
554 assert_eq!(data.execution_time_ms, 0);
555 }
556
557 if let ExecutionStatus::Success(data) = &statuses[2] {
559 assert_eq!(data.inputs_hash, 200);
560 assert_eq!(data.execution_time_ms, 20);
561 }
562 }
563
564 #[test]
565 fn test_load_from_salsa() {
566 let (mut manager, _temp) = setup();
567 let cell_id = CellId::new(0);
568
569 let output = TestOutput { value: 99 };
571 let boxed = BoxedOutput::new(&output).unwrap();
572 let output_data = CellOutputData::from_boxed(0, &boxed, 12345, 50);
573
574 manager.load_from_salsa(&output_data);
576
577 assert!(manager.has_output(cell_id));
579 let loaded: TestOutput = manager.load(cell_id).unwrap();
580 assert_eq!(loaded.value, 99);
581
582 assert!(!manager.dirty.contains(&cell_id));
584 }
585
586 #[test]
587 fn test_load_all_from_salsa() {
588 let (mut manager, _temp) = setup();
589
590 let output0 = TestOutput { value: 100 };
592 let boxed0 = BoxedOutput::new(&output0).unwrap();
593 let data0 = CellOutputData::from_boxed(0, &boxed0, 0, 0);
594
595 let output2 = TestOutput { value: 200 };
596 let boxed2 = BoxedOutput::new(&output2).unwrap();
597 let data2 = CellOutputData::from_boxed(2, &boxed2, 0, 0);
598
599 let statuses = vec![
600 ExecutionStatus::Success(data0),
601 ExecutionStatus::Pending,
602 ExecutionStatus::Success(data2),
603 ExecutionStatus::Failed("error".to_string()),
604 ];
605
606 let count = manager.load_all_from_salsa(&statuses);
608 assert_eq!(count, 2); assert!(manager.has_output(CellId::new(0)));
612 assert!(!manager.has_output(CellId::new(1))); assert!(manager.has_output(CellId::new(2)));
614 assert!(!manager.has_output(CellId::new(3))); let loaded0: TestOutput = manager.load(CellId::new(0)).unwrap();
617 assert_eq!(loaded0.value, 100);
618
619 let loaded2: TestOutput = manager.load(CellId::new(2)).unwrap();
620 assert_eq!(loaded2.value, 200);
621 }
622}