Skip to main content

vyre_wgpu/runtime/cache/
tiered_cache.rs

1use crate::runtime::cache::lru::AccessTracker;
2use rustc_hash::FxHashMap;
3
4/// Metadata for a cached entry.
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6#[non_exhaustive]
7pub struct CacheEntry {
8    /// Unique identifier for the entry.
9    pub key: u64,
10    /// Size of the entry in bytes.
11    pub size: u64,
12    /// Index of the tier the entry currently resides in.
13    pub tier: usize,
14}
15
16/// A single cache tier with a fixed capacity.
17#[non_exhaustive]
18pub struct CacheTier {
19    /// Human-readable name for the tier.
20    pub name: String,
21    /// Total capacity of the tier in bytes.
22    pub capacity: u64,
23    /// Currently used bytes in the tier.
24    pub used: u64,
25    pub(crate) entries: FxHashMap<u64, CacheEntry>,
26}
27
28impl CacheTier {
29    /// Create a new empty tier.
30    #[inline]
31    pub fn new(name: impl Into<String>, capacity: u64) -> Self {
32        Self {
33            name: name.into(),
34            capacity,
35            used: 0,
36            entries: FxHashMap::default(),
37        }
38    }
39}
40
41/// Access statistics used by [`TierPolicy`] implementations.
42#[non_exhaustive]
43pub struct AccessStats {
44    /// Number of recorded accesses.
45    pub frequency: u32,
46    /// Position in the recency queue (0 = most recent).
47    pub recency_rank: usize,
48    /// Size of the entry in bytes.
49    pub size: u64,
50}
51
52/// Policy that decides promotion and eviction behavior.
53pub trait TierPolicy: Send + Sync {
54    /// Return `true` if the entry should be promoted to a faster tier.
55    fn should_promote(&self, key: u64, stats: &AccessStats) -> bool;
56
57    /// Select a candidate for eviction from the given tier.
58    fn eviction_candidate(
59        &self,
60        tier: usize,
61        entries: &FxHashMap<u64, CacheEntry>,
62        tracker: &AccessTracker,
63    ) -> Option<u64>;
64}
65
66/// LRU eviction policy with frequency-based promotion.
67#[non_exhaustive]
68pub struct LruPolicy {
69    /// Minimum access frequency required for promotion.
70    pub promote_threshold: u32,
71}
72
73impl LruPolicy {
74    /// Default access threshold for promotion.
75    pub const DEFAULT_THRESHOLD: u32 = 3;
76
77    /// Create a new policy with the given promotion threshold.
78    #[inline]
79    pub fn new(promote_threshold: u32) -> Self {
80        Self { promote_threshold }
81    }
82}
83
84impl Default for LruPolicy {
85    fn default() -> Self {
86        Self::new(Self::DEFAULT_THRESHOLD)
87    }
88}
89
90impl TierPolicy for LruPolicy {
91    fn should_promote(&self, _key: u64, stats: &AccessStats) -> bool {
92        stats.frequency >= self.promote_threshold
93    }
94
95    fn eviction_candidate(
96        &self,
97        _tier: usize,
98        entries: &FxHashMap<u64, CacheEntry>,
99        tracker: &AccessTracker,
100    ) -> Option<u64> {
101        for (key, _meta) in tracker.iter_coldest() {
102            if entries.contains_key(key) {
103                return Some(*key);
104            }
105        }
106        entries.keys().next().copied()
107    }
108}
109
110/// Errors that can occur during cache operations.
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112#[non_exhaustive]
113pub enum CacheError {
114    /// The requested key does not exist in the cache.
115    KeyNotFound,
116    /// The entry is too large to fit in any tier.
117    EntryTooLarge,
118}
119
120impl std::fmt::Display for CacheError {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            Self::KeyNotFound => write!(
124                f,
125                "Key not found in cache. Fix: verify the key was inserted before operating on it."
126            ),
127            Self::EntryTooLarge => write!(
128                f,
129                "Entry size exceeds the capacity of the largest tier. Fix: reduce the buffer size or increase the tier capacity."
130            ),
131        }
132    }
133}
134
135impl std::error::Error for CacheError {}
136
137/// Generic tiered cache for GPU buffers.
138///
139/// Tracks hot/cold buffers. [`TierPolicy`] decides promotion and eviction.
140/// This is the vyre primitive that helix builds inference intelligence on top of.
141#[non_exhaustive]
142pub struct TieredCache<P: TierPolicy = LruPolicy> {
143    pub(crate) tiers: Vec<CacheTier>,
144    pub(crate) tracker: AccessTracker,
145    pub(crate) policy: P,
146}
147
148impl TieredCache<LruPolicy> {
149    /// Create a new cache with the given tiers and a default [`LruPolicy`].
150    #[inline]
151    pub fn new(tiers: Vec<CacheTier>) -> Self {
152        Self::with_policy(tiers, LruPolicy::default())
153    }
154}
155
156impl<P: TierPolicy> TieredCache<P> {
157    /// Create a new cache with a custom policy implementation.
158    #[inline]
159    pub fn with_policy(tiers: Vec<CacheTier>, policy: P) -> Self {
160        Self {
161            tiers,
162            tracker: AccessTracker::new(),
163            policy,
164        }
165    }
166
167    /// Return a reference to the entry with the given key, if it exists.
168    #[inline]
169    pub fn get(&self, key: u64) -> Option<&CacheEntry> {
170        self.tiers.iter().find_map(|tier| tier.entries.get(&key))
171    }
172
173    /// Insert a new entry into the lowest tier that can fit it.
174    ///
175    /// # Errors
176    ///
177    /// Returns [`CacheError::EntryTooLarge`] when no tier can hold the entry.
178    #[inline]
179    pub fn insert(&mut self, key: u64, size: u64) -> Result<(), CacheError> {
180        if self.get(key).is_some() {
181            self.evict(key);
182        }
183        self.tracker.set_size(key, size);
184        self.insert_into_tier(key, size, 0)
185    }
186
187    /// Record an access for the given key.
188    #[inline]
189    pub fn record_access(&mut self, key: u64) {
190        if self.get(key).is_some() {
191            self.tracker.record(key);
192        }
193    }
194
195    /// Promote the entry to the next faster tier if the policy allows it.
196    ///
197    /// # Errors
198    ///
199    /// Returns [`CacheError::KeyNotFound`] when the key does not exist.
200    #[inline]
201    pub fn promote(&mut self, key: u64) -> Result<(), CacheError> {
202        let entry = self.get(key).copied().ok_or(CacheError::KeyNotFound)?;
203        let stats = self.tracker.stats(key).ok_or(CacheError::KeyNotFound)?;
204        if !self.policy.should_promote(key, &stats) {
205            return Ok(());
206        }
207        let target = entry.tier.saturating_add(1);
208        if target >= self.tiers.len() {
209            return Ok(());
210        }
211        let size = entry.size;
212        self.remove_entry(key);
213        self.move_into_tier(key, size, target, entry.tier)
214    }
215
216    /// Demote the entry to the next slower tier.
217    ///
218    /// # Errors
219    ///
220    /// Returns [`CacheError::KeyNotFound`] when the key does not exist.
221    #[inline]
222    pub fn demote(&mut self, key: u64) -> Result<(), CacheError> {
223        let entry = self.get(key).copied().ok_or(CacheError::KeyNotFound)?;
224        if entry.tier == 0 {
225            return Ok(());
226        }
227        let target = entry.tier - 1;
228        let size = entry.size;
229        self.remove_entry(key);
230        self.move_into_tier(key, size, target, entry.tier)
231    }
232
233    fn insert_into_tier(
234        &mut self,
235        key: u64,
236        size: u64,
237        mut start: usize,
238    ) -> Result<(), CacheError> {
239        while start < self.tiers.len() {
240            if size > self.tiers[start].capacity {
241                start += 1;
242                continue;
243            }
244            if self.make_room(start, size) {
245                self.tiers[start].used = self.tiers[start].used.saturating_add(size);
246                self.tiers[start].entries.insert(
247                    key,
248                    CacheEntry {
249                        key,
250                        size,
251                        tier: start,
252                    },
253                );
254                return Ok(());
255            }
256            start += 1;
257        }
258        Err(CacheError::EntryTooLarge)
259    }
260
261    fn move_into_tier(
262        &mut self,
263        key: u64,
264        size: u64,
265        target: usize,
266        fallback: usize,
267    ) -> Result<(), CacheError> {
268        if self.make_room(target, size) {
269            self.tiers[target].used = self.tiers[target].used.saturating_add(size);
270            self.tiers[target].entries.insert(
271                key,
272                CacheEntry {
273                    key,
274                    size,
275                    tier: target,
276                },
277            );
278            Ok(())
279        } else {
280            self.insert_into_tier(key, size, fallback)
281        }
282    }
283
284    fn make_room(&mut self, tier: usize, size: u64) -> bool {
285        loop {
286            let used = self.tiers[tier].used;
287            let cap = self.tiers[tier].capacity;
288            if used.saturating_add(size) <= cap {
289                return true;
290            }
291            let candidate = {
292                let entries = &self.tiers[tier].entries;
293                self.policy.eviction_candidate(tier, entries, &self.tracker)
294            };
295            if let Some(key) = candidate {
296                self.evict_from_tier(key, tier);
297            } else {
298                return false;
299            }
300        }
301    }
302
303    fn remove_entry(&mut self, key: u64) -> Option<CacheEntry> {
304        for tier in &mut self.tiers {
305            if let Some(entry) = tier.entries.remove(&key) {
306                tier.used = tier.used.saturating_sub(entry.size);
307                return Some(entry);
308            }
309        }
310        None
311    }
312
313    fn evict(&mut self, key: u64) -> Option<CacheEntry> {
314        for tier in &mut self.tiers {
315            if let Some(entry) = tier.entries.remove(&key) {
316                tier.used = tier.used.saturating_sub(entry.size);
317                self.tracker.remove(key);
318                return Some(entry);
319            }
320        }
321        None
322    }
323
324    fn evict_from_tier(&mut self, key: u64, tier: usize) -> Option<CacheEntry> {
325        let tier = &mut self.tiers[tier];
326        let entry = tier.entries.remove(&key)?;
327        tier.used = tier.used.saturating_sub(entry.size);
328        self.tracker.remove(key);
329        Some(entry)
330    }
331}