Skip to main content

zsh/extensions/
script_cache.rs

1//! rkyv-backed bytecode cache for zsh scripts.
2//!
3//! Single-file shard at `~/.zshrs/scripts.rkyv`. On 2+ runs of a given
4//! script, lex/parse/compile is skipped — the cache hit is `mmap` + zero-copy
5//! `ArchivedHashMap` lookup + bincode-decode of the inner `fusevm::Chunk` blob.
6//!
7//! Storage layout (rkyv archived):
8//!   ScriptShard {
9//!     header: { magic, format_version, zshrs_version, pointer_width, built_at_secs },
10//!     entries: HashMap<canonical_path, ScriptEntry>,
11//!   }
12//!   ScriptEntry { mtime_secs, mtime_nsecs, binary_mtime_at_cache, cached_at_secs,
13//!                 chunk_blob: `Vec<u8>` }
14//!
15//! Inner `chunk_blob` is bincode for now — `fusevm::Chunk` is owned by the
16//! upstream `fusevm` crate and only derives `serde::Serialize`/`Deserialize`,
17//! not `rkyv::Archive`, so the inner codec stays bincode inside the rkyv outer
18//! container. Direct rkyv on Chunk would require either forking fusevm or a
19//! mirror archived type — both are large refactors and not needed for the
20//! current "kill SQLite for bytecode" goal.
21//!
22//! Read path:
23//!   - Lazy `mmap` of the shard, kept alive for the process lifetime so repeat
24//!     lookups pay validation once.
25//!   - `rkyv::check_archived_root::<ScriptShard>` validates the byte image.
26//!   - Header validated for magic / format_version / zshrs_version / pointer_width.
27//!   - Per-entry: source mtime must match, and `binary_mtime_at_cache` ≥ running
28//!     zshrs binary's mtime (any rebuild of zshrs invalidates entries silently).
29//!
30//! Write path:
31//!   - `bin_zsystem_flock(LOCK_EX)` on `scripts.rkyv.lock` so concurrent writers serialize.
32//!   - Read existing shard into owned form, mutate, `rkyv::to_bytes`,
33//!     write to `scripts.rkyv.tmp.<pid>.<nanos>`, fsync, atomic-rename.
34//!   - Drop the in-process `mmap` so the next read picks up the new shard.
35//!
36//! Ported from `strykelang/strykelang/script_cache.rs` (the user's stryke
37//! language has the same caching pattern; this is the same shape with `ZRSC`
38//! magic, zshrs version pin, and a single `chunk_blob` per entry — zshrs has
39//! no separate AST cache).
40
41use std::collections::HashMap;
42use std::fs::File;
43use std::io::Write as IoWrite;
44use std::path::{Path, PathBuf};
45use std::sync::OnceLock;
46use std::time::{SystemTime, UNIX_EPOCH};
47
48use memmap2::Mmap;
49use parking_lot::Mutex;
50use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
51use std::os::unix::fs::MetadataExt;
52
53/// Magic header bytes — fail-fast if a wrong-format file is mmap'd.
54/// "ZRSC" little-endian.
55pub const SHARD_MAGIC: u32 = 0x5A525343;
56/// Bumped on incompatible rkyv schema changes.
57pub const SHARD_FORMAT_VERSION: u32 = 1;
58
59#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
60#[archive(check_bytes)]
61pub struct ShardHeader {
62    pub magic: u32,
63    pub format_version: u32,
64    pub zshrs_version: String,
65    pub pointer_width: u32,
66    pub built_at_secs: u64,
67}
68
69#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
70#[archive(check_bytes)]
71pub struct ScriptEntry {
72    pub mtime_secs: i64,
73    pub mtime_nsecs: i64,
74    pub binary_mtime_at_cache: i64,
75    pub cached_at_secs: i64,
76    pub chunk_blob: Vec<u8>,
77}
78
79#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
80#[archive(check_bytes)]
81pub struct ScriptShard {
82    pub header: ShardHeader,
83    pub entries: HashMap<String, ScriptEntry>,
84}
85
86/// mmap + validated `*const ArchivedScriptShard`. Self-referential — the pointer
87/// is valid for the lifetime of the wrapping struct.
88pub struct MmappedShard {
89    _mmap: Mmap,
90    archived: *const ArchivedScriptShard,
91}
92
93// SAFETY: the pointer aliases an immutable mmap that lives as long as Self.
94// rkyv-validated reads are immutable.
95unsafe impl Send for MmappedShard {}
96unsafe impl Sync for MmappedShard {}
97
98impl MmappedShard {
99    pub fn open(path: &Path) -> Option<Self> {
100        let file = File::open(path).ok()?;
101        let mmap = unsafe { Mmap::map(&file).ok()? };
102        let archived = rkyv::check_archived_root::<ScriptShard>(&mmap[..]).ok()?;
103        let archived_ptr = archived as *const ArchivedScriptShard;
104        Some(Self {
105            _mmap: mmap,
106            archived: archived_ptr,
107        })
108    }
109
110    fn shard(&self) -> &ArchivedScriptShard {
111        // SAFETY: see Self impl comment.
112        unsafe { &*self.archived }
113    }
114
115    fn header_ok(&self) -> bool {
116        let h = &self.shard().header;
117        let magic: u32 = h.magic.into();
118        let fv: u32 = h.format_version.into();
119        let pw: u32 = h.pointer_width.into();
120        magic == SHARD_MAGIC
121            && fv == SHARD_FORMAT_VERSION
122            && pw as usize == std::mem::size_of::<usize>()
123            && h.zshrs_version.as_str() == env!("CARGO_PKG_VERSION")
124    }
125
126    fn lookup(&self, path: &str) -> Option<&ArchivedScriptEntry> {
127        self.shard().entries.get(path)
128    }
129
130    fn entry_count(&self) -> usize {
131        self.shard().entries.len()
132    }
133}
134
135/// Shard cache keyed by canonical script path. One per shard file.
136pub struct ScriptCache {
137    path: PathBuf,
138    lock_path: PathBuf,
139    mmap: Mutex<Option<MmappedShard>>,
140}
141
142impl ScriptCache {
143    pub fn open(path: &Path) -> std::io::Result<Self> {
144        if let Some(parent) = path.parent() {
145            std::fs::create_dir_all(parent)?;
146        }
147        let parent = path.parent().unwrap_or_else(|| Path::new("/tmp"));
148        let lock_path = parent.join(format!(
149            "{}.lock",
150            path.file_name()
151                .and_then(|s| s.to_str())
152                .unwrap_or("scripts.rkyv")
153        ));
154        Ok(Self {
155            path: path.to_path_buf(),
156            lock_path,
157            mmap: Mutex::new(None),
158        })
159    }
160
161    fn ensure_mmap(&self) {
162        let mut guard = self.mmap.lock();
163        if guard.is_none() {
164            *guard = MmappedShard::open(&self.path);
165        }
166    }
167
168    fn invalidate_mmap(&self) {
169        let mut guard = self.mmap.lock();
170        *guard = None;
171    }
172
173    /// Cache lookup. Returns `None` on miss, mtime mismatch, version drift, or
174    /// zshrs binary newer than the cached entry.
175    pub fn get(&self, path: &str, mtime_secs: i64, mtime_nsecs: i64) -> Option<Vec<u8>> {
176        self.ensure_mmap();
177        let guard = self.mmap.lock();
178        let shard = guard.as_ref()?;
179        if !shard.header_ok() {
180            return None;
181        }
182        let entry = shard.lookup(path)?;
183
184        let entry_mtime_s: i64 = entry.mtime_secs.into();
185        let entry_mtime_ns: i64 = entry.mtime_nsecs.into();
186        if entry_mtime_s != mtime_secs || entry_mtime_ns != mtime_nsecs {
187            return None;
188        }
189
190        if let Some(bin_mtime) = current_binary_mtime_secs() {
191            let cached_bin_mtime: i64 = entry.binary_mtime_at_cache.into();
192            if cached_bin_mtime < bin_mtime {
193                return None;
194            }
195        }
196
197        Some(entry.chunk_blob.as_slice().to_vec())
198    }
199
200    /// Insert / replace an entry. Serializes the whole shard and atomic-renames.
201    pub fn put(
202        &self,
203        path: &str,
204        mtime_secs: i64,
205        mtime_nsecs: i64,
206        chunk_blob: Vec<u8>,
207    ) -> Result<(), String> {
208        let _lock = match acquire_lock(&self.lock_path) {
209            Some(l) => l,
210            None => return Ok(()),
211        };
212
213        let mut shard = match read_owned_shard(&self.path) {
214            Some(s)
215                if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
216                    && s.header.pointer_width as usize == std::mem::size_of::<usize>()
217                    && s.header.format_version == SHARD_FORMAT_VERSION =>
218            {
219                s
220            }
221            _ => fresh_shard(),
222        };
223
224        let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
225        let entry = ScriptEntry {
226            mtime_secs,
227            mtime_nsecs,
228            binary_mtime_at_cache: bin_mtime,
229            cached_at_secs: now_secs(),
230            chunk_blob,
231        };
232        shard.entries.insert(path.to_string(), entry);
233        shard.header.built_at_secs = now_secs() as u64;
234
235        write_shard_atomic(&self.path, &shard)?;
236        self.invalidate_mmap();
237        Ok(())
238    }
239
240    /// `(count, total_blob_bytes)` snapshot.
241    pub fn stats(&self) -> (i64, i64) {
242        self.ensure_mmap();
243        let guard = self.mmap.lock();
244        let Some(shard) = guard.as_ref() else {
245            return (0, 0);
246        };
247        let count = shard.entry_count() as i64;
248        let bytes: i64 = shard
249            .shard()
250            .entries
251            .values()
252            .map(|e| e.chunk_blob.len() as i64)
253            .sum();
254        (count, bytes)
255    }
256
257    /// `(path, chunk_kb, version, cached_at_localstr)` per entry,
258    /// sorted by `cached_at` desc.
259    pub fn list_scripts(&self) -> Vec<(String, f64, String, String)> {
260        self.ensure_mmap();
261        let guard = self.mmap.lock();
262        let Some(shard) = guard.as_ref() else {
263            return Vec::new();
264        };
265        let v = shard.shard().header.zshrs_version.as_str().to_string();
266        let mut out: Vec<(String, f64, String, String, i64)> = shard
267            .shard()
268            .entries
269            .iter()
270            .map(|(k, e)| {
271                let chunk_kb = e.chunk_blob.len() as f64 / 1024.0;
272                let cached_at: i64 = e.cached_at_secs.into();
273                let ts = format_local_ts(cached_at);
274                (
275                    k.as_str().to_string(),
276                    chunk_kb,
277                    v.clone(),
278                    ts,
279                    cached_at,
280                )
281            })
282            .collect();
283        out.sort_by_key(|x| std::cmp::Reverse(x.4));
284        out.into_iter()
285            .map(|(p, ck, ver, ts, _)| (p, ck, ver, ts))
286            .collect()
287    }
288
289    /// Drop entries whose source file vanished or whose mtime changed.
290    pub fn evict_stale(&self) -> usize {
291        let _lock = match acquire_lock(&self.lock_path) {
292            Some(l) => l,
293            None => return 0,
294        };
295        let mut shard = match read_owned_shard(&self.path) {
296            Some(s) => s,
297            None => return 0,
298        };
299        let before = shard.entries.len();
300        shard.entries.retain(|p, e| match file_mtime(Path::new(p)) {
301            Some((s, ns)) => s == e.mtime_secs && ns == e.mtime_nsecs,
302            None => false,
303        });
304        let evicted = before - shard.entries.len();
305        if evicted > 0 {
306            let _ = write_shard_atomic(&self.path, &shard);
307            self.invalidate_mmap();
308        }
309        evicted
310    }
311
312    pub fn clear(&self) -> std::io::Result<()> {
313        let _lock = acquire_lock(&self.lock_path);
314        let res = match std::fs::remove_file(&self.path) {
315            Ok(()) => Ok(()),
316            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
317            Err(e) => Err(e),
318        };
319        self.invalidate_mmap();
320        res
321    }
322}
323
324fn acquire_lock(path: &Path) -> Option<nix::fcntl::Flock<File>> {
325    let f = File::options()
326        .read(true)
327        .write(true)
328        .create(true)
329        .truncate(false)
330        .open(path)
331        .ok()?;
332    nix::fcntl::Flock::lock(f, nix::fcntl::FlockArg::LockExclusive).ok()
333}
334
335fn fresh_shard() -> ScriptShard {
336    ScriptShard {
337        header: ShardHeader {
338            magic: SHARD_MAGIC,
339            format_version: SHARD_FORMAT_VERSION,
340            zshrs_version: env!("CARGO_PKG_VERSION").to_string(),
341            pointer_width: std::mem::size_of::<usize>() as u32,
342            built_at_secs: now_secs() as u64,
343        },
344        entries: HashMap::new(),
345    }
346}
347
348fn read_owned_shard(path: &Path) -> Option<ScriptShard> {
349    let bytes = std::fs::read(path).ok()?;
350    let archived = rkyv::check_archived_root::<ScriptShard>(&bytes[..]).ok()?;
351    archived.deserialize(&mut rkyv::Infallible).ok()
352}
353
354fn write_shard_atomic(path: &Path, shard: &ScriptShard) -> Result<(), String> {
355    let bytes = rkyv::to_bytes::<_, 4096>(shard)
356        .map_err(|e| format!("rkyv serialize: {}", e))?;
357
358    let parent = path.parent().expect("cache path has parent");
359    let _ = std::fs::create_dir_all(parent);
360
361    let pid = std::process::id();
362    let nanos = SystemTime::now()
363        .duration_since(UNIX_EPOCH)
364        .map(|d| d.as_nanos())
365        .unwrap_or(0);
366    let tmp_path = parent.join(format!(
367        "{}.tmp.{}.{}",
368        path.file_name()
369            .and_then(|s| s.to_str())
370            .unwrap_or("scripts.rkyv"),
371        pid,
372        nanos
373    ));
374
375    {
376        let mut f = File::create(&tmp_path).map_err(|e| e.to_string())?;
377        f.write_all(&bytes).map_err(|e| e.to_string())?;
378        f.sync_all().map_err(|e| e.to_string())?;
379    }
380
381    std::fs::rename(&tmp_path, path).map_err(|e| e.to_string())?;
382    Ok(())
383}
384
385fn now_secs() -> i64 {
386    SystemTime::now()
387        .duration_since(UNIX_EPOCH)
388        .map(|d| d.as_secs() as i64)
389        .unwrap_or(0)
390}
391
392fn format_local_ts(secs: i64) -> String {
393    let dt = chrono::DateTime::<chrono::Local>::from(
394        UNIX_EPOCH + std::time::Duration::from_secs(secs.max(0) as u64),
395    );
396    dt.format("%Y-%m-%d %H:%M:%S").to_string()
397}
398
399pub fn file_mtime(path: &Path) -> Option<(i64, i64)> {
400    let meta = std::fs::metadata(path).ok()?;
401    Some((meta.mtime(), meta.mtime_nsec()))
402}
403
404fn current_binary_mtime_secs() -> Option<i64> {
405    static BIN_MTIME: OnceLock<Option<i64>> = OnceLock::new();
406    *BIN_MTIME.get_or_init(|| {
407        let exe = std::env::current_exe().ok()?;
408        let (secs, _) = file_mtime(&exe)?;
409        Some(secs)
410    })
411}
412
413/// Default shard path: `~/.zshrs/scripts.rkyv`.
414pub fn default_cache_path() -> PathBuf {
415    dirs::home_dir()
416        .unwrap_or_else(|| PathBuf::from("/tmp"))
417        .join(".zshrs/scripts.rkyv")
418}
419
420/// `ZSHRS_CACHE=0|false|no` disables the cache entirely.
421pub fn cache_enabled() -> bool {
422    !matches!(
423        std::env::var("ZSHRS_CACHE").as_deref(),
424        Ok("0") | Ok("false") | Ok("no")
425    )
426}
427
428/// Process-wide `ScriptCache` rooted at `default_cache_path()`. `None` when the
429/// cache is disabled or the path could not be opened.
430pub static CACHE: once_cell::sync::Lazy<Option<ScriptCache>> = once_cell::sync::Lazy::new(|| {
431    if !cache_enabled() {
432        return None;
433    }
434    ScriptCache::open(&default_cache_path()).ok()
435});
436
437/// Try to load cached chunk-bytes by source path. Returns `None` on any miss.
438pub fn try_load_bytes(path: &Path) -> Option<Vec<u8>> {
439    let cache = CACHE.as_ref()?;
440    let canonical = path.canonicalize().ok()?;
441    let path_str = canonical.to_string_lossy();
442    let (mtime_s, mtime_ns) = file_mtime(&canonical)?;
443    cache.get(&path_str, mtime_s, mtime_ns)
444}
445
446/// Store bincode-encoded `fusevm::Chunk` bytes for a script path. Best-effort —
447/// cache disabled / canonicalize failure / mtime stat failure all return
448/// `Ok(())` silently so the caller can fire-and-forget.
449pub fn try_save_bytes(path: &Path, chunk_blob: &[u8]) -> Result<(), String> {
450    let Some(cache) = CACHE.as_ref() else {
451        return Ok(());
452    };
453    let canonical = match path.canonicalize() {
454        Ok(p) => p,
455        Err(_) => return Ok(()),
456    };
457    let path_str = canonical.to_string_lossy();
458    let (mtime_s, mtime_ns) = match file_mtime(&canonical) {
459        Some(m) => m,
460        None => return Ok(()),
461    };
462    cache.put(&path_str, mtime_s, mtime_ns, chunk_blob.to_vec())
463}
464
465pub fn stats() -> Option<(i64, i64)> {
466    CACHE.as_ref().map(|c| c.stats())
467}
468
469pub fn evict_stale() -> usize {
470    CACHE.as_ref().map(|c| c.evict_stale()).unwrap_or(0)
471}
472
473pub fn clear() -> bool {
474    CACHE.as_ref().map(|c| c.clear().is_ok()).unwrap_or(false)
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use tempfile::tempdir;
481
482    #[test]
483    fn round_trip() {
484        let dir = tempdir().unwrap();
485        let cache_path = dir.path().join("scripts.rkyv");
486        let cache = ScriptCache::open(&cache_path).unwrap();
487
488        let script_path = dir.path().join("test.zsh");
489        std::fs::write(&script_path, "echo hi").unwrap();
490
491        let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
492        let path_str = script_path.to_string_lossy().to_string();
493
494        let blob = vec![1u8, 2, 3, 4, 5];
495        cache.put(&path_str, mtime_s, mtime_ns, blob.clone()).unwrap();
496
497        let loaded = cache.get(&path_str, mtime_s, mtime_ns).unwrap();
498        assert_eq!(loaded, blob);
499
500        let (count, _bytes) = cache.stats();
501        assert_eq!(count, 1);
502    }
503
504    #[test]
505    fn mtime_invalidation() {
506        let dir = tempdir().unwrap();
507        let cache_path = dir.path().join("scripts.rkyv");
508        let cache = ScriptCache::open(&cache_path).unwrap();
509
510        let script_path = dir.path().join("test.zsh");
511        std::fs::write(&script_path, "echo hi").unwrap();
512
513        let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
514        let path_str = script_path.to_string_lossy().to_string();
515        cache.put(&path_str, mtime_s, mtime_ns, vec![9u8]).unwrap();
516
517        assert!(cache.get(&path_str, mtime_s + 1, mtime_ns).is_none());
518    }
519
520    #[test]
521    fn second_put_replaces_first() {
522        let dir = tempdir().unwrap();
523        let cache_path = dir.path().join("scripts.rkyv");
524        let cache = ScriptCache::open(&cache_path).unwrap();
525
526        let p1 = dir.path().join("a.zsh");
527        let p2 = dir.path().join("b.zsh");
528        std::fs::write(&p1, "1").unwrap();
529        std::fs::write(&p2, "2").unwrap();
530
531        let (m1s, m1n) = file_mtime(&p1).unwrap();
532        let (m2s, m2n) = file_mtime(&p2).unwrap();
533
534        cache.put(&p1.to_string_lossy(), m1s, m1n, vec![1u8]).unwrap();
535        cache.put(&p2.to_string_lossy(), m2s, m2n, vec![2u8]).unwrap();
536
537        let (count, _) = cache.stats();
538        assert_eq!(count, 2);
539        assert!(cache.get(&p1.to_string_lossy(), m1s, m1n).is_some());
540        assert!(cache.get(&p2.to_string_lossy(), m2s, m2n).is_some());
541    }
542
543    #[test]
544    fn corrupt_file_returns_no_mmap() {
545        let dir = tempdir().unwrap();
546        let cache_path = dir.path().join("scripts.rkyv");
547        std::fs::write(&cache_path, b"this is not a valid rkyv archive").unwrap();
548        let cache = ScriptCache::open(&cache_path).unwrap();
549        assert!(cache.get("/nope", 0, 0).is_none());
550    }
551
552    #[test]
553    fn clear_removes_file() {
554        let dir = tempdir().unwrap();
555        let cache_path = dir.path().join("scripts.rkyv");
556        let cache = ScriptCache::open(&cache_path).unwrap();
557
558        let script_path = dir.path().join("test.zsh");
559        std::fs::write(&script_path, "echo hi").unwrap();
560        let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
561        cache.put(&script_path.to_string_lossy(), mtime_s, mtime_ns, vec![7u8]).unwrap();
562        assert!(cache_path.exists());
563
564        cache.clear().unwrap();
565        assert!(!cache_path.exists());
566    }
567}