1use std::collections::HashMap;
35use std::fs;
36use std::io::{self, Read, Write};
37use std::path::Path;
38use std::time::{SystemTime, UNIX_EPOCH};
39
40use rkyv::{rancor, Archive, Deserialize, Serialize};
41
42pub const CACHE_VERSION: u32 = 1;
47
48#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
59pub struct CacheSnapshot {
60 pub version: u32,
62
63 pub toolchain_version: String,
68
69 pub dependency_hash: u64,
74
75 pub cells: HashMap<String, CachedCell>,
77
78 pub created_at: u64,
80}
81
82#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
84pub struct CachedCell {
85 pub name: String,
87
88 pub source_hash: u64,
92
93 pub dylib_path: String,
97
98 pub status: CachedCompilationStatus,
100}
101
102#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
104pub enum CachedCompilationStatus {
105 Success,
107
108 Cached,
110
111 Failed { error: String },
113}
114
115#[derive(Debug)]
117pub enum CacheError {
118 Io(io::Error),
120
121 VersionMismatch { expected: u32, found: u32 },
123
124 ToolchainMismatch { expected: String, found: String },
126
127 DependencyMismatch { expected: u64, found: u64 },
129
130 Deserialize(String),
132
133 Serialize(String),
135}
136
137impl std::fmt::Display for CacheError {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 match self {
140 CacheError::Io(e) => write!(f, "cache IO error: {}", e),
141 CacheError::VersionMismatch { expected, found } => {
142 write!(
143 f,
144 "cache version mismatch: expected {}, found {}",
145 expected, found
146 )
147 }
148 CacheError::ToolchainMismatch { expected, found } => {
149 write!(
150 f,
151 "toolchain mismatch: expected '{}', found '{}'",
152 expected, found
153 )
154 }
155 CacheError::DependencyMismatch { expected, found } => {
156 write!(
157 f,
158 "dependency hash mismatch: expected {:#x}, found {:#x}",
159 expected, found
160 )
161 }
162 CacheError::Deserialize(e) => write!(f, "cache deserialize error: {}", e),
163 CacheError::Serialize(e) => write!(f, "cache serialize error: {}", e),
164 }
165 }
166}
167
168impl std::error::Error for CacheError {
169 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
170 match self {
171 CacheError::Io(e) => Some(e),
172 _ => None,
173 }
174 }
175}
176
177impl From<io::Error> for CacheError {
178 fn from(e: io::Error) -> Self {
179 CacheError::Io(e)
180 }
181}
182
183pub struct CachePersistence;
185
186impl CachePersistence {
187 pub fn save(path: &Path, snapshot: &CacheSnapshot) -> Result<(), CacheError> {
192 if let Some(parent) = path.parent() {
194 fs::create_dir_all(parent)?;
195 }
196
197 let bytes = rkyv::to_bytes::<rancor::Error>(snapshot)
199 .map_err(|e| CacheError::Serialize(e.to_string()))?;
200
201 let pid = std::process::id();
204 let temp_path = path.with_extension(format!("tmp.{}", pid));
205
206 let mut file = fs::File::create(&temp_path)?;
207 file.write_all(&bytes)?;
208 file.sync_all()?;
209
210 let rename_result = fs::rename(&temp_path, path);
212
213 if rename_result.is_err() {
215 let _ = fs::remove_file(&temp_path);
216 }
217
218 rename_result?;
219
220 tracing::debug!(
221 "Saved cache snapshot: {} cells, {} bytes",
222 snapshot.cells.len(),
223 bytes.len()
224 );
225
226 Ok(())
227 }
228
229 pub fn load(path: &Path, expected_toolchain: &str) -> Result<Option<CacheSnapshot>, CacheError> {
239 if !path.exists() {
241 tracing::debug!("No cache file at {:?}", path);
242 return Ok(None);
243 }
244
245 let mut file = fs::File::open(path)?;
247 let mut bytes = Vec::new();
248 file.read_to_end(&mut bytes)?;
249
250 let archived = rkyv::access::<ArchivedCacheSnapshot, rancor::Error>(&bytes)
252 .map_err(|e| CacheError::Deserialize(e.to_string()))?;
253
254 let found_version: u32 = archived.version.into();
256 if found_version != CACHE_VERSION {
257 return Err(CacheError::VersionMismatch {
258 expected: CACHE_VERSION,
259 found: found_version,
260 });
261 }
262
263 let snapshot: CacheSnapshot =
265 rkyv::deserialize::<CacheSnapshot, rancor::Error>(archived)
266 .map_err(|e| CacheError::Deserialize(e.to_string()))?;
267
268 if snapshot.toolchain_version != expected_toolchain {
270 return Err(CacheError::ToolchainMismatch {
271 expected: expected_toolchain.to_string(),
272 found: snapshot.toolchain_version.clone(),
273 });
274 }
275
276 tracing::debug!(
277 "Loaded cache snapshot: {} cells, created at {}",
278 snapshot.cells.len(),
279 snapshot.created_at
280 );
281
282 Ok(Some(snapshot))
283 }
284
285 pub fn load_unchecked(path: &Path) -> Result<Option<CacheSnapshot>, CacheError> {
290 if !path.exists() {
291 return Ok(None);
292 }
293
294 let mut file = fs::File::open(path)?;
295 let mut bytes = Vec::new();
296 file.read_to_end(&mut bytes)?;
297
298 let archived = rkyv::access::<ArchivedCacheSnapshot, rancor::Error>(&bytes)
299 .map_err(|e| CacheError::Deserialize(e.to_string()))?;
300
301 let found_version: u32 = archived.version.into();
302 if found_version != CACHE_VERSION {
303 return Err(CacheError::VersionMismatch {
304 expected: CACHE_VERSION,
305 found: found_version,
306 });
307 }
308
309 let snapshot: CacheSnapshot =
310 rkyv::deserialize::<CacheSnapshot, rancor::Error>(archived)
311 .map_err(|e| CacheError::Deserialize(e.to_string()))?;
312
313 Ok(Some(snapshot))
314 }
315
316 pub fn invalidate(path: &Path) -> Result<(), CacheError> {
318 if path.exists() {
319 fs::remove_file(path)?;
320 tracing::debug!("Invalidated cache at {:?}", path);
321 }
322 Ok(())
323 }
324}
325
326impl CacheSnapshot {
327 pub fn new(toolchain_version: String, dependency_hash: u64) -> Self {
329 let created_at = SystemTime::now()
330 .duration_since(UNIX_EPOCH)
331 .map(|d| d.as_secs())
332 .unwrap_or(0);
333
334 Self {
335 version: CACHE_VERSION,
336 toolchain_version,
337 dependency_hash,
338 cells: HashMap::new(),
339 created_at,
340 }
341 }
342
343 pub fn add_cell(&mut self, cell: CachedCell) {
345 self.cells.insert(cell.name.clone(), cell);
346 }
347
348 pub fn get_cell(&self, name: &str) -> Option<&CachedCell> {
350 self.cells.get(name)
351 }
352
353 pub fn is_cell_valid(&self, name: &str, current_source_hash: u64) -> bool {
357 self.cells
358 .get(name)
359 .map(|c| c.source_hash == current_source_hash)
360 .unwrap_or(false)
361 }
362
363 pub fn is_dependency_valid(&self, current_hash: u64) -> bool {
365 self.dependency_hash == current_hash
366 }
367}
368
369impl CachedCell {
370 pub fn success(name: String, source_hash: u64, dylib_path: String) -> Self {
372 Self {
373 name,
374 source_hash,
375 dylib_path,
376 status: CachedCompilationStatus::Success,
377 }
378 }
379
380 pub fn cached(name: String, source_hash: u64, dylib_path: String) -> Self {
382 Self {
383 name,
384 source_hash,
385 dylib_path,
386 status: CachedCompilationStatus::Cached,
387 }
388 }
389
390 pub fn failed(name: String, source_hash: u64, error: String) -> Self {
392 Self {
393 name,
394 source_hash,
395 dylib_path: String::new(),
396 status: CachedCompilationStatus::Failed { error },
397 }
398 }
399
400 pub fn is_success(&self) -> bool {
402 matches!(
403 self.status,
404 CachedCompilationStatus::Success | CachedCompilationStatus::Cached
405 )
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use tempfile::tempdir;
413
414 #[test]
415 fn test_cache_round_trip() {
416 let dir = tempdir().unwrap();
417 let cache_path = dir.path().join("test_cache.bin");
418
419 let mut snapshot = CacheSnapshot::new("rustc 1.76.0-nightly".to_string(), 0x12345678);
421
422 snapshot.add_cell(CachedCell::success(
423 "cell_a".to_string(),
424 0xAABBCCDD,
425 "cell_a.so".to_string(),
426 ));
427
428 snapshot.add_cell(CachedCell::failed(
429 "cell_b".to_string(),
430 0x11223344,
431 "type mismatch".to_string(),
432 ));
433
434 CachePersistence::save(&cache_path, &snapshot).unwrap();
436
437 let loaded = CachePersistence::load(&cache_path, "rustc 1.76.0-nightly")
439 .unwrap()
440 .unwrap();
441
442 assert_eq!(loaded.version, CACHE_VERSION);
443 assert_eq!(loaded.toolchain_version, "rustc 1.76.0-nightly");
444 assert_eq!(loaded.dependency_hash, 0x12345678);
445 assert_eq!(loaded.cells.len(), 2);
446
447 let cell_a = loaded.get_cell("cell_a").unwrap();
448 assert_eq!(cell_a.source_hash, 0xAABBCCDD);
449 assert!(cell_a.is_success());
450
451 let cell_b = loaded.get_cell("cell_b").unwrap();
452 assert!(!cell_b.is_success());
453 assert!(matches!(
454 &cell_b.status,
455 CachedCompilationStatus::Failed { error } if error == "type mismatch"
456 ));
457 }
458
459 #[test]
460 fn test_cache_missing_file() {
461 let dir = tempdir().unwrap();
462 let cache_path = dir.path().join("nonexistent.bin");
463
464 let result = CachePersistence::load(&cache_path, "rustc 1.76.0-nightly").unwrap();
465 assert!(result.is_none());
466 }
467
468 #[test]
469 fn test_cache_toolchain_mismatch() {
470 let dir = tempdir().unwrap();
471 let cache_path = dir.path().join("test_cache.bin");
472
473 let snapshot = CacheSnapshot::new("rustc 1.76.0-nightly".to_string(), 0);
475 CachePersistence::save(&cache_path, &snapshot).unwrap();
476
477 let result = CachePersistence::load(&cache_path, "rustc 1.77.0-nightly");
479
480 assert!(matches!(
481 result,
482 Err(CacheError::ToolchainMismatch { .. })
483 ));
484 }
485
486 #[test]
487 fn test_cache_invalidation() {
488 let dir = tempdir().unwrap();
489 let cache_path = dir.path().join("test_cache.bin");
490
491 let snapshot = CacheSnapshot::new("test".to_string(), 0);
493 CachePersistence::save(&cache_path, &snapshot).unwrap();
494 assert!(cache_path.exists());
495
496 CachePersistence::invalidate(&cache_path).unwrap();
498 assert!(!cache_path.exists());
499 }
500
501 #[test]
502 fn test_cell_validity() {
503 let mut snapshot = CacheSnapshot::new("test".to_string(), 0);
504
505 snapshot.add_cell(CachedCell::success("test".to_string(), 0x1234, "".to_string()));
506
507 assert!(snapshot.is_cell_valid("test", 0x1234));
509
510 assert!(!snapshot.is_cell_valid("test", 0x5678));
512
513 assert!(!snapshot.is_cell_valid("unknown", 0x1234));
515 }
516
517 #[test]
518 fn test_dependency_validity() {
519 let snapshot = CacheSnapshot::new("test".to_string(), 0xABCD);
520
521 assert!(snapshot.is_dependency_valid(0xABCD));
522 assert!(!snapshot.is_dependency_valid(0x1234));
523 }
524}