1use std::collections::HashMap;
37use std::fs::File;
38use std::io::Write as IoWrite;
39use std::path::{Path, PathBuf};
40use std::sync::OnceLock;
41use std::time::{SystemTime, UNIX_EPOCH};
42
43use memmap2::Mmap;
44use parking_lot::Mutex;
45use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
46use std::os::unix::fs::MetadataExt;
47
48pub const SHARD_MAGIC: u32 = 0x5A52414C;
50pub const SHARD_FORMAT_VERSION: u32 = 1;
51
52#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
53#[archive(check_bytes)]
54pub struct ShardHeader {
55 pub magic: u32,
56 pub format_version: u32,
57 pub zshrs_version: String,
58 pub pointer_width: u32,
59 pub built_at_secs: u64,
60}
61
62#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
63#[archive(check_bytes)]
64pub struct AutoloadEntry {
65 pub binary_mtime_at_cache: i64,
66 pub cached_at_secs: i64,
67 pub chunk_blob: Vec<u8>,
68}
69
70#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
71#[archive(check_bytes)]
72pub struct AutoloadShard {
73 pub header: ShardHeader,
74 pub entries: HashMap<String, AutoloadEntry>,
75}
76
77pub struct MmappedShard {
78 _mmap: Mmap,
79 archived: *const ArchivedAutoloadShard,
80}
81
82unsafe impl Send for MmappedShard {}
83unsafe impl Sync for MmappedShard {}
84
85impl MmappedShard {
86 pub fn open(path: &Path) -> Option<Self> {
87 let file = File::open(path).ok()?;
88 let mmap = unsafe { Mmap::map(&file).ok()? };
89 let archived = rkyv::check_archived_root::<AutoloadShard>(&mmap[..]).ok()?;
90 let archived_ptr = archived as *const ArchivedAutoloadShard;
91 Some(Self {
92 _mmap: mmap,
93 archived: archived_ptr,
94 })
95 }
96
97 fn shard(&self) -> &ArchivedAutoloadShard {
98 unsafe { &*self.archived }
99 }
100
101 fn header_ok(&self) -> bool {
102 let h = &self.shard().header;
103 let magic: u32 = h.magic.into();
104 let fv: u32 = h.format_version.into();
105 let pw: u32 = h.pointer_width.into();
106 magic == SHARD_MAGIC
107 && fv == SHARD_FORMAT_VERSION
108 && pw as usize == std::mem::size_of::<usize>()
109 && h.zshrs_version.as_str() == env!("CARGO_PKG_VERSION")
110 }
111
112 fn lookup(&self, name: &str) -> Option<&ArchivedAutoloadEntry> {
113 self.shard().entries.get(name)
114 }
115}
116
117pub struct AutoloadCache {
118 path: PathBuf,
119 lock_path: PathBuf,
120 mmap: Mutex<Option<MmappedShard>>,
121}
122
123impl AutoloadCache {
124 pub fn open(path: &Path) -> std::io::Result<Self> {
125 if let Some(parent) = path.parent() {
126 std::fs::create_dir_all(parent)?;
127 }
128 let parent = path.parent().unwrap_or_else(|| Path::new("/tmp"));
129 let lock_path = parent.join(format!(
130 "{}.lock",
131 path.file_name()
132 .and_then(|s| s.to_str())
133 .unwrap_or("autoloads.rkyv")
134 ));
135 Ok(Self {
136 path: path.to_path_buf(),
137 lock_path,
138 mmap: Mutex::new(None),
139 })
140 }
141
142 fn ensure_mmap(&self) {
143 let mut guard = self.mmap.lock();
144 if guard.is_none() {
145 *guard = MmappedShard::open(&self.path);
146 }
147 }
148
149 fn invalidate_mmap(&self) {
150 let mut guard = self.mmap.lock();
151 *guard = None;
152 }
153
154 pub fn get(&self, name: &str) -> Option<Vec<u8>> {
155 self.ensure_mmap();
156 let guard = self.mmap.lock();
157 let shard = guard.as_ref()?;
158 if !shard.header_ok() {
159 return None;
160 }
161 let entry = shard.lookup(name)?;
162 if let Some(bin_mtime) = current_binary_mtime_secs() {
163 let cached_bin_mtime: i64 = entry.binary_mtime_at_cache.into();
164 if cached_bin_mtime < bin_mtime {
165 return None;
166 }
167 }
168 Some(entry.chunk_blob.as_slice().to_vec())
169 }
170
171 pub fn put_one(&self, name: &str, chunk_blob: Vec<u8>) -> Result<(), String> {
175 let _lock = match acquire_lock(&self.lock_path) {
176 Some(l) => l,
177 None => return Ok(()),
178 };
179 let mut shard = match read_owned_shard(&self.path) {
180 Some(s)
181 if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
182 && s.header.pointer_width as usize == std::mem::size_of::<usize>()
183 && s.header.format_version == SHARD_FORMAT_VERSION =>
184 {
185 s
186 }
187 _ => fresh_shard(),
188 };
189 let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
190 shard.entries.insert(
191 name.to_string(),
192 AutoloadEntry {
193 binary_mtime_at_cache: bin_mtime,
194 cached_at_secs: now_secs(),
195 chunk_blob,
196 },
197 );
198 shard.header.built_at_secs = now_secs() as u64;
199 write_shard_atomic(&self.path, &shard)?;
200 self.invalidate_mmap();
201 Ok(())
202 }
203
204 pub fn merge_in(&self, entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
211 if entries.is_empty() {
212 return Ok(());
213 }
214 let _lock = match acquire_lock(&self.lock_path) {
215 Some(l) => l,
216 None => return Ok(()),
217 };
218 let mut shard = match read_owned_shard(&self.path) {
219 Some(s)
220 if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
221 && s.header.pointer_width as usize == std::mem::size_of::<usize>()
222 && s.header.format_version == SHARD_FORMAT_VERSION =>
223 {
224 s
225 }
226 _ => fresh_shard(),
227 };
228 let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
229 let now = now_secs();
230 for (name, chunk_blob) in entries {
231 shard.entries.insert(
232 name,
233 AutoloadEntry {
234 binary_mtime_at_cache: bin_mtime,
235 cached_at_secs: now,
236 chunk_blob,
237 },
238 );
239 }
240 shard.header.built_at_secs = now as u64;
241 write_shard_atomic(&self.path, &shard)?;
242 self.invalidate_mmap();
243 Ok(())
244 }
245
246 pub fn replace_all(&self, entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
250 let _lock = match acquire_lock(&self.lock_path) {
251 Some(l) => l,
252 None => return Ok(()),
253 };
254 let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
255 let now = now_secs();
256 let mut shard = fresh_shard();
257 for (name, chunk_blob) in entries {
258 shard.entries.insert(
259 name,
260 AutoloadEntry {
261 binary_mtime_at_cache: bin_mtime,
262 cached_at_secs: now,
263 chunk_blob,
264 },
265 );
266 }
267 write_shard_atomic(&self.path, &shard)?;
268 self.invalidate_mmap();
269 Ok(())
270 }
271
272 pub fn entry_count(&self) -> usize {
273 self.ensure_mmap();
274 let guard = self.mmap.lock();
275 guard.as_ref().map(|s| s.shard().entries.len()).unwrap_or(0)
276 }
277
278 pub fn cached_names(&self) -> std::collections::HashSet<String> {
282 self.ensure_mmap();
283 let guard = self.mmap.lock();
284 let Some(shard) = guard.as_ref() else {
285 return std::collections::HashSet::new();
286 };
287 shard
288 .shard()
289 .entries
290 .keys()
291 .map(|k| k.as_str().to_string())
292 .collect()
293 }
294
295 pub fn stats(&self) -> (i64, i64) {
296 self.ensure_mmap();
297 let guard = self.mmap.lock();
298 let Some(shard) = guard.as_ref() else {
299 return (0, 0);
300 };
301 let count = shard.shard().entries.len() as i64;
302 let bytes: i64 = shard
303 .shard()
304 .entries
305 .values()
306 .map(|e| e.chunk_blob.len() as i64)
307 .sum();
308 (count, bytes)
309 }
310
311 pub fn clear(&self) -> std::io::Result<()> {
312 let _lock = acquire_lock(&self.lock_path);
313 let res = match std::fs::remove_file(&self.path) {
314 Ok(()) => Ok(()),
315 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
316 Err(e) => Err(e),
317 };
318 self.invalidate_mmap();
319 res
320 }
321}
322
323fn acquire_lock(path: &Path) -> Option<nix::fcntl::Flock<File>> {
324 let f = File::options()
325 .read(true)
326 .write(true)
327 .create(true)
328 .truncate(false)
329 .open(path)
330 .ok()?;
331 nix::fcntl::Flock::lock(f, nix::fcntl::FlockArg::LockExclusive).ok()
332}
333
334fn fresh_shard() -> AutoloadShard {
335 AutoloadShard {
336 header: ShardHeader {
337 magic: SHARD_MAGIC,
338 format_version: SHARD_FORMAT_VERSION,
339 zshrs_version: env!("CARGO_PKG_VERSION").to_string(),
340 pointer_width: std::mem::size_of::<usize>() as u32,
341 built_at_secs: now_secs() as u64,
342 },
343 entries: HashMap::new(),
344 }
345}
346
347fn read_owned_shard(path: &Path) -> Option<AutoloadShard> {
348 let bytes = std::fs::read(path).ok()?;
349 let archived = rkyv::check_archived_root::<AutoloadShard>(&bytes[..]).ok()?;
350 archived.deserialize(&mut rkyv::Infallible).ok()
351}
352
353fn write_shard_atomic(path: &Path, shard: &AutoloadShard) -> Result<(), String> {
354 let bytes = rkyv::to_bytes::<_, 4096>(shard)
355 .map_err(|e| format!("rkyv serialize: {}", e))?;
356 let parent = path.parent().expect("cache path has parent");
357 let _ = std::fs::create_dir_all(parent);
358 let pid = std::process::id();
359 let nanos = SystemTime::now()
360 .duration_since(UNIX_EPOCH)
361 .map(|d| d.as_nanos())
362 .unwrap_or(0);
363 let tmp_path = parent.join(format!(
364 "{}.tmp.{}.{}",
365 path.file_name()
366 .and_then(|s| s.to_str())
367 .unwrap_or("autoloads.rkyv"),
368 pid,
369 nanos
370 ));
371 {
372 let mut f = File::create(&tmp_path).map_err(|e| e.to_string())?;
373 f.write_all(&bytes).map_err(|e| e.to_string())?;
374 f.sync_all().map_err(|e| e.to_string())?;
375 }
376 std::fs::rename(&tmp_path, path).map_err(|e| e.to_string())?;
377 Ok(())
378}
379
380fn now_secs() -> i64 {
381 SystemTime::now()
382 .duration_since(UNIX_EPOCH)
383 .map(|d| d.as_secs() as i64)
384 .unwrap_or(0)
385}
386
387fn file_mtime(path: &Path) -> Option<(i64, i64)> {
388 let meta = std::fs::metadata(path).ok()?;
389 Some((meta.mtime(), meta.mtime_nsec()))
390}
391
392fn current_binary_mtime_secs() -> Option<i64> {
393 static BIN_MTIME: OnceLock<Option<i64>> = OnceLock::new();
394 *BIN_MTIME.get_or_init(|| {
395 let exe = std::env::current_exe().ok()?;
396 let (secs, _) = file_mtime(&exe)?;
397 Some(secs)
398 })
399}
400
401pub fn default_cache_path() -> PathBuf {
402 dirs::home_dir()
403 .unwrap_or_else(|| PathBuf::from("/tmp"))
404 .join(".cache/zshrs/autoloads.rkyv")
405}
406
407pub fn cache_enabled() -> bool {
408 !matches!(
409 std::env::var("ZSHRS_CACHE").as_deref(),
410 Ok("0") | Ok("false") | Ok("no")
411 )
412}
413
414pub static CACHE: once_cell::sync::Lazy<Option<AutoloadCache>> =
415 once_cell::sync::Lazy::new(|| {
416 if !cache_enabled() {
417 return None;
418 }
419 AutoloadCache::open(&default_cache_path()).ok()
420 });
421
422pub fn try_load(name: &str) -> Option<Vec<u8>> {
423 let cache = CACHE.as_ref()?;
424 cache.get(name)
425}
426
427pub fn try_save_one(name: &str, chunk_blob: &[u8]) -> Result<(), String> {
428 let Some(cache) = CACHE.as_ref() else {
429 return Ok(());
430 };
431 cache.put_one(name, chunk_blob.to_vec())
432}
433
434pub fn try_replace_all(entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
438 let Some(cache) = CACHE.as_ref() else {
439 return Ok(());
440 };
441 cache.replace_all(entries)
442}
443
444pub fn try_merge_in(entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
448 let Some(cache) = CACHE.as_ref() else {
449 return Ok(());
450 };
451 cache.merge_in(entries)
452}
453
454pub fn cached_names() -> std::collections::HashSet<String> {
455 CACHE
456 .as_ref()
457 .map(|c| c.cached_names())
458 .unwrap_or_default()
459}
460
461pub fn entry_count() -> usize {
462 CACHE.as_ref().map(|c| c.entry_count()).unwrap_or(0)
463}
464
465pub fn stats() -> Option<(i64, i64)> {
466 CACHE.as_ref().map(|c| c.stats())
467}
468
469pub fn clear() -> bool {
470 CACHE.as_ref().map(|c| c.clear().is_ok()).unwrap_or(false)
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use tempfile::tempdir;
477
478 #[test]
479 fn round_trip_one() {
480 let dir = tempdir().unwrap();
481 let cache_path = dir.path().join("autoloads.rkyv");
482 let cache = AutoloadCache::open(&cache_path).unwrap();
483 cache.put_one("foo", vec![1, 2, 3]).unwrap();
484 assert_eq!(cache.get("foo"), Some(vec![1, 2, 3]));
485 assert_eq!(cache.entry_count(), 1);
486 }
487
488 #[test]
489 fn replace_all_overwrites() {
490 let dir = tempdir().unwrap();
491 let cache_path = dir.path().join("autoloads.rkyv");
492 let cache = AutoloadCache::open(&cache_path).unwrap();
493 cache.put_one("a", vec![10]).unwrap();
494 cache.put_one("b", vec![20]).unwrap();
495 assert_eq!(cache.entry_count(), 2);
496
497 let mut new_entries = HashMap::new();
498 new_entries.insert("c".to_string(), vec![30]);
499 new_entries.insert("d".to_string(), vec![40]);
500 cache.replace_all(new_entries).unwrap();
501
502 assert_eq!(cache.entry_count(), 2);
503 assert!(cache.get("a").is_none());
504 assert!(cache.get("b").is_none());
505 assert_eq!(cache.get("c"), Some(vec![30]));
506 assert_eq!(cache.get("d"), Some(vec![40]));
507 }
508
509 #[test]
510 fn cached_names_returns_keys() {
511 let dir = tempdir().unwrap();
512 let cache_path = dir.path().join("autoloads.rkyv");
513 let cache = AutoloadCache::open(&cache_path).unwrap();
514 cache.put_one("alpha", vec![1]).unwrap();
515 cache.put_one("beta", vec![2]).unwrap();
516 let names = cache.cached_names();
517 assert!(names.contains("alpha"));
518 assert!(names.contains("beta"));
519 assert_eq!(names.len(), 2);
520 }
521
522 #[test]
523 fn corrupt_shard_returns_none() {
524 let dir = tempdir().unwrap();
525 let cache_path = dir.path().join("autoloads.rkyv");
526 std::fs::write(&cache_path, b"garbage").unwrap();
527 let cache = AutoloadCache::open(&cache_path).unwrap();
528 assert!(cache.get("anything").is_none());
529 assert_eq!(cache.entry_count(), 0);
530 }
531
532 #[test]
533 fn clear_removes_file() {
534 let dir = tempdir().unwrap();
535 let cache_path = dir.path().join("autoloads.rkyv");
536 let cache = AutoloadCache::open(&cache_path).unwrap();
537 cache.put_one("x", vec![1]).unwrap();
538 assert!(cache_path.exists());
539 cache.clear().unwrap();
540 assert!(!cache_path.exists());
541 }
542}