1use std::collections::HashMap;
42use std::fs::File;
43use std::io::Write as IoWrite;
44use std::path::{Path, PathBuf};
45use std::sync::OnceLock;
46use std::time::{SystemTime, UNIX_EPOCH};
47
48use memmap2::Mmap;
49use parking_lot::Mutex;
50use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
51use std::os::unix::fs::MetadataExt;
52
53pub const SHARD_MAGIC: u32 = 0x5A525343;
56pub const SHARD_FORMAT_VERSION: u32 = 1;
58
59#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
60#[archive(check_bytes)]
61pub struct ShardHeader {
62 pub magic: u32,
63 pub format_version: u32,
64 pub zshrs_version: String,
65 pub pointer_width: u32,
66 pub built_at_secs: u64,
67}
68
69#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
70#[archive(check_bytes)]
71pub struct ScriptEntry {
72 pub mtime_secs: i64,
73 pub mtime_nsecs: i64,
74 pub binary_mtime_at_cache: i64,
75 pub cached_at_secs: i64,
76 pub chunk_blob: Vec<u8>,
77}
78
79#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
80#[archive(check_bytes)]
81pub struct ScriptShard {
82 pub header: ShardHeader,
83 pub entries: HashMap<String, ScriptEntry>,
84}
85
86pub struct MmappedShard {
89 _mmap: Mmap,
90 archived: *const ArchivedScriptShard,
91}
92
93unsafe impl Send for MmappedShard {}
96unsafe impl Sync for MmappedShard {}
97
98impl MmappedShard {
99 pub fn open(path: &Path) -> Option<Self> {
100 let file = File::open(path).ok()?;
101 let mmap = unsafe { Mmap::map(&file).ok()? };
102 let archived = rkyv::check_archived_root::<ScriptShard>(&mmap[..]).ok()?;
103 let archived_ptr = archived as *const ArchivedScriptShard;
104 Some(Self {
105 _mmap: mmap,
106 archived: archived_ptr,
107 })
108 }
109
110 fn shard(&self) -> &ArchivedScriptShard {
111 unsafe { &*self.archived }
113 }
114
115 fn header_ok(&self) -> bool {
116 let h = &self.shard().header;
117 let magic: u32 = h.magic.into();
118 let fv: u32 = h.format_version.into();
119 let pw: u32 = h.pointer_width.into();
120 magic == SHARD_MAGIC
121 && fv == SHARD_FORMAT_VERSION
122 && pw as usize == std::mem::size_of::<usize>()
123 && h.zshrs_version.as_str() == env!("CARGO_PKG_VERSION")
124 }
125
126 fn lookup(&self, path: &str) -> Option<&ArchivedScriptEntry> {
127 self.shard().entries.get(path)
128 }
129
130 fn entry_count(&self) -> usize {
131 self.shard().entries.len()
132 }
133}
134
135pub struct ScriptCache {
137 path: PathBuf,
138 lock_path: PathBuf,
139 mmap: Mutex<Option<MmappedShard>>,
140}
141
142impl ScriptCache {
143 pub fn open(path: &Path) -> std::io::Result<Self> {
144 if let Some(parent) = path.parent() {
145 std::fs::create_dir_all(parent)?;
146 }
147 let parent = path.parent().unwrap_or_else(|| Path::new("/tmp"));
148 let lock_path = parent.join(format!(
149 "{}.lock",
150 path.file_name()
151 .and_then(|s| s.to_str())
152 .unwrap_or("scripts.rkyv")
153 ));
154 Ok(Self {
155 path: path.to_path_buf(),
156 lock_path,
157 mmap: Mutex::new(None),
158 })
159 }
160
161 fn ensure_mmap(&self) {
162 let mut guard = self.mmap.lock();
163 if guard.is_none() {
164 *guard = MmappedShard::open(&self.path);
165 }
166 }
167
168 fn invalidate_mmap(&self) {
169 let mut guard = self.mmap.lock();
170 *guard = None;
171 }
172
173 pub fn get(&self, path: &str, mtime_secs: i64, mtime_nsecs: i64) -> Option<Vec<u8>> {
176 self.ensure_mmap();
177 let guard = self.mmap.lock();
178 let shard = guard.as_ref()?;
179 if !shard.header_ok() {
180 return None;
181 }
182 let entry = shard.lookup(path)?;
183
184 let entry_mtime_s: i64 = entry.mtime_secs.into();
185 let entry_mtime_ns: i64 = entry.mtime_nsecs.into();
186 if entry_mtime_s != mtime_secs || entry_mtime_ns != mtime_nsecs {
187 return None;
188 }
189
190 if let Some(bin_mtime) = current_binary_mtime_secs() {
191 let cached_bin_mtime: i64 = entry.binary_mtime_at_cache.into();
192 if cached_bin_mtime < bin_mtime {
193 return None;
194 }
195 }
196
197 Some(entry.chunk_blob.as_slice().to_vec())
198 }
199
200 pub fn put(
202 &self,
203 path: &str,
204 mtime_secs: i64,
205 mtime_nsecs: i64,
206 chunk_blob: Vec<u8>,
207 ) -> Result<(), String> {
208 let _lock = match acquire_lock(&self.lock_path) {
209 Some(l) => l,
210 None => return Ok(()),
211 };
212
213 let mut shard = match read_owned_shard(&self.path) {
214 Some(s)
215 if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
216 && s.header.pointer_width as usize == std::mem::size_of::<usize>()
217 && s.header.format_version == SHARD_FORMAT_VERSION =>
218 {
219 s
220 }
221 _ => fresh_shard(),
222 };
223
224 let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
225 let entry = ScriptEntry {
226 mtime_secs,
227 mtime_nsecs,
228 binary_mtime_at_cache: bin_mtime,
229 cached_at_secs: now_secs(),
230 chunk_blob,
231 };
232 shard.entries.insert(path.to_string(), entry);
233 shard.header.built_at_secs = now_secs() as u64;
234
235 write_shard_atomic(&self.path, &shard)?;
236 self.invalidate_mmap();
237 Ok(())
238 }
239
240 pub fn stats(&self) -> (i64, i64) {
242 self.ensure_mmap();
243 let guard = self.mmap.lock();
244 let Some(shard) = guard.as_ref() else {
245 return (0, 0);
246 };
247 let count = shard.entry_count() as i64;
248 let bytes: i64 = shard
249 .shard()
250 .entries
251 .values()
252 .map(|e| e.chunk_blob.len() as i64)
253 .sum();
254 (count, bytes)
255 }
256
257 pub fn list_scripts(&self) -> Vec<(String, f64, String, String)> {
260 self.ensure_mmap();
261 let guard = self.mmap.lock();
262 let Some(shard) = guard.as_ref() else {
263 return Vec::new();
264 };
265 let v = shard.shard().header.zshrs_version.as_str().to_string();
266 let mut out: Vec<(String, f64, String, String, i64)> = shard
267 .shard()
268 .entries
269 .iter()
270 .map(|(k, e)| {
271 let chunk_kb = e.chunk_blob.len() as f64 / 1024.0;
272 let cached_at: i64 = e.cached_at_secs.into();
273 let ts = format_local_ts(cached_at);
274 (
275 k.as_str().to_string(),
276 chunk_kb,
277 v.clone(),
278 ts,
279 cached_at,
280 )
281 })
282 .collect();
283 out.sort_by_key(|x| std::cmp::Reverse(x.4));
284 out.into_iter()
285 .map(|(p, ck, ver, ts, _)| (p, ck, ver, ts))
286 .collect()
287 }
288
289 pub fn evict_stale(&self) -> usize {
291 let _lock = match acquire_lock(&self.lock_path) {
292 Some(l) => l,
293 None => return 0,
294 };
295 let mut shard = match read_owned_shard(&self.path) {
296 Some(s) => s,
297 None => return 0,
298 };
299 let before = shard.entries.len();
300 shard.entries.retain(|p, e| match file_mtime(Path::new(p)) {
301 Some((s, ns)) => s == e.mtime_secs && ns == e.mtime_nsecs,
302 None => false,
303 });
304 let evicted = before - shard.entries.len();
305 if evicted > 0 {
306 let _ = write_shard_atomic(&self.path, &shard);
307 self.invalidate_mmap();
308 }
309 evicted
310 }
311
312 pub fn clear(&self) -> std::io::Result<()> {
313 let _lock = acquire_lock(&self.lock_path);
314 let res = match std::fs::remove_file(&self.path) {
315 Ok(()) => Ok(()),
316 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
317 Err(e) => Err(e),
318 };
319 self.invalidate_mmap();
320 res
321 }
322}
323
324fn acquire_lock(path: &Path) -> Option<nix::fcntl::Flock<File>> {
325 let f = File::options()
326 .read(true)
327 .write(true)
328 .create(true)
329 .truncate(false)
330 .open(path)
331 .ok()?;
332 nix::fcntl::Flock::lock(f, nix::fcntl::FlockArg::LockExclusive).ok()
333}
334
335fn fresh_shard() -> ScriptShard {
336 ScriptShard {
337 header: ShardHeader {
338 magic: SHARD_MAGIC,
339 format_version: SHARD_FORMAT_VERSION,
340 zshrs_version: env!("CARGO_PKG_VERSION").to_string(),
341 pointer_width: std::mem::size_of::<usize>() as u32,
342 built_at_secs: now_secs() as u64,
343 },
344 entries: HashMap::new(),
345 }
346}
347
348fn read_owned_shard(path: &Path) -> Option<ScriptShard> {
349 let bytes = std::fs::read(path).ok()?;
350 let archived = rkyv::check_archived_root::<ScriptShard>(&bytes[..]).ok()?;
351 archived.deserialize(&mut rkyv::Infallible).ok()
352}
353
354fn write_shard_atomic(path: &Path, shard: &ScriptShard) -> Result<(), String> {
355 let bytes = rkyv::to_bytes::<_, 4096>(shard)
356 .map_err(|e| format!("rkyv serialize: {}", e))?;
357
358 let parent = path.parent().expect("cache path has parent");
359 let _ = std::fs::create_dir_all(parent);
360
361 let pid = std::process::id();
362 let nanos = SystemTime::now()
363 .duration_since(UNIX_EPOCH)
364 .map(|d| d.as_nanos())
365 .unwrap_or(0);
366 let tmp_path = parent.join(format!(
367 "{}.tmp.{}.{}",
368 path.file_name()
369 .and_then(|s| s.to_str())
370 .unwrap_or("scripts.rkyv"),
371 pid,
372 nanos
373 ));
374
375 {
376 let mut f = File::create(&tmp_path).map_err(|e| e.to_string())?;
377 f.write_all(&bytes).map_err(|e| e.to_string())?;
378 f.sync_all().map_err(|e| e.to_string())?;
379 }
380
381 std::fs::rename(&tmp_path, path).map_err(|e| e.to_string())?;
382 Ok(())
383}
384
385fn now_secs() -> i64 {
386 SystemTime::now()
387 .duration_since(UNIX_EPOCH)
388 .map(|d| d.as_secs() as i64)
389 .unwrap_or(0)
390}
391
392fn format_local_ts(secs: i64) -> String {
393 let dt = chrono::DateTime::<chrono::Local>::from(
394 UNIX_EPOCH + std::time::Duration::from_secs(secs.max(0) as u64),
395 );
396 dt.format("%Y-%m-%d %H:%M:%S").to_string()
397}
398
399pub fn file_mtime(path: &Path) -> Option<(i64, i64)> {
400 let meta = std::fs::metadata(path).ok()?;
401 Some((meta.mtime(), meta.mtime_nsec()))
402}
403
404fn current_binary_mtime_secs() -> Option<i64> {
405 static BIN_MTIME: OnceLock<Option<i64>> = OnceLock::new();
406 *BIN_MTIME.get_or_init(|| {
407 let exe = std::env::current_exe().ok()?;
408 let (secs, _) = file_mtime(&exe)?;
409 Some(secs)
410 })
411}
412
413pub fn default_cache_path() -> PathBuf {
415 dirs::home_dir()
416 .unwrap_or_else(|| PathBuf::from("/tmp"))
417 .join(".zshrs/scripts.rkyv")
418}
419
420pub fn cache_enabled() -> bool {
422 !matches!(
423 std::env::var("ZSHRS_CACHE").as_deref(),
424 Ok("0") | Ok("false") | Ok("no")
425 )
426}
427
428pub static CACHE: once_cell::sync::Lazy<Option<ScriptCache>> = once_cell::sync::Lazy::new(|| {
431 if !cache_enabled() {
432 return None;
433 }
434 ScriptCache::open(&default_cache_path()).ok()
435});
436
437pub fn try_load_bytes(path: &Path) -> Option<Vec<u8>> {
439 let cache = CACHE.as_ref()?;
440 let canonical = path.canonicalize().ok()?;
441 let path_str = canonical.to_string_lossy();
442 let (mtime_s, mtime_ns) = file_mtime(&canonical)?;
443 cache.get(&path_str, mtime_s, mtime_ns)
444}
445
446pub fn try_save_bytes(path: &Path, chunk_blob: &[u8]) -> Result<(), String> {
450 let Some(cache) = CACHE.as_ref() else {
451 return Ok(());
452 };
453 let canonical = match path.canonicalize() {
454 Ok(p) => p,
455 Err(_) => return Ok(()),
456 };
457 let path_str = canonical.to_string_lossy();
458 let (mtime_s, mtime_ns) = match file_mtime(&canonical) {
459 Some(m) => m,
460 None => return Ok(()),
461 };
462 cache.put(&path_str, mtime_s, mtime_ns, chunk_blob.to_vec())
463}
464
465pub fn stats() -> Option<(i64, i64)> {
466 CACHE.as_ref().map(|c| c.stats())
467}
468
469pub fn evict_stale() -> usize {
470 CACHE.as_ref().map(|c| c.evict_stale()).unwrap_or(0)
471}
472
473pub fn clear() -> bool {
474 CACHE.as_ref().map(|c| c.clear().is_ok()).unwrap_or(false)
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use tempfile::tempdir;
481
482 #[test]
483 fn round_trip() {
484 let dir = tempdir().unwrap();
485 let cache_path = dir.path().join("scripts.rkyv");
486 let cache = ScriptCache::open(&cache_path).unwrap();
487
488 let script_path = dir.path().join("test.zsh");
489 std::fs::write(&script_path, "echo hi").unwrap();
490
491 let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
492 let path_str = script_path.to_string_lossy().to_string();
493
494 let blob = vec![1u8, 2, 3, 4, 5];
495 cache.put(&path_str, mtime_s, mtime_ns, blob.clone()).unwrap();
496
497 let loaded = cache.get(&path_str, mtime_s, mtime_ns).unwrap();
498 assert_eq!(loaded, blob);
499
500 let (count, _bytes) = cache.stats();
501 assert_eq!(count, 1);
502 }
503
504 #[test]
505 fn mtime_invalidation() {
506 let dir = tempdir().unwrap();
507 let cache_path = dir.path().join("scripts.rkyv");
508 let cache = ScriptCache::open(&cache_path).unwrap();
509
510 let script_path = dir.path().join("test.zsh");
511 std::fs::write(&script_path, "echo hi").unwrap();
512
513 let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
514 let path_str = script_path.to_string_lossy().to_string();
515 cache.put(&path_str, mtime_s, mtime_ns, vec![9u8]).unwrap();
516
517 assert!(cache.get(&path_str, mtime_s + 1, mtime_ns).is_none());
518 }
519
520 #[test]
521 fn second_put_replaces_first() {
522 let dir = tempdir().unwrap();
523 let cache_path = dir.path().join("scripts.rkyv");
524 let cache = ScriptCache::open(&cache_path).unwrap();
525
526 let p1 = dir.path().join("a.zsh");
527 let p2 = dir.path().join("b.zsh");
528 std::fs::write(&p1, "1").unwrap();
529 std::fs::write(&p2, "2").unwrap();
530
531 let (m1s, m1n) = file_mtime(&p1).unwrap();
532 let (m2s, m2n) = file_mtime(&p2).unwrap();
533
534 cache.put(&p1.to_string_lossy(), m1s, m1n, vec![1u8]).unwrap();
535 cache.put(&p2.to_string_lossy(), m2s, m2n, vec![2u8]).unwrap();
536
537 let (count, _) = cache.stats();
538 assert_eq!(count, 2);
539 assert!(cache.get(&p1.to_string_lossy(), m1s, m1n).is_some());
540 assert!(cache.get(&p2.to_string_lossy(), m2s, m2n).is_some());
541 }
542
543 #[test]
544 fn corrupt_file_returns_no_mmap() {
545 let dir = tempdir().unwrap();
546 let cache_path = dir.path().join("scripts.rkyv");
547 std::fs::write(&cache_path, b"this is not a valid rkyv archive").unwrap();
548 let cache = ScriptCache::open(&cache_path).unwrap();
549 assert!(cache.get("/nope", 0, 0).is_none());
550 }
551
552 #[test]
553 fn clear_removes_file() {
554 let dir = tempdir().unwrap();
555 let cache_path = dir.path().join("scripts.rkyv");
556 let cache = ScriptCache::open(&cache_path).unwrap();
557
558 let script_path = dir.path().join("test.zsh");
559 std::fs::write(&script_path, "echo hi").unwrap();
560 let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
561 cache.put(&script_path.to_string_lossy(), mtime_s, mtime_ns, vec![7u8]).unwrap();
562 assert!(cache_path.exists());
563
564 cache.clear().unwrap();
565 assert!(!cache_path.exists());
566 }
567}