Skip to main content

set_associative/
lib.rs

1//! Set-associative cache with CLOCK eviction, SIMD tag matching, and custom allocator support.
2
3#![allow(clippy::identity_op)]
4
5use std::hash::{BuildHasher, Hash};
6use std::marker::PhantomData;
7use std::mem::MaybeUninit;
8use std::ptr::NonNull;
9
10pub mod alloc;
11use alloc::{Allocator, Global};
12
13/// `#[cold]` marks the function as rarely called, making the opposite branch
14/// the predicted path. Used to emulate `likely`/`unlikely` on stable.
15#[inline(always)]
16#[cold]
17fn cold_path() {}
18
19/// Hint: condition is likely true.
20#[inline(always)]
21#[allow(unused)]
22fn likely(b: bool) -> bool {
23    if b {
24        true
25    } else {
26        cold_path();
27        false
28    }
29}
30
31/// Hint: condition is likely false.
32#[inline(always)]
33fn unlikely(b: bool) -> bool {
34    if b {
35        cold_path();
36        true
37    } else {
38        false
39    }
40}
41
42/// Strategy for extracting a key from a stored value.
43///
44/// Allows storing only values when the key is embedded in the value.
45/// For simple `(K, V)` pairs, use [`PairExtract`].
46pub trait KeyExtract {
47    /// The key type used for lookups.
48    type Key: Hash + Eq;
49    /// The value type stored in the cache.
50    type Value;
51
52    /// Extract a key reference from a stored value.
53    fn extract(value: &Self::Value) -> &Self::Key;
54}
55
56/// Standard `(K, V)` extraction — key is the first element.
57pub struct PairExtract<K, V>(PhantomData<fn() -> (K, V)>);
58
59impl<K: Hash + Eq, V> KeyExtract for PairExtract<K, V> {
60    type Key = K;
61    type Value = (K, V);
62
63    #[inline]
64    fn extract(value: &(K, V)) -> &K {
65        &value.0
66    }
67}
68
69/// Key equivalence trait.
70///
71/// This trait defines the function used to compare the input value with the
72/// cache keys during a lookup operation such as [`SetAssociativeCache::get`].
73/// It is provided with a blanket implementation based on the
74/// [`Borrow`](core::borrow::Borrow) trait.
75///
76/// # Correctness
77///
78/// Equivalent values must hash to the same value.
79pub trait Equivalent<K: ?Sized> {
80    /// Checks if this value is equivalent to the given key.
81    fn equivalent(&self, key: &K) -> bool;
82}
83
84impl<Q: ?Sized, K: ?Sized> Equivalent<K> for Q
85where
86    Q: Eq,
87    K: core::borrow::Borrow<Q>,
88{
89    #[inline(always)]
90    fn equivalent(&self, key: &K) -> bool {
91        self == key.borrow()
92    }
93}
94
95// =============================================================================
96// CacheLayout
97// =============================================================================
98
99/// Compile-time cache geometry parameters.
100pub trait CacheLayout {
101    /// Number of ways per set (2, 4, or 16).
102    const WAYS: u64;
103    /// Tag bits per entry (8 or 16).
104    const TAG_BITS: u64;
105    /// CLOCK counter bits per entry (1, 2, or 4).
106    const CLOCK_BITS: u64;
107    /// Hardware cache line size in bytes (power of two).
108    const CACHE_LINE_SIZE: u64;
109}
110
111/// Default cache layout: 16 ways, 8-bit tags, 2-bit CLOCK counters, 64-byte cache lines.
112pub struct DefaultLayout;
113
114impl CacheLayout for DefaultLayout {
115    const WAYS: u64 = 16;
116    const TAG_BITS: u64 = 8;
117    const CLOCK_BITS: u64 = 2;
118    const CACHE_LINE_SIZE: u64 = 64;
119}
120
121/// Whether an upsert updated an existing entry or inserted a new one.
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum UpdateOrInsert {
124    /// Existing entry was updated.
125    Update,
126    /// New entry was inserted.
127    Insert,
128}
129
130/// Result of an [`SetAssociativeCache::upsert`] operation.
131#[derive(Debug)]
132pub struct UpsertResult<V> {
133    /// Slot index where the value was placed.
134    pub index: usize,
135    /// Whether the operation was an update or insert.
136    pub updated: UpdateOrInsert,
137    /// The evicted value, if any.
138    pub evicted: Option<V>,
139}
140
141/// Cache hit/miss counters.
142#[derive(Debug, Default, Clone)]
143pub struct Metrics {
144    /// Number of cache hits.
145    pub hits: u64,
146    /// Number of cache misses.
147    pub misses: u64,
148    /// Current number of live entries.
149    pub value_count: u64,
150}
151
152#[inline(always)]
153const fn log2(x: u64) -> u64 {
154    assert!(x.is_power_of_two() && x > 0);
155    x.trailing_zeros() as u64
156}
157
158/// Fast alternative to modulo reduction.
159/// See <https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/>
160#[inline(always)]
161pub fn fastrange(word: u64, p: u64) -> u64 {
162    ((word as u128).wrapping_mul(p as u128) >> 64) as u64
163}
164
165/// Integer division rounding up.
166#[inline]
167pub fn div_ceil(numerator: u64, denominator: u64) -> u64 {
168    assert!(denominator > 0);
169    if numerator == 0 {
170        return 0;
171    }
172    numerator.div_ceil(denominator)
173}
174
175// =============================================================================
176// AlignedBuf — raw aligned allocation via Allocator
177// =============================================================================
178
179struct AlignedBuf<T> {
180    ptr: NonNull<T>,
181    len: usize,
182    layout: std::alloc::Layout,
183}
184
185impl<T> AlignedBuf<T> {
186    /// Allocate `len` elements of type `T` with at least `align` byte alignment.
187    /// All bytes are zeroed.
188    fn alloc_zeroed(len: usize, align: usize, alloc: &impl Allocator) -> Self {
189        if len == 0 || std::mem::size_of::<T>() == 0 {
190            return Self {
191                ptr: NonNull::dangling(),
192                len,
193                layout: std::alloc::Layout::from_size_align(
194                    0,
195                    align.max(std::mem::align_of::<T>()),
196                )
197                .unwrap(),
198            };
199        }
200        let size = len * std::mem::size_of::<T>();
201        let align = align.max(std::mem::align_of::<T>());
202        let layout = std::alloc::Layout::from_size_align(size, align).unwrap();
203        let slice = alloc::do_alloc(alloc, layout).expect("allocation failed");
204        let ptr = slice.as_ptr().cast::<u8>();
205        // SAFETY: ptr is valid for `size` bytes, freshly allocated.
206        unsafe { std::ptr::write_bytes(ptr, 0, size) };
207        Self {
208            // SAFETY: ptr is non-null (from successful allocation) and properly aligned.
209            ptr: unsafe { NonNull::new_unchecked(ptr.cast::<T>()) },
210            len,
211            layout,
212        }
213    }
214    #[inline(always)]
215    fn as_slice(&self) -> &[T] {
216        if unlikely(self.len == 0) {
217            return &[];
218        }
219        // SAFETY: ptr is valid for `len` elements, allocated and zeroed in alloc_zeroed.
220        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
221    }
222    #[inline(always)]
223    fn as_mut_slice(&mut self) -> &mut [T] {
224        if unlikely(self.len == 0) {
225            return &mut [];
226        }
227        // SAFETY: ptr is valid for `len` elements, we have exclusive access via &mut self.
228        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
229    }
230
231    /// Deallocate using the given allocator. Must use the same allocator that allocated.
232    ///
233    /// # SAFETY:
234    /// Must only be called once, with the same allocator used for allocation.
235    unsafe fn dealloc(&self, alloc: &impl Allocator) {
236        if unlikely(self.layout.size() == 0) {
237            return;
238        }
239        // SAFETY: caller guarantees this is the matching allocator and single dealloc call.
240        unsafe {
241            alloc.deallocate(
242                NonNull::new_unchecked(self.ptr.as_ptr().cast::<u8>()),
243                self.layout,
244            );
245        }
246    }
247    #[inline(always)]
248    fn fill(&mut self, val: T)
249    where
250        T: Copy,
251    {
252        self.as_mut_slice().fill(val);
253    }
254}
255
256/// Bit-packed array of small unsigned integers (1, 2, or 4 bits) stored in u64 words.
257#[derive(Debug)]
258pub struct PackedArray {
259    uint_bits: u32,
260    words: PackedWords,
261}
262
263#[derive(Debug)]
264enum PackedWords {
265    Vec(Vec<u64>),
266    Buf {
267        ptr: NonNull<u64>,
268        len: usize,
269        layout: std::alloc::Layout,
270    },
271}
272
273impl PackedWords {
274    #[inline]
275    fn as_slice(&self) -> &[u64] {
276        match self {
277            PackedWords::Vec(v) => v,
278            PackedWords::Buf { ptr, len, .. } => {
279                if *len == 0 {
280                    return &[];
281                }
282                // SAFETY: ptr is valid for `len` u64 elements, from AlignedBuf allocation.
283                unsafe { std::slice::from_raw_parts(ptr.as_ptr(), *len) }
284            }
285        }
286    }
287    #[inline]
288    fn as_mut_slice(&mut self) -> &mut [u64] {
289        match self {
290            PackedWords::Vec(v) => v,
291            PackedWords::Buf { ptr, len, .. } => {
292                if *len == 0 {
293                    return &mut [];
294                }
295                // SAFETY: ptr is valid for `len` u64 elements, we have exclusive access.
296                unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr(), *len) }
297            }
298        }
299    }
300}
301
302impl PackedArray {
303    /// Create a new packed array using the global allocator (Vec-backed).
304    pub fn new(uint_bits: u32, count: u64) -> Self {
305        assert!(uint_bits == 1 || uint_bits == 2 || uint_bits == 4);
306        let total_bits = count * uint_bits as u64;
307        let num_words = div_ceil(total_bits, 64);
308        Self {
309            uint_bits,
310            words: PackedWords::Vec(vec![0u64; num_words as usize]),
311        }
312    }
313
314    fn new_aligned(uint_bits: u32, count: u64, align: usize, alloc: &impl Allocator) -> Self {
315        assert!(uint_bits == 1 || uint_bits == 2 || uint_bits == 4);
316        let total_bits = count * uint_bits as u64;
317        let num_words = div_ceil(total_bits, 64) as usize;
318        let buf = AlignedBuf::<u64>::alloc_zeroed(num_words, align, alloc);
319        Self {
320            uint_bits,
321            words: PackedWords::Buf {
322                ptr: buf.ptr,
323                len: buf.len,
324                layout: buf.layout,
325            },
326        }
327    }
328
329    /// Get the value at `index`.
330    #[inline]
331    pub fn get(&self, index: u64) -> u64 {
332        let words = self.words.as_slice();
333        let uint_bits = self.uint_bits;
334        let uints_per_word = 64 / uint_bits;
335        let word_idx = (index / uints_per_word as u64) as usize;
336        let bit_offset = (index % uints_per_word as u64) * uint_bits as u64;
337        let mask = (1u64 << uint_bits) - 1;
338        (words[word_idx] >> bit_offset) & mask
339    }
340
341    /// Set the value at `index`.
342    #[inline]
343    pub fn set(&mut self, index: u64, value: u64) {
344        let words = self.words.as_mut_slice();
345        let uint_bits = self.uint_bits;
346        let uints_per_word = 64 / uint_bits;
347        let word_idx = (index / uints_per_word as u64) as usize;
348        let bit_offset = (index % uints_per_word as u64) * uint_bits as u64;
349        let mask = (1u64 << uint_bits) - 1;
350        words[word_idx] &= !(mask << bit_offset);
351        words[word_idx] |= (value & mask) << bit_offset;
352    }
353
354    /// Zero all entries.
355    pub fn clear(&mut self) {
356        self.words.as_mut_slice().fill(0);
357    }
358
359    /// View the underlying u64 words.
360    pub fn words(&self) -> &[u64] {
361        self.words.as_slice()
362    }
363
364    /// Deallocate the buffer variant. No-op for Vec variant.
365    unsafe fn dealloc(&self, alloc: &impl Allocator) {
366        if let PackedWords::Buf { ptr, layout, .. } = &self.words
367            && layout.size() > 0
368        {
369            // SAFETY: caller guarantees matching allocator and single dealloc.
370            unsafe {
371                alloc.deallocate(NonNull::new_unchecked(ptr.as_ptr().cast::<u8>()), *layout);
372            }
373        }
374    }
375}
376
377mod simd {
378    /// Compare `ways` u8 tags against `needle`, return bitmask of matching positions.
379    /// Runtime SIMD on x86_64 (SSE2), scalar fallback everywhere else.
380    #[inline]
381    pub(crate) fn search_tags(tags: &[u8], needle: u8, ways: u64) -> u64 {
382        #[cfg(target_arch = "x86_64")]
383        {
384            if ways == 16 && is_x86_feature_detected!("sse2") {
385                // SAFETY: tags slice has 16 elements and is 16-byte aligned (from AlignedBuf).
386                return unsafe { search_tags_16_sse2(tags, needle) };
387            }
388        }
389        search_tags_scalar(tags, needle, ways)
390    }
391
392    /// Compare `ways` u16 tags against `needle`, return bitmask of matching positions.
393    /// Runtime SIMD on x86_64 (AVX2 -> SSE2), scalar fallback everywhere else.
394    #[inline]
395    pub(crate) fn search_tags_u16(tags: &[u16], needle: u16, ways: u64) -> u64 {
396        #[cfg(target_arch = "x86_64")]
397        {
398            if ways == 16 {
399                if is_x86_feature_detected!("avx2") {
400                    // SAFETY: tags slice has 16 elements and is 32-byte aligned (from AlignedBuf).
401                    return unsafe { search_tags_u16_16_avx2(tags, needle) };
402                }
403                if is_x86_feature_detected!("sse2") {
404                    // SAFETY: tags slice has 16 elements and is 16-byte aligned (from AlignedBuf).
405                    return unsafe { search_tags_u16_16_sse2(tags, needle) };
406                }
407            }
408        }
409        search_tags_u16_scalar(tags, needle, ways)
410    }
411
412    // ---- x86_64 SSE2: 16 × u8 ----
413
414    #[cfg(target_arch = "x86_64")]
415    #[target_feature(enable = "sse2")]
416    unsafe fn search_tags_16_sse2(tags: &[u8], needle: u8) -> u64 {
417        use std::arch::x86_64::*;
418        // SAFETY: caller guarantees tags is 16-byte aligned and has >= 16 elements.
419        unsafe {
420            let data = _mm_load_si128(tags.as_ptr().cast::<__m128i>());
421            let splat = _mm_set1_epi8(needle as i8);
422            let cmp = _mm_cmpeq_epi8(data, splat);
423            let mask = _mm_movemask_epi8(cmp) as u32;
424            (mask & 0xFFFF) as u64
425        }
426    }
427
428    // ---- x86_64 SSE2: 16 × u16 (two 128-bit passes) ----
429
430    #[cfg(target_arch = "x86_64")]
431    #[target_feature(enable = "sse2")]
432    unsafe fn search_tags_u16_16_sse2(tags: &[u16], needle: u16) -> u64 {
433        use std::arch::x86_64::*;
434        // SAFETY: caller guarantees tags is 32-byte aligned and has >= 16 elements.
435        // Each half (8 × u16 = 16 bytes) is 16-byte aligned.
436        unsafe {
437            let splat = _mm_set1_epi16(needle as i16);
438
439            let lo = _mm_load_si128(tags.as_ptr().cast::<__m128i>());
440            let cmp_lo = _mm_cmpeq_epi16(lo, splat);
441            let packed_lo = _mm_packs_epi16(cmp_lo, _mm_setzero_si128());
442            let mask_lo = _mm_movemask_epi8(packed_lo) as u32 & 0xFF;
443
444            let hi = _mm_load_si128(tags.as_ptr().add(8).cast::<__m128i>());
445            let cmp_hi = _mm_cmpeq_epi16(hi, splat);
446            let packed_hi = _mm_packs_epi16(cmp_hi, _mm_setzero_si128());
447            let mask_hi = _mm_movemask_epi8(packed_hi) as u32 & 0xFF;
448
449            (mask_lo | (mask_hi << 8)) as u64
450        }
451    }
452
453    // ---- x86_64 AVX2: 16 × u16 (single 256-bit pass) ----
454
455    #[cfg(target_arch = "x86_64")]
456    #[target_feature(enable = "avx2")]
457    unsafe fn search_tags_u16_16_avx2(tags: &[u16], needle: u16) -> u64 {
458        use std::arch::x86_64::*;
459        // SAFETY: caller guarantees tags is 32-byte aligned and has >= 16 elements.
460        unsafe {
461            let data = _mm256_load_si256(tags.as_ptr().cast::<__m256i>());
462            let splat = _mm256_set1_epi16(needle as i16);
463            let cmp = _mm256_cmpeq_epi16(data, splat);
464            let packed = _mm256_packs_epi16(cmp, _mm256_setzero_si256());
465            let permuted = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);
466            let mask = _mm256_movemask_epi8(permuted) as u32;
467            (mask & 0xFFFF) as u64
468        }
469    }
470
471    // ---- Scalar fallback ----
472
473    #[inline]
474    fn search_tags_scalar(tags: &[u8], needle: u8, ways: u64) -> u64 {
475        let mut bits: u64 = 0;
476        for (i, &tag) in tags.iter().enumerate().take(ways as usize) {
477            if tag == needle {
478                bits |= 1 << i;
479            }
480        }
481        bits
482    }
483
484    #[inline]
485    fn search_tags_u16_scalar(tags: &[u16], needle: u16, ways: u64) -> u64 {
486        let mut bits: u64 = 0;
487        for (i, &tag) in tags.iter().enumerate().take(ways as usize) {
488            if tag == needle {
489                bits |= 1 << i;
490            }
491        }
492        bits
493    }
494}
495
496enum TagStore {
497    U8(AlignedBuf<u8>),
498    U16(AlignedBuf<u16>),
499}
500
501impl TagStore {
502    fn clear(&mut self) {
503        match self {
504            TagStore::U8(buf) => buf.fill(0),
505            TagStore::U16(buf) => buf.fill(0),
506        }
507    }
508
509    #[cfg(test)]
510    fn all_zero(&self) -> bool {
511        match self {
512            TagStore::U8(buf) => buf.as_slice().iter().all(|&t| t == 0),
513            TagStore::U16(buf) => buf.as_slice().iter().all(|&t| t == 0),
514        }
515    }
516
517    /// # Safety
518    /// Must only be called once, with the same allocator used for allocation.
519    unsafe fn dealloc(&self, alloc: &impl Allocator) {
520        match self {
521            // SAFETY: forwarded from caller's safety contract.
522            TagStore::U8(buf) => unsafe { buf.dealloc(alloc) },
523            // SAFETY: forwarded from caller's safety contract.
524            TagStore::U16(buf) => unsafe { buf.dealloc(alloc) },
525        }
526    }
527}
528
529/// Internal representation of a tag: either 8 or 16 bits, stored as u16.
530type Tag = u16;
531
532struct SetView {
533    tag: Tag,
534    offset: u64,
535}
536
537/// A set-associative cache with CLOCK eviction, SIMD tag matching, and custom allocator support.
538pub struct SetAssociativeCache<E: KeyExtract, S: BuildHasher, L: CacheLayout, A: Allocator = Global>
539{
540    sets: u64,
541    tag_store: TagStore,
542    values: AlignedBuf<MaybeUninit<E::Value>>,
543    counts: PackedArray,
544    clocks: PackedArray,
545    /// Cache hit/miss metrics.
546    pub metrics: Metrics,
547    hash_builder: S,
548    alloc: A,
549    _extract: PhantomData<E>,
550    _layout: PhantomData<L>,
551}
552
553impl<E: KeyExtract, S: BuildHasher + Default, L: CacheLayout> SetAssociativeCache<E, S, L>
554where
555    E::Key: Hash + Eq,
556{
557    /// Create a new cache with default hasher and global allocator.
558    pub fn new(value_count_max: u64) -> Self {
559        Self::with_hasher(value_count_max, S::default())
560    }
561}
562
563impl<E: KeyExtract, S: BuildHasher, L: CacheLayout> SetAssociativeCache<E, S, L>
564where
565    E::Key: Hash + Eq,
566{
567    /// Create a new cache with the given hasher and global allocator.
568    pub fn with_hasher(value_count_max: u64, hash_builder: S) -> Self {
569        Self::with_hasher_and_alloc(value_count_max, hash_builder, Global)
570    }
571}
572
573impl<E: KeyExtract, S: BuildHasher, L: CacheLayout, A: Allocator> SetAssociativeCache<E, S, L, A>
574where
575    E::Key: Hash + Eq,
576{
577    /// Create a new cache with the given hasher and allocator.
578    pub fn with_hasher_and_alloc(value_count_max: u64, hash_builder: S, alloc: A) -> Self {
579        const { assert!(L::WAYS == 2 || L::WAYS == 4 || L::WAYS == 16) };
580        const { assert!(L::TAG_BITS == 8 || L::TAG_BITS == 16) };
581        const { assert!(L::CLOCK_BITS == 1 || L::CLOCK_BITS == 2 || L::CLOCK_BITS == 4) };
582        const { assert!(L::CACHE_LINE_SIZE.is_power_of_two()) };
583
584        let ways = L::WAYS;
585        let sets = value_count_max / ways;
586        let cache_line_size = L::CACHE_LINE_SIZE as usize;
587
588        assert!(value_count_max > 0);
589        assert!(value_count_max >= ways);
590        assert!(value_count_max.is_multiple_of(ways));
591
592        let value_count_max_multiple = Self::value_count_max_multiple();
593        assert!(
594            value_count_max.is_multiple_of(value_count_max_multiple),
595            "value_count_max ({}) must be a multiple of {}",
596            value_count_max,
597            value_count_max_multiple,
598        );
599
600        // Tags: align to max(cache_line_size, 32) for AVX2 _mm256_load_si256
601        let tag_align = cache_line_size.max(32);
602        let tag_store = match L::TAG_BITS {
603            8 => TagStore::U8(AlignedBuf::alloc_zeroed(
604                value_count_max as usize,
605                tag_align,
606                &alloc,
607            )),
608            16 => TagStore::U16(AlignedBuf::alloc_zeroed(
609                value_count_max as usize,
610                tag_align,
611                &alloc,
612            )),
613            _ => unreachable!(),
614        };
615
616        let values = AlignedBuf::<MaybeUninit<E::Value>>::alloc_zeroed(
617            value_count_max as usize,
618            cache_line_size,
619            &alloc,
620        );
621        let counts = PackedArray::new_aligned(
622            L::CLOCK_BITS as u32,
623            value_count_max,
624            cache_line_size,
625            &alloc,
626        );
627        let clock_hand_bits = log2(L::WAYS);
628        let clocks =
629            PackedArray::new_aligned(clock_hand_bits as u32, sets, cache_line_size, &alloc);
630
631        Self {
632            sets,
633            tag_store,
634            values,
635            counts,
636            clocks,
637            metrics: Metrics::default(),
638            hash_builder,
639            alloc,
640            _extract: PhantomData,
641            _layout: PhantomData,
642        }
643    }
644
645    /// Minimum alignment multiple that `value_count_max` must satisfy.
646    pub fn value_count_max_multiple() -> u64 {
647        let cache_line_size = L::CACHE_LINE_SIZE;
648        let ways = L::WAYS;
649        let clock_bits = L::CLOCK_BITS;
650        let value_size = std::mem::size_of::<E::Value>() as u64;
651        let values_part =
652            (value_size.max(cache_line_size) / value_size.min(cache_line_size)) * ways;
653        let counts_part = (cache_line_size * 8) / clock_bits;
654        values_part.max(counts_part)
655    }
656
657    /// Reset the cache, clearing all entries and metrics.
658    pub fn reset(&mut self) {
659        let total_slots = self.sets * L::WAYS;
660        for i in 0..total_slots {
661            if self.counts.get(i) > 0 {
662                // SAFETY: count > 0 means the slot was initialized via MaybeUninit::write.
663                unsafe { self.values.as_mut_slice()[i as usize].assume_init_drop() };
664            }
665        }
666        self.tag_store.clear();
667        self.counts.clear();
668        self.clocks.clear();
669        self.metrics = Metrics::default();
670    }
671
672    /// Look up a key, returning its index if found.
673    pub fn get_index<Q>(&mut self, key: &Q) -> Option<usize>
674    where
675        Q: Hash + Equivalent<E::Key> + ?Sized,
676    {
677        let set = self.associate(key);
678        if let Some(way) = self.search(&set, key) {
679            self.metrics.hits += 1;
680            let idx = set.offset + way as u64;
681            let count = self.counts.get(idx);
682            let max = (1u64 << L::CLOCK_BITS) - 1;
683            self.counts.set(idx, count.saturating_add(1).min(max));
684            Some(idx as usize)
685        } else {
686            self.metrics.misses += 1;
687            None
688        }
689    }
690
691    /// Look up a key, returning a reference to the value if found.
692    pub fn get<Q>(&mut self, key: &Q) -> Option<&E::Value>
693    where
694        Q: Hash + Equivalent<E::Key> + ?Sized,
695    {
696        let index = self.get_index(key)?;
697        // SAFETY: get_index only returns an index where count > 0, meaning the slot
698        // was initialized via MaybeUninit::write in upsert.
699        Some(unsafe { self.values.as_slice()[index].assume_init_ref() })
700    }
701
702    /// Look up a key, returning a mutable reference to the value if found.
703    pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut E::Value>
704    where
705        Q: Hash + Equivalent<E::Key> + ?Sized,
706    {
707        let index = self.get_index(key)?;
708        // SAFETY: get_index only returns an index where count > 0, meaning the slot
709        // was initialized via MaybeUninit::write in upsert.
710        Some(unsafe { self.values.as_mut_slice()[index].assume_init_mut() })
711    }
712
713    /// Remove a key from the cache if present.
714    pub fn remove<Q>(&mut self, key: &Q) -> Option<E::Value>
715    where
716        Q: Hash + Equivalent<E::Key> + ?Sized,
717    {
718        let set = self.associate(key);
719        let way = self.search(&set, key)?;
720        let idx = set.offset + way as u64;
721        // SAFETY: search only returns a way where count > 0, meaning initialized.
722        // assume_init_read moves the value out; we set count to 0 so the slot
723        // is treated as uninitialized from here on.
724        let removed = unsafe { self.values.as_slice()[idx as usize].assume_init_read() };
725        self.counts.set(idx, 0);
726        self.metrics.value_count -= 1;
727        Some(removed)
728    }
729
730    /// Hint that the key is less likely to be accessed in the future.
731    pub fn demote<Q>(&mut self, key: &Q)
732    where
733        Q: Hash + Equivalent<E::Key> + ?Sized,
734    {
735        let set = self.associate(key);
736        if let Some(way) = self.search(&set, key) {
737            self.counts.set(set.offset + way as u64, 1);
738        }
739    }
740
741    /// Upsert a value, evicting an older entry if needed.
742    pub fn upsert(&mut self, value: E::Value) -> UpsertResult<E::Value> {
743        // Extract key twice (inline, cheap) to avoid requiring Key: Copy.
744        // Each temporary borrow of `value` is released before the next statement.
745        let set = self.associate(E::extract(&value));
746        let existing_way = self.search(&set, E::extract(&value));
747
748        if let Some(way) = existing_way {
749            let idx = (set.offset + way as u64) as usize;
750            self.counts.set(idx as u64, 1);
751            let slot = &mut self.values.as_mut_slice()[idx];
752            // SAFETY: search found this slot with count > 0, so it is initialized.
753            let evicted = unsafe { slot.assume_init_read() };
754            slot.write(value);
755            return UpsertResult {
756                index: idx,
757                updated: UpdateOrInsert::Update,
758                evicted: Some(evicted),
759            };
760        }
761
762        let ways = L::WAYS;
763        let max_count = (1u64 << L::CLOCK_BITS) - 1;
764        let clock_index = set.offset / ways;
765
766        let mut way = self.clocks.get(clock_index);
767        let way_mask = ways - 1;
768
769        // Maximum iterations: every slot at max count, decrementing all down to 1,
770        // then one more iteration to decrement to 0 and break.
771        let clock_iterations_max = ways * (max_count - 1);
772
773        let mut evicted: Option<E::Value> = None;
774        let mut safety_count = 0u64;
775        loop {
776            if safety_count > clock_iterations_max {
777                unreachable!("CLOCK algorithm exceeded maximum iterations");
778            }
779            let idx = set.offset + way;
780            let mut count = self.counts.get(idx);
781            if count == 0 {
782                break; // Way is already free.
783            }
784            count -= 1;
785            self.counts.set(idx, count);
786            if count == 0 {
787                // SAFETY: count was > 0 before decrement, so the slot is initialized.
788                evicted = Some(unsafe { self.values.as_slice()[idx as usize].assume_init_read() });
789                break;
790            }
791            safety_count += 1;
792            way = (way + 1) & way_mask;
793        }
794
795        debug_assert!(self.counts.get(set.offset + way) == 0);
796
797        let idx = (set.offset + way) as usize;
798        match &mut self.tag_store {
799            TagStore::U8(buf) => buf.as_mut_slice()[idx] = set.tag as u8,
800            TagStore::U16(buf) => buf.as_mut_slice()[idx] = set.tag,
801        }
802        self.values.as_mut_slice()[idx].write(value);
803        self.counts.set(set.offset + way, 1);
804        self.clocks.set(clock_index, (way + 1) & way_mask);
805        if evicted.is_none() {
806            self.metrics.value_count += 1;
807        }
808
809        UpsertResult {
810            index: idx,
811            updated: UpdateOrInsert::Insert,
812            evicted,
813        }
814    }
815
816    #[inline]
817    fn associate<Q: Hash + ?Sized>(&self, key: &Q) -> SetView {
818        let entropy = self.hash_builder.hash_one(key);
819        let tag = (entropy & ((1u64 << L::TAG_BITS) - 1)) as Tag;
820        let index = fastrange(entropy, self.sets);
821        let offset = index * L::WAYS;
822        SetView { tag, offset }
823    }
824
825    #[inline]
826    fn search<Q>(&self, set: &SetView, key: &Q) -> Option<u16>
827    where
828        Q: Equivalent<E::Key> + ?Sized,
829    {
830        let ways = L::WAYS;
831        let offset = set.offset;
832
833        let matching_ways: u64 = match &self.tag_store {
834            TagStore::U8(buf) => {
835                let tags = buf.as_slice();
836                let slice = &tags[offset as usize..(offset + ways) as usize];
837                simd::search_tags(slice, set.tag as u8, ways)
838            }
839            TagStore::U16(buf) => {
840                let tags = buf.as_slice();
841                let slice = &tags[offset as usize..(offset + ways) as usize];
842                simd::search_tags_u16(slice, set.tag, ways)
843            }
844        };
845
846        if matching_ways == 0 {
847            return None;
848        }
849
850        for way in 0..ways {
851            if (matching_ways >> way) & 1 == 1 && self.counts.get(offset + way) > 0 {
852                // SAFETY: count > 0 means the slot was initialized via MaybeUninit::write.
853                let val =
854                    unsafe { self.values.as_slice()[(offset + way) as usize].assume_init_ref() };
855                if key.equivalent(E::extract(val)) {
856                    return Some(way as u16);
857                }
858            }
859        }
860        None
861    }
862}
863
864impl<E: KeyExtract, S: BuildHasher, L: CacheLayout, A: Allocator> Drop
865    for SetAssociativeCache<E, S, L, A>
866{
867    fn drop(&mut self) {
868        // Drop all live values before deallocating.
869        let total_slots = self.sets * L::WAYS;
870        for i in 0..total_slots {
871            if self.counts.get(i) > 0 {
872                // SAFETY: count > 0 means the slot was initialized via MaybeUninit::write.
873                unsafe { self.values.as_mut_slice()[i as usize].assume_init_drop() };
874            }
875        }
876        // SAFETY: each buffer is deallocated exactly once with the same allocator
877        // that was used for allocation, stored in self.alloc.
878        unsafe {
879            self.tag_store.dealloc(&self.alloc);
880            self.values.dealloc(&self.alloc);
881            self.counts.dealloc(&self.alloc);
882            self.clocks.dealloc(&self.alloc);
883        }
884    }
885}
886
887#[cfg(test)]
888mod tests {
889    use super::*;
890    use std::hash::Hasher;
891
892    // --- PackedArray tests ---
893
894    #[test]
895    fn packed_array_unit() {
896        let mut words = [0u64; 8];
897        words[1] = 0b10110010;
898
899        let mut p = PackedArray {
900            uint_bits: 2,
901            words: PackedWords::Vec(words.to_vec()),
902        };
903
904        assert_eq!(p.get(32 + 0), 0b10);
905        assert_eq!(p.get(32 + 1), 0b00);
906        assert_eq!(p.get(32 + 2), 0b11);
907        assert_eq!(p.get(32 + 3), 0b10);
908
909        p.set(0, 0b01);
910        assert_eq!(p.words().to_vec()[0], 0b00000001);
911        assert_eq!(p.get(0), 0b01);
912
913        p.set(1, 0b10);
914        assert_eq!(p.words().to_vec()[0], 0b00001001);
915        assert_eq!(p.get(1), 0b10);
916
917        p.set(2, 0b11);
918        assert_eq!(p.words().to_vec()[0], 0b00111001);
919        assert_eq!(p.get(2), 0b11);
920
921        p.set(3, 0b11);
922        assert_eq!(p.words().to_vec()[0], 0b11111001);
923        assert_eq!(p.get(3), 0b11);
924
925        p.set(3, 0b01);
926        assert_eq!(p.words().to_vec()[0], 0b01111001);
927        assert_eq!(p.get(3), 0b01);
928
929        p.set(3, 0b00);
930        assert_eq!(p.words().to_vec()[0], 0b00111001);
931        assert_eq!(p.get(3), 0b00);
932
933        p.set(4, 0b11);
934        assert_eq!(
935            p.words().to_vec()[0],
936            0b0000000000000000000000000000000000000000000000000000001100111001
937        );
938
939        p.set(31, 0b11);
940        assert_eq!(
941            p.words().to_vec()[0],
942            0b1100000000000000000000000000000000000000000000000000001100111001
943        );
944    }
945
946    // --- BuildHasher implementations for tests ---
947
948    /// A hasher that returns the u64 as-is (identity hash).
949    struct IdentityHasher(u64);
950
951    impl Hasher for IdentityHasher {
952        fn finish(&self) -> u64 {
953            self.0
954        }
955        fn write(&mut self, _bytes: &[u8]) {
956            unimplemented!("IdentityHasher only supports write_u64");
957        }
958        fn write_u64(&mut self, i: u64) {
959            self.0 = i;
960        }
961    }
962
963    #[derive(Clone)]
964    struct IdentityBuildHasher;
965
966    impl BuildHasher for IdentityBuildHasher {
967        type Hasher = IdentityHasher;
968        fn build_hasher(&self) -> IdentityHasher {
969            IdentityHasher(0)
970        }
971    }
972
973    /// A hasher that always returns 0.
974    struct ZeroHasher;
975
976    impl Hasher for ZeroHasher {
977        fn finish(&self) -> u64 {
978            0
979        }
980        fn write(&mut self, _bytes: &[u8]) {}
981        fn write_u64(&mut self, _i: u64) {}
982    }
983
984    #[derive(Clone)]
985    struct ZeroBuildHasher;
986
987    impl BuildHasher for ZeroBuildHasher {
988        type Hasher = ZeroHasher;
989        fn build_hasher(&self) -> ZeroHasher {
990            ZeroHasher
991        }
992    }
993
994    // --- SetAssociativeCache: KeyExtract for u64 identity ---
995
996    struct IdentityExtract;
997
998    impl KeyExtract for IdentityExtract {
999        type Key = u64;
1000        type Value = u64;
1001
1002        #[inline]
1003        fn extract(value: &u64) -> &u64 {
1004            value
1005        }
1006    }
1007
1008    // --- Test layout types ---
1009
1010    struct Ways2Layout;
1011    impl CacheLayout for Ways2Layout {
1012        const WAYS: u64 = 2;
1013        const TAG_BITS: u64 = 8;
1014        const CLOCK_BITS: u64 = 2;
1015        const CACHE_LINE_SIZE: u64 = 64;
1016    }
1017
1018    struct Ways4Layout;
1019    impl CacheLayout for Ways4Layout {
1020        const WAYS: u64 = 4;
1021        const TAG_BITS: u64 = 8;
1022        const CLOCK_BITS: u64 = 2;
1023        const CACHE_LINE_SIZE: u64 = 64;
1024    }
1025
1026    struct Tag16Layout;
1027    impl CacheLayout for Tag16Layout {
1028        const WAYS: u64 = 16;
1029        const TAG_BITS: u64 = 16;
1030        const CLOCK_BITS: u64 = 2;
1031        const CACHE_LINE_SIZE: u64 = 64;
1032    }
1033
1034    struct Clock1Layout;
1035    impl CacheLayout for Clock1Layout {
1036        const WAYS: u64 = 16;
1037        const TAG_BITS: u64 = 8;
1038        const CLOCK_BITS: u64 = 1;
1039        const CACHE_LINE_SIZE: u64 = 64;
1040    }
1041
1042    struct Clock4Layout;
1043    impl CacheLayout for Clock4Layout {
1044        const WAYS: u64 = 16;
1045        const TAG_BITS: u64 = 8;
1046        const CLOCK_BITS: u64 = 4;
1047        const CACHE_LINE_SIZE: u64 = 64;
1048    }
1049
1050    fn run_cache_test_with_hasher<S: BuildHasher, L: CacheLayout>(hash_builder: S) {
1051        let ways = L::WAYS;
1052        let value_count_max = 16 * 16 * 8;
1053
1054        let mut sac = SetAssociativeCache::<IdentityExtract, S, L>::with_hasher(
1055            value_count_max,
1056            hash_builder,
1057        );
1058
1059        // Verify initial state
1060        assert!(sac.tag_store.all_zero());
1061        assert!(sac.counts.words().iter().all(|&w| w == 0));
1062        assert!(sac.clocks.words().iter().all(|&w| w == 0));
1063        assert_eq!(sac.metrics.value_count, 0);
1064
1065        let clock_bits = L::CLOCK_BITS;
1066        let max_count = (1u64 << clock_bits) - 1;
1067
1068        let count_after_get = max_count.min(2);
1069
1070        // Fill up the first set entirely.
1071        for i in 0..ways {
1072            assert_eq!(sac.clocks.get(0), i);
1073            let key = i * sac.sets;
1074            sac.upsert(key);
1075            assert_eq!(sac.counts.get(i), 1);
1076            assert_eq!(*sac.get(&key).unwrap(), key);
1077            assert_eq!(sac.counts.get(i), count_after_get);
1078        }
1079        assert_eq!(sac.clocks.get(0), 0);
1080        assert_eq!(sac.metrics.value_count, ways);
1081
1082        // Insert another element into the first set, causing key 0 to be evicted.
1083        {
1084            let key = ways * sac.sets;
1085            sac.upsert(key);
1086            assert_eq!(sac.counts.get(0), 1);
1087            assert_eq!(*sac.get(&key).unwrap(), key);
1088            assert_eq!(sac.counts.get(0), count_after_get);
1089
1090            assert!(sac.get(&0).is_none());
1091
1092            for i in 1..ways {
1093                assert_eq!(sac.counts.get(i), 1);
1094            }
1095            assert_eq!(sac.metrics.value_count, ways);
1096        }
1097
1098        // Ensure removal works.
1099        {
1100            let remove_way = ways - 1;
1101            let key = remove_way * sac.sets;
1102            assert_eq!(*sac.get(&key).unwrap(), key);
1103
1104            sac.remove(&key);
1105            assert!(sac.get(&key).is_none());
1106            assert_eq!(sac.counts.get(remove_way), 0);
1107            assert_eq!(sac.metrics.value_count, ways - 1);
1108        }
1109
1110        sac.reset();
1111
1112        assert!(sac.tag_store.all_zero());
1113        assert!(sac.counts.words().iter().all(|&w| w == 0));
1114        assert!(sac.clocks.words().iter().all(|&w| w == 0));
1115        assert_eq!(sac.metrics.value_count, 0);
1116
1117        // Fill up the first set entirely, maxing out the count for each slot.
1118        for i in 0..ways {
1119            assert_eq!(sac.clocks.get(0), i);
1120            let key = i * sac.sets;
1121            sac.upsert(key);
1122            assert_eq!(sac.counts.get(i), 1);
1123            for j in 2..=max_count {
1124                assert_eq!(*sac.get(&key).unwrap(), key);
1125                assert_eq!(sac.counts.get(i), j);
1126            }
1127            // One more get should stay at max.
1128            assert_eq!(*sac.get(&key).unwrap(), key);
1129            assert_eq!(sac.counts.get(i), max_count);
1130        }
1131        assert_eq!(sac.clocks.get(0), 0);
1132        assert_eq!(sac.metrics.value_count, ways);
1133
1134        // Insert another element into the first set, causing key 0 to be evicted.
1135        {
1136            let key = ways * sac.sets;
1137            sac.upsert(key);
1138            assert_eq!(sac.counts.get(0), 1);
1139            assert_eq!(*sac.get(&key).unwrap(), key);
1140            assert_eq!(sac.counts.get(0), count_after_get);
1141
1142            assert!(sac.get(&0).is_none());
1143
1144            for i in 1..ways {
1145                assert_eq!(sac.counts.get(i), 1);
1146            }
1147            assert_eq!(sac.metrics.value_count, ways);
1148        }
1149    }
1150
1151    #[test]
1152    fn set_associative_cache_eviction() {
1153        run_cache_test_with_hasher::<_, DefaultLayout>(IdentityBuildHasher);
1154    }
1155
1156    #[test]
1157    fn set_associative_cache_hash_collision() {
1158        run_cache_test_with_hasher::<_, DefaultLayout>(ZeroBuildHasher);
1159    }
1160
1161    #[test]
1162    fn set_associative_cache_ways_2() {
1163        run_cache_test_with_hasher::<_, Ways2Layout>(IdentityBuildHasher);
1164    }
1165
1166    #[test]
1167    fn set_associative_cache_ways_4() {
1168        run_cache_test_with_hasher::<_, Ways4Layout>(IdentityBuildHasher);
1169    }
1170
1171    #[test]
1172    fn set_associative_cache_tag_bits_16() {
1173        run_cache_test_with_hasher::<_, Tag16Layout>(IdentityBuildHasher);
1174    }
1175
1176    #[test]
1177    fn set_associative_cache_clock_bits_1() {
1178        run_cache_test_with_hasher::<_, Clock1Layout>(IdentityBuildHasher);
1179    }
1180
1181    #[test]
1182    fn set_associative_cache_clock_bits_4() {
1183        run_cache_test_with_hasher::<_, Clock4Layout>(IdentityBuildHasher);
1184    }
1185
1186    // --- SIMD search_tags correctness ---
1187
1188    #[test]
1189    fn search_tags_correctness() {
1190        use rand::rngs::SmallRng;
1191        use rand::{Rng, SeedableRng};
1192
1193        let mut rng = SmallRng::seed_from_u64(42);
1194
1195        for ways in [2u64, 4, 16] {
1196            for _ in 0..10_000 {
1197                let mut tags = vec![0u8; ways as usize];
1198                for t in tags.iter_mut() {
1199                    *t = rng.random();
1200                }
1201                let needle: u8 = rng.random();
1202
1203                // Force some matches.
1204                let matches_min = rng.random_range(0..=ways as usize);
1205                let mut indices: Vec<usize> = (0..ways as usize).collect();
1206                // Simple Fisher-Yates
1207                for i in (1..indices.len()).rev() {
1208                    let j = rng.random_range(0..=i);
1209                    indices.swap(i, j);
1210                }
1211                for &idx in &indices[..matches_min] {
1212                    tags[idx] = needle;
1213                }
1214
1215                // Reference
1216                let mut expected = 0u64;
1217                for (i, &t) in tags.iter().enumerate() {
1218                    if t == needle {
1219                        expected |= 1 << i;
1220                    }
1221                }
1222
1223                let actual = simd::search_tags(&tags, needle, ways);
1224                assert_eq!(
1225                    expected, actual,
1226                    "ways={ways} needle={needle} tags={tags:?}"
1227                );
1228            }
1229        }
1230    }
1231
1232    // --- PairExtract test ---
1233
1234    #[test]
1235    fn pair_extract_works() {
1236        type E = PairExtract<u32, String>;
1237        let val = (42u32, "hello".to_string());
1238        assert_eq!(E::extract(&val), &42u32);
1239    }
1240}