1use std::collections::hash_map::DefaultHasher;
19use std::collections::HashSet;
20use std::fs::{self, File};
21use std::hash::{Hash, Hasher};
22use std::io::{BufReader, BufWriter, Read, Write};
23use std::path::Path;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::sync::RwLock;
26use std::time::SystemTime;
27
28use dashmap::DashMap;
29use serde::{de::DeserializeOwned, Deserialize, Serialize};
30use tldr_core::Language;
31
32use super::error::DaemonResult;
33use super::types::SalsaCacheStats;
34
35pub const DEFAULT_MAX_ENTRIES: usize = 10_000;
41
42pub const DEFAULT_MAX_BYTES: usize = 512 * 1024 * 1024;
44
45const CACHE_MAGIC: &[u8; 4] = b"TLDR";
47
48const CACHE_VERSION: u8 = 1;
52
53pub const CACHE_SCHEMA_VERSION: u32 = 2;
65
66#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
86pub struct QueryKey {
87 pub query_name: String,
89 pub args_hash: u64,
91 pub language: Language,
95}
96
97impl QueryKey {
98 pub fn new(query_name: impl Into<String>, args_hash: u64, language: Language) -> Self {
100 Self {
101 query_name: query_name.into(),
102 args_hash,
103 language,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CacheEntry {
111 pub value: Vec<u8>,
113 pub revision: u64,
115 pub input_hashes: Vec<u64>,
117 #[serde(with = "system_time_serde")]
119 pub created_at: SystemTime,
120 #[serde(with = "system_time_serde")]
122 pub last_accessed: SystemTime,
123}
124
125impl CacheEntry {
126 pub fn new(value: Vec<u8>, revision: u64, input_hashes: Vec<u64>) -> Self {
128 let now = SystemTime::now();
129 Self {
130 value,
131 revision,
132 input_hashes,
133 created_at: now,
134 last_accessed: now,
135 }
136 }
137
138 pub fn estimated_bytes(&self) -> usize {
140 self.value.len()
141 + self.input_hashes.len() * std::mem::size_of::<u64>()
142 + std::mem::size_of::<Self>()
143 }
144}
145
146pub struct QueryCache {
148 entries: DashMap<QueryKey, CacheEntry>,
150 dependents: DashMap<u64, HashSet<QueryKey>>,
152 revision: AtomicU64,
154 stats: RwLock<SalsaCacheStats>,
156 max_entries: usize,
158 max_bytes: usize,
160 current_bytes: AtomicU64,
162}
163
164impl QueryCache {
169 pub fn new(max_entries: usize) -> Self {
171 Self::with_limits(max_entries, DEFAULT_MAX_BYTES)
172 }
173
174 pub fn with_limits(max_entries: usize, max_bytes: usize) -> Self {
176 Self {
177 entries: DashMap::new(),
178 dependents: DashMap::new(),
179 revision: AtomicU64::new(0),
180 stats: RwLock::new(SalsaCacheStats::default()),
181 max_entries,
182 max_bytes,
183 current_bytes: AtomicU64::new(0),
184 }
185 }
186
187 pub fn with_defaults() -> Self {
189 Self::new(DEFAULT_MAX_ENTRIES)
190 }
191
192 pub fn get<T: DeserializeOwned>(&self, key: &QueryKey) -> Option<T> {
198 if let Some(mut entry) = self.entries.get_mut(key) {
199 entry.last_accessed = SystemTime::now();
201
202 if let Ok(mut stats) = self.stats.write() {
204 stats.hits += 1;
205 }
206
207 match serde_json::from_slice(&entry.value) {
209 Ok(value) => Some(value),
210 Err(_) => {
211 drop(entry);
213 self.entries.remove(key);
214 None
215 }
216 }
217 } else {
218 if let Ok(mut stats) = self.stats.write() {
220 stats.misses += 1;
221 }
222 None
223 }
224 }
225
226 pub fn insert<T: Serialize>(&self, key: QueryKey, value: &T, input_hashes: Vec<u64>) {
231 let serialized = match serde_json::to_vec(value) {
233 Ok(v) => v,
234 Err(_) => return, };
236
237 let revision = self.revision.load(Ordering::Acquire);
238 let entry = CacheEntry::new(serialized, revision, input_hashes.clone());
239
240 for &hash in &input_hashes {
242 self.dependents.entry(hash).or_default().insert(key.clone());
243 }
244
245 if let Some(old) = self.entries.get(&key) {
247 self.current_bytes
248 .fetch_sub(old.estimated_bytes() as u64, Ordering::Relaxed);
249 }
250
251 self.current_bytes
253 .fetch_add(entry.estimated_bytes() as u64, Ordering::Relaxed);
254
255 self.entries.insert(key, entry);
257
258 self.maybe_evict();
260 }
261
262 pub fn invalidate_by_input(&self, input_hash: u64) -> usize {
266 self.revision.fetch_add(1, Ordering::Release);
268
269 let mut invalidated = 0;
270
271 if let Some((_, keys)) = self.dependents.remove(&input_hash) {
273 for key in keys {
274 if let Some((_, entry)) = self.entries.remove(&key) {
275 self.current_bytes
276 .fetch_sub(entry.estimated_bytes() as u64, Ordering::Relaxed);
277 invalidated += 1;
278 }
279 }
280 }
281
282 if let Ok(mut stats) = self.stats.write() {
284 stats.invalidations += invalidated as u64;
285 }
286
287 invalidated
288 }
289
290 pub fn invalidate(&self, key: &QueryKey) -> bool {
294 if let Some((_, entry)) = self.entries.remove(key) {
295 self.current_bytes
297 .fetch_sub(entry.estimated_bytes() as u64, Ordering::Relaxed);
298
299 for hash in entry.input_hashes {
301 if let Some(mut deps) = self.dependents.get_mut(&hash) {
302 deps.remove(key);
303 }
304 }
305
306 if let Ok(mut stats) = self.stats.write() {
307 stats.invalidations += 1;
308 }
309
310 true
311 } else {
312 false
313 }
314 }
315
316 pub fn stats(&self) -> SalsaCacheStats {
318 self.stats.read().map(|s| s.clone()).unwrap_or_default()
319 }
320
321 pub fn len(&self) -> usize {
323 self.entries.len()
324 }
325
326 pub fn is_empty(&self) -> bool {
328 self.entries.is_empty()
329 }
330
331 pub fn revision(&self) -> u64 {
333 self.revision.load(Ordering::Acquire)
334 }
335
336 pub fn clear(&self) {
338 self.entries.clear();
339 self.dependents.clear();
340 self.revision.store(0, Ordering::Release);
341 self.current_bytes.store(0, Ordering::Relaxed);
342
343 if let Ok(mut stats) = self.stats.write() {
344 *stats = SalsaCacheStats::default();
345 }
346 }
347
348 pub fn total_bytes(&self) -> usize {
350 self.current_bytes.load(Ordering::Relaxed) as usize
351 }
352
353 fn maybe_evict(&self) {
355 let over_entries = self.entries.len() > self.max_entries;
356 let over_bytes = self.total_bytes() > self.max_bytes;
357
358 if !over_entries && !over_bytes {
359 return;
360 }
361
362 let mut entries_by_time: Vec<(QueryKey, SystemTime, usize)> = self
364 .entries
365 .iter()
366 .map(|e| {
367 (
368 e.key().clone(),
369 e.value().last_accessed,
370 e.value().estimated_bytes(),
371 )
372 })
373 .collect();
374
375 entries_by_time.sort_by(|a, b| a.1.cmp(&b.1));
377
378 for (key, _, _) in entries_by_time {
380 if self.entries.len() <= self.max_entries && self.total_bytes() <= self.max_bytes {
381 break;
382 }
383 self.invalidate(&key);
384 }
385 }
386
387 pub fn save_to_file(&self, path: &Path) -> DaemonResult<()> {
395 let entries: Vec<(QueryKey, CacheEntry)> = self
397 .entries
398 .iter()
399 .map(|e| (e.key().clone(), e.value().clone()))
400 .collect();
401
402 let dependents: Vec<(u64, Vec<QueryKey>)> = self
403 .dependents
404 .iter()
405 .map(|e| (*e.key(), e.value().iter().cloned().collect()))
406 .collect();
407
408 let stats = self.stats();
409 let revision = self.revision();
410
411 let cache_data = CacheFileData {
412 schema_version: CACHE_SCHEMA_VERSION,
415 entries,
416 dependents,
417 stats,
418 revision,
419 };
420
421 let json = serde_json::to_vec(&cache_data)?;
423
424 let checksum = calculate_checksum(&json);
426
427 let temp_path = path.with_extension("tmp");
429 {
430 let file = File::create(&temp_path)?;
431 let mut writer = BufWriter::new(file);
432
433 writer.write_all(CACHE_MAGIC)?;
435 writer.write_all(&[CACHE_VERSION])?;
436 writer.write_all(&checksum.to_le_bytes())?;
437 writer.write_all(&json)?;
438 writer.flush()?;
439 }
440
441 fs::rename(&temp_path, path)?;
443
444 Ok(())
445 }
446
447 pub fn load_from_file(path: &Path) -> DaemonResult<Self> {
460 let file = File::open(path)?;
461 match Self::try_load_payload(&file) {
462 Ok(cache_data) if cache_data.schema_version == CACHE_SCHEMA_VERSION => {
463 Ok(Self::from_cache_data(cache_data))
464 }
465 Ok(stale) => {
466 eprintln!(
467 "tldr-cli: cache schema mismatch on {} (found schema_version={}, expected {}); discarding and starting fresh",
468 path.display(),
469 stale.schema_version,
470 CACHE_SCHEMA_VERSION,
471 );
472 let _ = fs::remove_file(path);
473 Ok(Self::with_defaults())
474 }
475 Err(reason) => {
476 eprintln!(
477 "tldr-cli: cache file at {} could not be loaded ({}); discarding and starting fresh",
478 path.display(),
479 reason,
480 );
481 let _ = fs::remove_file(path);
482 Ok(Self::with_defaults())
483 }
484 }
485 }
486
487 fn try_load_payload(file: &File) -> Result<CacheFileData, String> {
492 let mut reader = BufReader::new(file);
493
494 let mut magic = [0u8; 4];
495 reader
496 .read_exact(&mut magic)
497 .map_err(|e| format!("read magic: {}", e))?;
498 if &magic != CACHE_MAGIC {
499 return Err("invalid cache file magic".to_string());
500 }
501
502 let mut version = [0u8; 1];
503 reader
504 .read_exact(&mut version)
505 .map_err(|e| format!("read version: {}", e))?;
506 if version[0] != CACHE_VERSION {
507 return Err(format!("unsupported cache header version: {}", version[0]));
508 }
509
510 let mut checksum_bytes = [0u8; 8];
511 reader
512 .read_exact(&mut checksum_bytes)
513 .map_err(|e| format!("read checksum: {}", e))?;
514 let stored_checksum = u64::from_le_bytes(checksum_bytes);
515
516 let mut data = Vec::new();
517 reader
518 .read_to_end(&mut data)
519 .map_err(|e| format!("read payload: {}", e))?;
520
521 let actual_checksum = calculate_checksum(&data);
522 if stored_checksum != actual_checksum {
523 return Err("cache file checksum mismatch".to_string());
524 }
525
526 serde_json::from_slice::<CacheFileData>(&data)
527 .map_err(|e| format!("deserialize payload: {}", e))
528 }
529
530 fn from_cache_data(cache_data: CacheFileData) -> Self {
533 let cache = Self::with_defaults();
534
535 let mut total_bytes: u64 = 0;
536 for (key, entry) in cache_data.entries {
537 total_bytes += entry.estimated_bytes() as u64;
538 cache.entries.insert(key, entry);
539 }
540 cache.current_bytes.store(total_bytes, Ordering::Relaxed);
541
542 for (hash, keys) in cache_data.dependents {
543 cache.dependents.insert(hash, keys.into_iter().collect());
544 }
545
546 cache.revision.store(cache_data.revision, Ordering::Release);
547
548 if let Ok(mut stats) = cache.stats.write() {
549 *stats = cache_data.stats;
550 }
551
552 cache
553 }
554}
555
556impl Default for QueryCache {
557 fn default() -> Self {
558 Self::with_defaults()
559 }
560}
561
562#[derive(Serialize, Deserialize, Default)]
578pub struct CacheFileData {
579 #[serde(default)]
582 pub schema_version: u32,
583 pub entries: Vec<(QueryKey, CacheEntry)>,
585 pub dependents: Vec<(u64, Vec<QueryKey>)>,
587 pub stats: SalsaCacheStats,
589 pub revision: u64,
591}
592
593fn calculate_checksum(data: &[u8]) -> u64 {
595 let mut hasher = DefaultHasher::new();
596 data.hash(&mut hasher);
597 hasher.finish()
598}
599
600mod system_time_serde {
602 use serde::{Deserialize, Deserializer, Serialize, Serializer};
603 use std::time::{Duration, SystemTime, UNIX_EPOCH};
604
605 pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
606 where
607 S: Serializer,
608 {
609 let duration = time.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO);
610 duration.as_secs().serialize(serializer)
611 }
612
613 pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
614 where
615 D: Deserializer<'de>,
616 {
617 let secs = u64::deserialize(deserializer)?;
618 Ok(UNIX_EPOCH + Duration::from_secs(secs))
619 }
620}
621
622pub fn hash_args<T: Hash>(args: &T) -> u64 {
628 let mut hasher = DefaultHasher::new();
629 args.hash(&mut hasher);
630 hasher.finish()
631}
632
633pub fn hash_path(path: &Path) -> u64 {
635 let mut hasher = DefaultHasher::new();
636 path.hash(&mut hasher);
637 hasher.finish()
638}
639
640#[cfg(test)]
645mod tests {
646 use super::*;
647 use tempfile::tempdir;
648
649 #[test]
650 fn test_query_cache_new() {
651 let cache = QueryCache::new(100);
652 assert_eq!(cache.max_entries, 100);
653 assert!(cache.is_empty());
654 assert_eq!(cache.revision(), 0);
655 }
656
657 #[test]
658 fn test_query_cache_insert_and_get() {
659 let cache = QueryCache::new(100);
660 let key = QueryKey::new("test", 12345, Language::Python);
661 let value = vec!["hello", "world"];
662
663 cache.insert(key.clone(), &value, vec![]);
664
665 let result: Option<Vec<String>> = cache.get(&key);
666 assert!(result.is_some());
667 assert_eq!(result.unwrap(), vec!["hello", "world"]);
668 }
669
670 #[test]
671 fn test_query_cache_miss() {
672 let cache = QueryCache::new(100);
673 let key = QueryKey::new("nonexistent", 99999, Language::Python);
674
675 let result: Option<String> = cache.get(&key);
676 assert!(result.is_none());
677
678 let stats = cache.stats();
679 assert_eq!(stats.misses, 1);
680 assert_eq!(stats.hits, 0);
681 }
682
683 #[test]
684 fn test_query_cache_hit_tracking() {
685 let cache = QueryCache::new(100);
686 let key = QueryKey::new("test", 12345, Language::Python);
687 cache.insert(key.clone(), &"value", vec![]);
688
689 let _: Option<String> = cache.get(&key);
691 let _: Option<String> = cache.get(&key);
693
694 let stats = cache.stats();
695 assert_eq!(stats.hits, 2);
696 }
697
698 #[test]
699 fn test_query_cache_invalidate_by_input() {
700 let cache = QueryCache::new(100);
701 let input_hash = hash_path(Path::new("/test/file.rs"));
702
703 let key1 = QueryKey::new("query1", 1, Language::Python);
705 let key2 = QueryKey::new("query2", 2, Language::Python);
706 let key3 = QueryKey::new("query3", 3, Language::Python); cache.insert(key1.clone(), &"value1", vec![input_hash]);
709 cache.insert(key2.clone(), &"value2", vec![input_hash]);
710 cache.insert(key3.clone(), &"value3", vec![]);
711
712 assert_eq!(cache.len(), 3);
713
714 let invalidated = cache.invalidate_by_input(input_hash);
716 assert_eq!(invalidated, 2);
717 assert_eq!(cache.len(), 1);
718
719 let result: Option<String> = cache.get(&key3);
721 assert!(result.is_some());
722
723 let result: Option<String> = cache.get(&key1);
725 assert!(result.is_none());
726 }
727
728 #[test]
729 fn test_query_cache_invalidation_stats() {
730 let cache = QueryCache::new(100);
731 let key = QueryKey::new("test", 1, Language::Python);
732 cache.insert(key.clone(), &"value", vec![12345]);
733
734 cache.invalidate_by_input(12345);
735
736 let stats = cache.stats();
737 assert_eq!(stats.invalidations, 1);
738 }
739
740 #[test]
741 fn test_query_cache_clear() {
742 let cache = QueryCache::new(100);
743
744 cache.insert(QueryKey::new("q1", 1, Language::Python), &"v1", vec![]);
746 cache.insert(QueryKey::new("q2", 2, Language::Python), &"v2", vec![]);
747
748 assert_eq!(cache.len(), 2);
749
750 cache.clear();
751
752 assert!(cache.is_empty());
753 assert_eq!(cache.revision(), 0);
754 }
755
756 #[test]
757 fn test_query_cache_lru_eviction() {
758 let cache = QueryCache::new(3); cache.insert(QueryKey::new("q1", 1, Language::Python), &"v1", vec![]);
762 std::thread::sleep(std::time::Duration::from_millis(10));
763 cache.insert(QueryKey::new("q2", 2, Language::Python), &"v2", vec![]);
764 std::thread::sleep(std::time::Duration::from_millis(10));
765 cache.insert(QueryKey::new("q3", 3, Language::Python), &"v3", vec![]);
766 std::thread::sleep(std::time::Duration::from_millis(10));
767
768 let _: Option<String> = cache.get(&QueryKey::new("q1", 1, Language::Python));
770 std::thread::sleep(std::time::Duration::from_millis(10));
771
772 cache.insert(QueryKey::new("q4", 4, Language::Python), &"v4", vec![]);
774
775 assert!(cache.len() <= 3);
776
777 let result: Option<String> = cache.get(&QueryKey::new("q1", 1, Language::Python));
779 assert!(result.is_some());
780 }
781
782 #[test]
783 fn test_query_cache_persistence() {
784 let dir = tempdir().unwrap();
785 let cache_path = dir.path().join("test_cache.bin");
786
787 let cache = QueryCache::new(100);
789 cache.insert(QueryKey::new("test", 12345, Language::Python), &"hello world", vec![1, 2, 3]);
790 cache.insert(QueryKey::new("test2", 67890, Language::Python), &vec![1, 2, 3], vec![]);
791
792 cache.save_to_file(&cache_path).unwrap();
794
795 let loaded = QueryCache::load_from_file(&cache_path).unwrap();
797
798 assert_eq!(loaded.len(), 2);
800
801 let result: Option<String> = loaded.get(&QueryKey::new("test", 12345, Language::Python));
802 assert_eq!(result, Some("hello world".to_string()));
803
804 let result: Option<Vec<i32>> = loaded.get(&QueryKey::new("test2", 67890, Language::Python));
805 assert_eq!(result, Some(vec![1, 2, 3]));
806 }
807
808 #[test]
809 fn test_query_cache_persistence_checksum_validation() {
810 let dir = tempdir().unwrap();
811 let cache_path = dir.path().join("test_cache.bin");
812
813 let cache = QueryCache::new(100);
815 cache.insert(QueryKey::new("test", 1, Language::Python), &"value", vec![]);
816 cache.save_to_file(&cache_path).unwrap();
817
818 let mut data = fs::read(&cache_path).unwrap();
820 if data.len() > 20 {
821 data[20] ^= 0xFF; }
823 fs::write(&cache_path, data).unwrap();
824
825 let loaded = QueryCache::load_from_file(&cache_path)
832 .expect("checksum mismatch must be discarded gracefully");
833 assert_eq!(loaded.len(), 0, "discarded cache must be empty");
834 assert!(
835 !cache_path.exists(),
836 "corrupted cache file must be removed by graceful-discard path"
837 );
838 }
839
840 #[test]
841 fn test_hash_args() {
842 let args1 = ("query", "/path/to/file.rs", 42);
843 let args2 = ("query", "/path/to/file.rs", 42);
844 let args3 = ("query", "/path/to/other.rs", 42);
845
846 assert_eq!(hash_args(&args1), hash_args(&args2));
847 assert_ne!(hash_args(&args1), hash_args(&args3));
848 }
849
850 #[test]
851 fn test_hash_path() {
852 let path1 = Path::new("/foo/bar.rs");
853 let path2 = Path::new("/foo/bar.rs");
854 let path3 = Path::new("/foo/baz.rs");
855
856 assert_eq!(hash_path(path1), hash_path(path2));
857 assert_ne!(hash_path(path1), hash_path(path3));
858 }
859
860 #[test]
861 fn test_query_key_equality() {
862 let key1 = QueryKey::new("test", 12345, Language::Python);
863 let key2 = QueryKey::new("test", 12345, Language::Python);
864 let key3 = QueryKey::new("test", 99999, Language::Python);
865 let key4 = QueryKey::new("other", 12345, Language::Python);
866
867 assert_eq!(key1, key2);
868 assert_ne!(key1, key3);
869 assert_ne!(key1, key4);
870 }
871
872 #[test]
873 fn test_cache_entry_creation() {
874 let entry = CacheEntry::new(vec![1, 2, 3], 5, vec![100, 200]);
875
876 assert_eq!(entry.value, vec![1, 2, 3]);
877 assert_eq!(entry.revision, 5);
878 assert_eq!(entry.input_hashes, vec![100, 200]);
879 assert!(entry.created_at <= SystemTime::now());
880 assert!(entry.last_accessed <= SystemTime::now());
881 }
882
883 #[test]
884 fn test_stats_hit_rate_calculation() {
885 let cache = QueryCache::new(100);
886
887 let stats = cache.stats();
889 assert_eq!(stats.hit_rate(), 0.0);
890
891 cache.insert(QueryKey::new("test", 1, Language::Python), &"value", vec![]);
893 let _: Option<String> = cache.get(&QueryKey::new("test", 1, Language::Python)); let _: Option<String> = cache.get(&QueryKey::new("test", 2, Language::Python)); let _: Option<String> = cache.get(&QueryKey::new("test", 1, Language::Python)); let stats = cache.stats();
898 assert_eq!(stats.hits, 2);
899 assert_eq!(stats.misses, 1);
900 assert!((stats.hit_rate() - 66.67).abs() < 0.1);
902 }
903
904 #[test]
905 fn test_revision_increments_on_invalidation() {
906 let cache = QueryCache::new(100);
907 assert_eq!(cache.revision(), 0);
908
909 cache.invalidate_by_input(12345);
910 assert_eq!(cache.revision(), 1);
911
912 cache.invalidate_by_input(67890);
913 assert_eq!(cache.revision(), 2);
914 }
915
916 #[test]
917 fn test_multiple_entries_same_input() {
918 let cache = QueryCache::new(100);
919 let shared_input = 12345u64;
920
921 cache.insert(QueryKey::new("q1", 1, Language::Python), &"v1", vec![shared_input]);
923 cache.insert(QueryKey::new("q2", 2, Language::Python), &"v2", vec![shared_input]);
924 cache.insert(QueryKey::new("q3", 3, Language::Python), &"v3", vec![shared_input]);
925
926 assert_eq!(cache.len(), 3);
927
928 let count = cache.invalidate_by_input(shared_input);
930 assert_eq!(count, 3);
931 assert!(cache.is_empty());
932 }
933
934 #[test]
935 fn test_entry_with_multiple_inputs() {
936 let cache = QueryCache::new(100);
937 let input1 = 111u64;
938 let input2 = 222u64;
939
940 cache.insert(QueryKey::new("q1", 1, Language::Python), &"v1", vec![input1, input2]);
942
943 assert_eq!(cache.len(), 1);
945 cache.invalidate_by_input(input1);
946 assert!(cache.is_empty());
947 }
948
949 #[test]
954 fn test_total_bytes_tracking() {
955 let cache = QueryCache::new(100);
956 assert_eq!(cache.total_bytes(), 0);
957
958 cache.insert(QueryKey::new("q1", 1, Language::Python), &"hello", vec![]);
960 let bytes_after_one = cache.total_bytes();
961 assert!(
962 bytes_after_one > 0,
963 "total_bytes should increase after insert"
964 );
965
966 cache.insert(QueryKey::new("q2", 2, Language::Python), &"world", vec![]);
968 let bytes_after_two = cache.total_bytes();
969 assert!(
970 bytes_after_two > bytes_after_one,
971 "total_bytes should increase with more entries"
972 );
973
974 cache.clear();
976 assert_eq!(cache.total_bytes(), 0);
977 }
978
979 #[test]
980 fn test_bytes_decrease_on_invalidate() {
981 let cache = QueryCache::new(100);
982 cache.insert(QueryKey::new("q1", 1, Language::Python), &"value1", vec![]);
983 cache.insert(QueryKey::new("q2", 2, Language::Python), &"value2", vec![]);
984 let bytes_before = cache.total_bytes();
985
986 cache.invalidate(&QueryKey::new("q1", 1, Language::Python));
987 let bytes_after = cache.total_bytes();
988 assert!(
989 bytes_after < bytes_before,
990 "total_bytes should decrease after invalidation"
991 );
992 }
993
994 #[test]
995 fn test_bytes_decrease_on_invalidate_by_input() {
996 let cache = QueryCache::new(100);
997 let input_hash = 42u64;
998
999 cache.insert(QueryKey::new("q1", 1, Language::Python), &"value1", vec![input_hash]);
1000 cache.insert(QueryKey::new("q2", 2, Language::Python), &"value2", vec![input_hash]);
1001 let bytes_before = cache.total_bytes();
1002 assert!(bytes_before > 0);
1003
1004 cache.invalidate_by_input(input_hash);
1005 assert_eq!(
1006 cache.total_bytes(),
1007 0,
1008 "total_bytes should be 0 after all entries invalidated"
1009 );
1010 }
1011
1012 #[test]
1013 fn test_byte_limit_eviction() {
1014 let cache = QueryCache::with_limits(10_000, 1024);
1016
1017 let payload = "x".repeat(200);
1020 for i in 0..20 {
1021 cache.insert(QueryKey::new("q", i, Language::Python), &payload, vec![]);
1022 }
1023
1024 assert!(
1026 cache.total_bytes() <= 1024,
1027 "total_bytes ({}) should be <= 1024 after eviction",
1028 cache.total_bytes()
1029 );
1030 assert!(
1031 cache.len() < 20,
1032 "entry count ({}) should be < 20 after byte-based eviction",
1033 cache.len()
1034 );
1035 }
1036
1037 #[test]
1038 fn test_large_entry_evicts_many_small() {
1039 let cache = QueryCache::with_limits(10_000, 2048);
1041
1042 for i in 0..10 {
1044 cache.insert(QueryKey::new("small", i, Language::Python), &"tiny", vec![]);
1045 }
1046 let count_before = cache.len();
1047 assert_eq!(count_before, 10);
1048
1049 let big_payload = "x".repeat(1500);
1051 cache.insert(QueryKey::new("big", 0, Language::Python), &big_payload, vec![]);
1052
1053 assert!(
1055 cache.total_bytes() <= 2048,
1056 "total_bytes ({}) should be <= 2048",
1057 cache.total_bytes()
1058 );
1059 let result: Option<String> = cache.get(&QueryKey::new("big", 0, Language::Python));
1061 assert!(result.is_some(), "large entry should survive eviction");
1062 }
1063
1064 #[test]
1065 fn test_byte_tracking_on_replace() {
1066 let cache = QueryCache::new(100);
1067
1068 cache.insert(QueryKey::new("q1", 1, Language::Python), &"small", vec![]);
1070 let bytes_small = cache.total_bytes();
1071
1072 let big = "x".repeat(10_000);
1074 cache.insert(QueryKey::new("q1", 1, Language::Python), &big, vec![]);
1075 let bytes_big = cache.total_bytes();
1076
1077 assert!(
1078 bytes_big > bytes_small,
1079 "bytes should increase when replacing small with large"
1080 );
1081 assert_eq!(cache.len(), 1, "should still be one entry after replace");
1082 }
1083
1084 #[test]
1085 fn test_memory_bounded_cache_under_stress() {
1086 let cache = QueryCache::with_limits(10_000, 100 * 1024);
1088
1089 for i in 0..1000u64 {
1091 let size = ((i % 10) + 1) as usize * 100; let payload = "x".repeat(size);
1093 cache.insert(QueryKey::new("stress", i, Language::Python), &payload, vec![]);
1094 }
1095
1096 assert!(
1098 cache.total_bytes() <= 100 * 1024,
1099 "total_bytes ({}) should be <= 102400 after stress test",
1100 cache.total_bytes()
1101 );
1102
1103 let result: Option<String> = cache.get(&QueryKey::new("stress", 999, Language::Python));
1105 assert!(result.is_some(), "most recent entry should be cached");
1106 }
1107
1108 #[test]
1109 fn test_estimated_bytes_accuracy() {
1110 let small = CacheEntry::new(vec![1, 2, 3], 0, vec![]);
1111 let large = CacheEntry::new(vec![0u8; 10_000], 0, vec![1, 2, 3]);
1112
1113 assert!(small.estimated_bytes() < large.estimated_bytes());
1114 assert!(small.estimated_bytes() > 0);
1115 assert!(
1117 large.estimated_bytes() >= 10_000,
1118 "estimated_bytes ({}) should be >= payload size",
1119 large.estimated_bytes()
1120 );
1121 }
1122
1123 #[test]
1124 fn test_default_max_bytes() {
1125 let cache = QueryCache::with_defaults();
1126 assert_eq!(cache.max_bytes, DEFAULT_MAX_BYTES);
1127 assert_eq!(cache.max_bytes, 512 * 1024 * 1024); }
1129
1130 mod proptest_cache {
1135 use super::*;
1136 use proptest::prelude::*;
1137
1138 fn recompute_bytes(cache: &QueryCache) -> usize {
1140 cache
1141 .entries
1142 .iter()
1143 .map(|e| e.value().estimated_bytes())
1144 .sum()
1145 }
1146
1147 #[derive(Debug, Clone)]
1149 enum CacheOp {
1150 Insert {
1151 key_id: u8,
1152 payload_len: usize,
1153 input_hash: u64,
1154 },
1155 InvalidateByInput(u64),
1156 InvalidateByKey(u8),
1157 Clear,
1158 }
1159
1160 fn arb_cache_op() -> impl Strategy<Value = CacheOp> {
1161 prop_oneof![
1162 (any::<u8>(), 0..2000usize, any::<u64>()).prop_map(|(k, p, h)| CacheOp::Insert {
1163 key_id: k,
1164 payload_len: p,
1165 input_hash: h % 16, }),
1167 (any::<u64>()).prop_map(|h| CacheOp::InvalidateByInput(h % 16)),
1168 (any::<u8>()).prop_map(CacheOp::InvalidateByKey),
1169 Just(CacheOp::Clear),
1170 ]
1171 }
1172
1173 proptest! {
1174 #[test]
1177 fn bytes_tracking_consistent(ops in prop::collection::vec(arb_cache_op(), 1..150)) {
1178 let cache = QueryCache::with_limits(500, 10_000_000);
1179
1180 for op in ops {
1181 match op {
1182 CacheOp::Insert { key_id, payload_len, input_hash } => {
1183 let key = QueryKey::new("prop", key_id as u64, Language::Python);
1184 let payload = vec![0u8; payload_len];
1185 cache.insert(key, &payload, vec![input_hash]);
1186 }
1187 CacheOp::InvalidateByInput(hash) => {
1188 cache.invalidate_by_input(hash);
1189 }
1190 CacheOp::InvalidateByKey(key_id) => {
1191 let key = QueryKey::new("prop", key_id as u64, Language::Python);
1192 cache.invalidate(&key);
1193 }
1194 CacheOp::Clear => {
1195 cache.clear();
1196 }
1197 }
1198 }
1199
1200 let tracked = cache.total_bytes();
1201 let actual = recompute_bytes(&cache);
1202 prop_assert_eq!(tracked, actual,
1203 "tracked bytes ({}) != recomputed bytes ({})", tracked, actual);
1204 }
1205
1206 #[test]
1208 fn entry_count_bounded(ops in prop::collection::vec(arb_cache_op(), 1..200)) {
1209 let max = 50;
1210 let cache = QueryCache::with_limits(max, 10_000_000);
1211
1212 for op in ops {
1213 match op {
1214 CacheOp::Insert { key_id, payload_len, input_hash } => {
1215 let key = QueryKey::new("prop", key_id as u64, Language::Python);
1216 let payload = vec![0u8; payload_len];
1217 cache.insert(key, &payload, vec![input_hash]);
1218 }
1219 CacheOp::InvalidateByInput(hash) => {
1220 cache.invalidate_by_input(hash);
1221 }
1222 CacheOp::InvalidateByKey(key_id) => {
1223 let key = QueryKey::new("prop", key_id as u64, Language::Python);
1224 cache.invalidate(&key);
1225 }
1226 CacheOp::Clear => {
1227 cache.clear();
1228 }
1229 }
1230 }
1231
1232 prop_assert!(cache.len() <= max,
1233 "cache size {} exceeds max {}", cache.len(), max);
1234 }
1235
1236 #[test]
1238 fn byte_limit_bounded(ops in prop::collection::vec(arb_cache_op(), 1..200)) {
1239 let max_bytes = 50_000;
1240 let cache = QueryCache::with_limits(500, max_bytes);
1241
1242 for op in ops {
1243 match op {
1244 CacheOp::Insert { key_id, payload_len, input_hash } => {
1245 let key = QueryKey::new("prop", key_id as u64, Language::Python);
1246 let payload = vec![0u8; payload_len];
1247 cache.insert(key, &payload, vec![input_hash]);
1248 }
1249 CacheOp::InvalidateByInput(hash) => {
1250 cache.invalidate_by_input(hash);
1251 }
1252 CacheOp::InvalidateByKey(key_id) => {
1253 let key = QueryKey::new("prop", key_id as u64, Language::Python);
1254 cache.invalidate(&key);
1255 }
1256 CacheOp::Clear => {
1257 cache.clear();
1258 }
1259 }
1260 }
1261
1262 prop_assert!(cache.total_bytes() <= max_bytes,
1263 "total bytes {} exceeds max {}", cache.total_bytes(), max_bytes);
1264 }
1265
1266 #[test]
1268 fn clear_resets_everything(
1269 inserts in prop::collection::vec((any::<u8>(), 0..500usize), 1..50)
1270 ) {
1271 let cache = QueryCache::with_limits(500, 10_000_000);
1272
1273 for (key_id, payload_len) in inserts {
1274 let key = QueryKey::new("prop", key_id as u64, Language::Python);
1275 cache.insert(key, &vec![0u8; payload_len], vec![]);
1276 }
1277
1278 cache.clear();
1279
1280 prop_assert_eq!(cache.len(), 0);
1281 prop_assert_eq!(cache.total_bytes(), 0);
1282 prop_assert_eq!(recompute_bytes(&cache), 0);
1283 }
1284
1285 #[test]
1288 fn replace_in_place_no_leak(
1289 sizes in prop::collection::vec(0..5000usize, 2..20)
1290 ) {
1291 let cache = QueryCache::with_limits(500, 10_000_000);
1292 let key = QueryKey::new("same", 42, Language::Python);
1293
1294 for size in &sizes {
1295 cache.insert(key.clone(), &vec![0u8; *size], vec![]);
1296 }
1297
1298 prop_assert_eq!(cache.len(), 1);
1300 let tracked = cache.total_bytes();
1302 let actual = recompute_bytes(&cache);
1303 prop_assert_eq!(tracked, actual);
1304 }
1305 }
1306 }
1307}