viceroy_lib/
cache.rs

1use core::str;
2use std::{collections::HashSet, sync::Arc, time::Duration};
3
4use bytes::Bytes;
5#[cfg(test)]
6use proptest_derive::Arbitrary;
7
8use crate::{
9    body::Body,
10    component::bindings::fastly::compute::types::Error as ComponentError,
11    wiggle_abi::types::{BodyHandle, CacheOverrideTag, FastlyStatus},
12};
13
14use http::{HeaderMap, HeaderValue};
15
16mod store;
17mod variance;
18
19use store::{CacheData, CacheKeyObjects, GetBodyBuilder, ObjectMeta, Obligation};
20pub use variance::VaryRule;
21
22#[derive(Debug, thiserror::Error)]
23#[non_exhaustive]
24pub enum Error {
25    #[error("invalid key")]
26    InvalidKey,
27
28    #[error("handle is not writeable")]
29    CannotWrite,
30
31    #[error("no entry for key in cache")]
32    Missing,
33
34    #[error("invalid argument: {0}")]
35    InvalidArgument(&'static str),
36
37    #[error("cache entry's body is currently being read by another body")]
38    HandleBodyUsed,
39
40    #[error("cache entry is not revalidatable")]
41    NotRevalidatable,
42}
43
44impl From<Error> for crate::Error {
45    fn from(value: Error) -> Self {
46        crate::Error::CacheError(value)
47    }
48}
49
50impl From<&Error> for FastlyStatus {
51    fn from(value: &Error) -> Self {
52        match value {
53            // TODO: cceckman-at-fastly: These may not correspond to the same errors as the compute
54            // platform uses. Check!
55            Error::InvalidKey => FastlyStatus::Inval,
56            Error::InvalidArgument(_) => FastlyStatus::Inval,
57            Error::CannotWrite => FastlyStatus::Badf,
58            Error::Missing => FastlyStatus::None,
59            Error::HandleBodyUsed => FastlyStatus::Badf,
60            Error::NotRevalidatable => FastlyStatus::Badf,
61        }
62    }
63}
64
65impl From<Error> for ComponentError {
66    fn from(value: Error) -> Self {
67        match value {
68            // TODO: cceckman-at-fastly: These may not correspond to the same errors as the compute
69            // platform uses. Check!
70            Error::InvalidKey => ComponentError::InvalidArgument,
71            Error::InvalidArgument(_) => ComponentError::InvalidArgument,
72            Error::CannotWrite => ComponentError::AuxiliaryError,
73            Error::Missing => ComponentError::CannotRead,
74            Error::HandleBodyUsed => ComponentError::AuxiliaryError,
75            Error::NotRevalidatable => ComponentError::AuxiliaryError,
76        }
77    }
78}
79
80/// Primary cache key: an up-to-4KiB buffer.
81#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
82#[cfg_attr(test, derive(Arbitrary))]
83pub struct CacheKey(
84    #[cfg_attr(test, proptest(filter = "|f| f.len() <= CacheKey::MAX_LENGTH"))] Vec<u8>,
85);
86
87impl CacheKey {
88    /// The maximum size of a cache key is 4KiB.
89    pub const MAX_LENGTH: usize = 4096;
90}
91
92impl TryFrom<&Vec<u8>> for CacheKey {
93    type Error = Error;
94
95    fn try_from(value: &Vec<u8>) -> Result<Self, Self::Error> {
96        value.as_slice().try_into()
97    }
98}
99
100impl TryFrom<Vec<u8>> for CacheKey {
101    type Error = Error;
102
103    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
104        if value.len() > Self::MAX_LENGTH {
105            Err(Error::InvalidKey)
106        } else {
107            Ok(CacheKey(value))
108        }
109    }
110}
111
112impl TryFrom<&[u8]> for CacheKey {
113    type Error = Error;
114
115    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
116        if value.len() > CacheKey::MAX_LENGTH {
117            Err(Error::InvalidKey)
118        } else {
119            Ok(CacheKey(value.to_owned()))
120        }
121    }
122}
123
124impl TryFrom<&str> for CacheKey {
125    type Error = Error;
126
127    fn try_from(value: &str) -> Result<Self, Self::Error> {
128        value.as_bytes().try_into()
129    }
130}
131
132/// The result of a lookup: the object (if found), and/or an obligation to fetch.
133#[derive(Debug)]
134pub struct CacheEntry {
135    key: CacheKey,
136    found: Option<Found>,
137    go_get: Option<Obligation>,
138
139    /// Respect the range in body() even when the body length is not yet known.
140    ///
141    /// When a cached item is Found, the length of the cached item may or may not be known:
142    /// if no expected length was provided and the body is still streaming, the length is unknown.
143    ///
144    /// When always_use_requested_range is false, and the length is unknown,
145    /// body() returns the full body regardless of the requested range.
146    /// When always_use_requested_range is true, and the length is unknown,
147    /// body() blocks until the start of the range is available.
148    always_use_requested_range: bool,
149}
150
151impl CacheEntry {
152    /// Set the always_use_requested_range flag.
153    /// This applies to all subsequent lookups from this CacheEntry or future entries derived from it.
154    pub fn with_always_use_requested_range(self, always_use_requested_range: bool) -> Self {
155        Self {
156            always_use_requested_range,
157            ..self
158        }
159    }
160
161    /// Return a stub entry to hold in CacheBusy.
162    pub fn stub(&self) -> CacheEntry {
163        Self {
164            key: self.key.clone(),
165            found: None,
166            go_get: None,
167            always_use_requested_range: false,
168        }
169    }
170
171    /// Returns the key used to generate this CacheEntry.
172    pub fn key(&self) -> &CacheKey {
173        &self.key
174    }
175    /// Returns the data found in the cache, if any was present.
176    pub fn found(&self) -> Option<&Found> {
177        self.found.as_ref()
178    }
179
180    /// Returns the data found in the cache, if any was present.
181    pub fn found_mut(&mut self) -> Option<&mut Found> {
182        self.found.as_mut()
183    }
184
185    /// Returns the obligation to fetch, if required
186    pub fn go_get(&self) -> Option<&Obligation> {
187        self.go_get.as_ref()
188    }
189
190    /// Ignore the obligation to fetch, if present.
191    /// Returns true if it actually canceled an obligation.
192    pub fn cancel(&mut self) -> bool {
193        self.go_get.take().is_some()
194    }
195
196    /// Access the body of the cached item, if available.
197    pub async fn body(&self, from: Option<u64>, to: Option<u64>) -> Result<Body, crate::Error> {
198        let found = self
199            .found
200            .as_ref()
201            .ok_or(crate::Error::CacheError(Error::Missing))?;
202        found
203            .get_body()
204            .with_range(from, to)
205            .with_always_use_requested_range(self.always_use_requested_range)
206            .build()
207            .await
208    }
209
210    /// Insert the provided body into the cache.
211    ///
212    /// Returns a CacheEntry where the new item is Found.
213    pub fn insert(
214        &mut self,
215        options: WriteOptions,
216        body: Body,
217    ) -> Result<CacheEntry, crate::Error> {
218        let go_get = self.go_get.take().ok_or(Error::NotRevalidatable)?;
219        let found = go_get.insert(options, body);
220        Ok(CacheEntry {
221            key: self.key.clone(),
222            found: Some(found),
223            go_get: None,
224            always_use_requested_range: self.always_use_requested_range,
225        })
226    }
227
228    /// Freshen the existing cache item according to the new write options,
229    /// without changing the body.
230    pub async fn update(&mut self, options: WriteOptions) -> Result<(), crate::Error> {
231        let go_get = self.go_get.take().ok_or(Error::NotRevalidatable)?;
232        match go_get.update(options).await {
233            Ok(()) => Ok(()),
234            Err((go_get, err)) => {
235                // On failure, preserve the obligation.
236                self.go_get = Some(go_get);
237                Err(err)
238            }
239        }
240    }
241}
242
243/// A successful retrieval of an item from the cache.
244#[derive(Debug)]
245pub struct Found {
246    data: Arc<CacheData>,
247
248    /// The handle for the last body used to read from this Found.
249    ///
250    /// Only one Body may be outstanding from a given Found at a time.
251    /// (This is an implementation restriction within the compute platform).
252    /// We mirror the BodyHandle here when we create it; we can later check whether the handle is
253    /// still valid, to find an outstanding read.
254    pub last_body_handle: Option<BodyHandle>,
255}
256
257impl From<Arc<CacheData>> for Found {
258    fn from(data: Arc<CacheData>) -> Self {
259        Found {
260            data,
261            last_body_handle: None,
262        }
263    }
264}
265
266impl Found {
267    fn get_body(&self) -> GetBodyBuilder<'_> {
268        self.data.as_ref().body()
269    }
270
271    /// Access the metadata of the cached object.
272    pub fn meta(&self) -> &ObjectMeta {
273        self.data.get_meta()
274    }
275
276    /// The length of the cached object, if known.
277    pub fn length(&self) -> Option<u64> {
278        self.data.length()
279    }
280}
281
282/// Cache for a service.
283///
284// TODO: cceckman-at-fastly:
285// Explain some about how this works:
286// - Request collapsing
287// - Stale-while-revalidate
288pub struct Cache {
289    inner: moka::future::Cache<CacheKey, Arc<CacheKeyObjects>>,
290}
291
292impl Default for Cache {
293    fn default() -> Self {
294        // TODO: cceckman-at-fastly:
295        // Weight by size, allow a cap on max size?
296        let inner = moka::future::Cache::builder()
297            .eviction_listener(|key, _value, cause| {
298                tracing::info!("cache eviction of {key:?}: {cause:?}")
299            })
300            .build();
301        Cache { inner }
302    }
303}
304
305impl Cache {
306    /// Perform a non-transactional lookup.
307    pub async fn lookup(&self, key: &CacheKey, headers: &HeaderMap) -> CacheEntry {
308        let found = self
309            .inner
310            .get_with_by_ref(key, async { Default::default() })
311            .await
312            .get(headers)
313            .map(|data| Found {
314                data,
315                last_body_handle: None,
316            });
317        CacheEntry {
318            key: key.clone(),
319            found,
320            go_get: None,
321            always_use_requested_range: false,
322        }
323    }
324
325    /// Perform a transactional lookup.
326    pub async fn transaction_lookup(
327        &self,
328        key: &CacheKey,
329        headers: &HeaderMap,
330        ok_to_wait: bool,
331    ) -> CacheEntry {
332        let (found, obligation) = self
333            .inner
334            .get_with_by_ref(key, async { Default::default() })
335            .await
336            .transaction_get(headers, ok_to_wait)
337            .await;
338        CacheEntry {
339            key: key.clone(),
340            found: found.map(|v| v.into()),
341            go_get: obligation,
342            always_use_requested_range: false,
343        }
344    }
345
346    /// Perform a non-transactional lookup for the given cache key.
347    /// Note: races with other insertions, including transactional insertions.
348    /// Last writer wins!
349    pub async fn insert(
350        &self,
351        key: &CacheKey,
352        request_headers: HeaderMap,
353        options: WriteOptions,
354        body: Body,
355    ) {
356        self.inner
357            .get_with_by_ref(key, async { Default::default() })
358            .await
359            .insert(request_headers, options, body, None);
360    }
361
362    /// Purge/soft-purge all cache entries corresponding to the given surrogate key.
363    /// Returns the number of entries (variants) purged.
364    ///
365    /// Note: this does not block concurrent reads _or inserts_; an insertion can race with the
366    /// purge.
367    pub fn purge(&self, key: SurrogateKey, soft_purge: bool) -> usize {
368        self.inner
369            .iter()
370            .map(|(_, entry)| entry.purge(&key, soft_purge))
371            .sum()
372    }
373}
374
375/// Options that can be applied to a write, e.g. insert or transaction_insert.
376#[derive(Default, Clone)]
377pub struct WriteOptions {
378    pub max_age: Duration,
379    pub initial_age: Duration,
380    pub stale_while_revalidate: Duration,
381    pub vary_rule: VaryRule,
382    pub user_metadata: Bytes,
383    pub length: Option<u64>,
384    pub sensitive_data: bool,
385    pub edge_max_age: Duration,
386    pub surrogate_keys: SurrogateKeySet,
387}
388
389impl WriteOptions {
390    pub fn new(max_age: Duration) -> Self {
391        WriteOptions {
392            max_age,
393            initial_age: Duration::ZERO,
394            stale_while_revalidate: Duration::ZERO,
395            vary_rule: VaryRule::default(),
396            user_metadata: Bytes::new(),
397            length: None,
398            sensitive_data: false,
399            edge_max_age: max_age,
400            surrogate_keys: Default::default(),
401        }
402    }
403}
404
405/// Optional override for response caching behavior.
406#[derive(Clone, Debug, Default)]
407pub enum CacheOverride {
408    /// Do not override the behavior specified in the origin response's cache control headers.
409    #[default]
410    None,
411    /// Do not cache the response to this request, regardless of the origin response's headers.
412    Pass,
413    /// Override particular cache control settings.
414    ///
415    /// The origin response's cache control headers will be used for ttl and stale_while_revalidate if `None`.
416    Override {
417        ttl: Option<u32>,
418        stale_while_revalidate: Option<u32>,
419        pci: bool,
420        surrogate_key: Option<HeaderValue>,
421    },
422}
423
424impl CacheOverride {
425    pub fn is_pass(&self) -> bool {
426        if let Self::Pass = self {
427            true
428        } else {
429            false
430        }
431    }
432
433    /// Convert from the representation suitable for passing across the ABI boundary.
434    ///
435    /// Returns `None` if the tag is not recognized. Depending on the tag, some of the values may be
436    /// ignored.
437    pub fn from_abi(
438        tag: u32,
439        ttl: u32,
440        swr: u32,
441        surrogate_key: Option<HeaderValue>,
442    ) -> Option<Self> {
443        CacheOverrideTag::from_bits(tag).map(|tag| {
444            if tag.contains(CacheOverrideTag::PASS) {
445                return CacheOverride::Pass;
446            }
447            if tag.is_empty() && surrogate_key.is_none() {
448                return CacheOverride::None;
449            }
450            let ttl = if tag.contains(CacheOverrideTag::TTL) {
451                Some(ttl)
452            } else {
453                None
454            };
455            let stale_while_revalidate = if tag.contains(CacheOverrideTag::STALE_WHILE_REVALIDATE) {
456                Some(swr)
457            } else {
458                None
459            };
460            let pci = tag.contains(CacheOverrideTag::PCI);
461            CacheOverride::Override {
462                ttl,
463                stale_while_revalidate,
464                pci,
465                surrogate_key,
466            }
467        })
468    }
469}
470
471/// Maximum length of surrogate keys (when combined with spaces).
472const MAX_SURROGATE_KEYS_LENGTH: usize = 16 * 1024;
473/// Maximum length of a single surrogate key.
474const MAX_SURROGATE_KEY_LENGTH: usize = 1024;
475
476#[derive(Debug, Default, Clone)]
477pub struct SurrogateKeySet(HashSet<SurrogateKey>);
478
479impl std::fmt::Display for SurrogateKeySet {
480    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481        write!(f, "{{")?;
482        for (i, item) in self.0.iter().enumerate() {
483            if i == 0 {
484                write!(f, "{item}")?;
485            } else {
486                write!(f, " {item}")?;
487            }
488        }
489        write!(f, "}}")
490    }
491}
492
493impl TryFrom<&[u8]> for SurrogateKeySet {
494    type Error = crate::Error;
495
496    fn try_from(s: &[u8]) -> Result<Self, Self::Error> {
497        if s.len() > MAX_SURROGATE_KEYS_LENGTH {
498            return Err(
499                Error::InvalidArgument("surrogate key set exceeds maximum length (16Ki)").into(),
500            );
501        }
502        let result: Result<HashSet<_>, _> = s
503            .split(|c| *c == b' ')
504            .filter(|sk| !sk.is_empty())
505            .map(SurrogateKey::try_from)
506            .collect();
507        Ok(SurrogateKeySet(result?))
508    }
509}
510
511impl std::str::FromStr for SurrogateKeySet {
512    type Err = crate::Error;
513
514    fn from_str(s: &str) -> Result<Self, Self::Err> {
515        s.as_bytes().try_into()
516    }
517}
518
519/// A validated surrogate key: a non-empty string containing only visible ASCII characters.
520#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone)]
521pub struct SurrogateKey(String);
522
523impl std::fmt::Display for SurrogateKey {
524    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525        self.0.fmt(f)
526    }
527}
528
529impl TryFrom<&[u8]> for SurrogateKey {
530    type Error = Error;
531
532    fn try_from(b: &[u8]) -> Result<Self, Self::Error> {
533        if b.len() > MAX_SURROGATE_KEY_LENGTH {
534            return Err(Error::InvalidArgument(
535                "surrogate key exceeds maximum length (1024)",
536            ));
537        }
538
539        if !b.iter().all(|c| c.is_ascii_graphic()) {
540            return Err(Error::InvalidArgument(
541                "surrogate key contains characters other than graphical ASCII",
542            ));
543        }
544        let s = unsafe {
545            // All graphic ASCII characters are equivalent to the corresponding UTF-8 codepoints.
546            str::from_utf8_unchecked(b)
547        };
548        Ok(SurrogateKey(s.to_owned()))
549    }
550}
551
552impl std::str::FromStr for SurrogateKey {
553    type Err = Error;
554
555    fn from_str(s: &str) -> Result<Self, Self::Err> {
556        s.as_bytes().try_into()
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use std::rc::Rc;
563
564    use http::HeaderName;
565    use proptest::prelude::*;
566
567    use super::*;
568
569    proptest! {
570        #[test]
571        fn reject_cache_key_too_long(l in 4097usize..5000) {
572            let mut v : Vec<u8> = Vec::new();
573            v.resize(l, 0);
574            CacheKey::try_from(&v).unwrap_err();
575        }
576    }
577
578    proptest! {
579        #[test]
580        fn accept_valid_cache_key_len(l in 0usize..4096) {
581            let mut v : Vec<u8> = Vec::new();
582            v.resize(l, 0);
583            let _ = CacheKey::try_from(&v).unwrap();
584        }
585    }
586
587    proptest! {
588        #[test]
589        fn nontransactional_insert_lookup(
590                key in any::<CacheKey>(),
591                max_age in any::<u32>(),
592                initial_age in any::<u32>(),
593                value in any::<Vec<u8>>()) {
594            let cache = Cache::default();
595
596            // We can't use tokio::test and proptest! together; both alter the signature of the
597            // test function, and are not aware of each other enough for it to pass.
598            let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
599            rt.block_on(async {
600                let empty = cache.lookup(&key, &HeaderMap::default()).await;
601                assert!(empty.found().is_none());
602                // The non-transactional case does not produce an obligation (go_get)
603                assert!(empty.go_get().is_none());
604
605                let write_options = WriteOptions {
606                    max_age: Duration::from_secs(max_age as u64),
607                    initial_age: Duration::from_secs(initial_age as u64),
608                    ..Default::default()
609                };
610
611                cache.insert(&key, HeaderMap::default(), write_options, value.clone().into()).await;
612
613                let nonempty = cache.lookup(&key, &HeaderMap::default()).await;
614                let found = nonempty.found().expect("should have found inserted key");
615                let got = found.get_body().build().await.unwrap().read_into_vec().await.unwrap();
616                assert_eq!(got, value);
617            });
618        }
619    }
620
621    #[test]
622    fn parse_surrogate_keys() {
623        let keys: SurrogateKeySet = "keyA keyB".parse().unwrap();
624        let key_a: SurrogateKey = "keyA".parse().unwrap();
625        let key_b: SurrogateKey = "keyB".parse().unwrap();
626        assert_eq!(keys.0.len(), 2);
627        assert!(keys.0.contains(&key_a));
628        assert!(keys.0.contains(&key_b));
629    }
630
631    #[tokio::test]
632    async fn insert_immediately_stale() {
633        let cache = Cache::default();
634        let key = ([1u8].as_slice()).try_into().unwrap();
635
636        // Insert an already-stale entry:
637        let write_options = WriteOptions {
638            max_age: Duration::from_secs(1),
639            initial_age: Duration::from_secs(2),
640            ..Default::default()
641        };
642
643        let mut body = Body::empty();
644        body.push_back([1u8].as_slice());
645
646        cache
647            .insert(&key, HeaderMap::default(), write_options, body)
648            .await;
649
650        let nonempty = cache.lookup(&key, &HeaderMap::default()).await;
651        let found = nonempty.found().expect("should have found inserted key");
652        assert!(!found.meta().is_fresh());
653    }
654
655    #[tokio::test]
656    async fn test_vary() {
657        let cache = Cache::default();
658        let key = ([1u8].as_slice()).try_into().unwrap();
659
660        let header_name = HeaderName::from_static("x-viceroy-test");
661        let request_headers: HeaderMap = [(header_name.clone(), HeaderValue::from_static("test"))]
662            .into_iter()
663            .collect();
664
665        let write_options = WriteOptions {
666            max_age: Duration::from_secs(100),
667            vary_rule: VaryRule::new([&header_name].into_iter()),
668            ..Default::default()
669        };
670        let body = Body::empty();
671        cache
672            .insert(&key, request_headers.clone(), write_options, body)
673            .await;
674
675        let empty_headers = cache.lookup(&key, &HeaderMap::default()).await;
676        assert!(empty_headers.found().is_none());
677
678        let matched_headers = cache.lookup(&key, &request_headers).await;
679        assert!(matched_headers.found.is_some());
680
681        let r2_headers: HeaderMap = [(header_name.clone(), HeaderValue::from_static("assert"))]
682            .into_iter()
683            .collect();
684        let mismatched_headers = cache.lookup(&key, &r2_headers).await;
685        assert!(mismatched_headers.found.is_none());
686    }
687
688    #[tokio::test]
689    async fn insert_stale_while_revalidate() {
690        let cache = Cache::default();
691        let key = ([1u8].as_slice()).try_into().unwrap();
692
693        // Insert an already-stale entry, with an SWR period:
694        let write_options = WriteOptions {
695            max_age: Duration::from_secs(1),
696            initial_age: Duration::from_secs(2),
697            stale_while_revalidate: Duration::from_secs(10),
698            ..Default::default()
699        };
700
701        let mut body = Body::empty();
702        body.push_back([1u8].as_slice());
703        cache
704            .insert(&key, HeaderMap::default(), write_options, body)
705            .await;
706
707        let nonempty = cache.lookup(&key, &HeaderMap::default()).await;
708        let found = nonempty.found().expect("should have found inserted key");
709        assert!(!found.meta().is_fresh());
710        assert!(found.meta().is_usable());
711    }
712
713    #[tokio::test]
714    async fn insert_and_revalidate() {
715        let cache = Cache::default();
716        let key = ([1u8].as_slice()).try_into().unwrap();
717
718        // Insert an already-stale entry, with an SWR period:
719        let write_options = WriteOptions {
720            max_age: Duration::from_secs(1),
721            initial_age: Duration::from_secs(2),
722            stale_while_revalidate: Duration::from_secs(10),
723            ..Default::default()
724        };
725
726        let mut body = Body::empty();
727        body.push_back([1u8].as_slice());
728        cache
729            .insert(&key, HeaderMap::default(), write_options, body)
730            .await;
731
732        let mut txn1 = cache
733            .transaction_lookup(&key, &HeaderMap::default(), true)
734            .await;
735        let found = txn1.found().expect("should have found inserted key");
736        assert!(!found.meta().is_fresh());
737        assert!(found.meta().is_usable());
738
739        assert!(txn1.go_get().is_some());
740
741        // Another lookup should not get the obligation *or* block
742        {
743            let txn2 = cache
744                .transaction_lookup(&key, &HeaderMap::default(), true)
745                .await;
746            let found = txn2.found().expect("should have found inserted key");
747            assert!(!found.meta().is_fresh());
748            assert!(found.meta().is_usable());
749            assert!(txn2.go_get().is_none());
750        }
751
752        txn1.update(WriteOptions {
753            max_age: Duration::from_secs(10),
754            stale_while_revalidate: Duration::from_secs(10),
755            ..WriteOptions::default()
756        })
757        .await
758        .unwrap();
759
760        // After this, should get the new response:
761        let txn3 = cache
762            .transaction_lookup(&key, &HeaderMap::default(), true)
763            .await;
764        assert!(txn3.go_get().is_none());
765        let found = txn3.found().expect("should find updated entry");
766        assert!(found.meta().is_usable());
767        assert!(found.meta().is_fresh());
768    }
769
770    #[tokio::test]
771    async fn cannot_revalidate_first_insert() {
772        let cache = Cache::default();
773        let key = ([1u8].as_slice()).try_into().unwrap();
774
775        let mut txn1 = cache
776            .transaction_lookup(&key, &HeaderMap::default(), false)
777            .await;
778        assert!(txn1.found().is_none());
779        let opts = WriteOptions {
780            max_age: Duration::from_secs(10),
781            stale_while_revalidate: Duration::from_secs(10),
782            ..WriteOptions::default()
783        };
784        txn1.update(opts.clone()).await.unwrap_err();
785
786        // But we should still be able to insert.
787        txn1.insert(opts.clone(), Body::empty()).unwrap();
788    }
789
790    #[tokio::test]
791    async fn purge_by_surrogate_key() {
792        let cache = Cache::default();
793        let key: CacheKey = ([1u8].as_slice()).try_into().unwrap();
794
795        let surrogate_keys: SurrogateKeySet = "one two three".parse().unwrap();
796        let opts = WriteOptions {
797            max_age: Duration::from_secs(100),
798            surrogate_keys: surrogate_keys.clone(),
799            ..WriteOptions::default()
800        };
801
802        cache
803            .insert(&key, HeaderMap::default(), opts.clone(), Body::empty())
804            .await;
805        let result = cache.lookup(&key, &HeaderMap::default()).await;
806        assert!(result.found().is_some());
807
808        assert_eq!(cache.purge("two".parse().unwrap(), false), 1);
809        let result = cache.lookup(&key, &HeaderMap::default()).await;
810        assert!(result.found().is_none());
811    }
812
813    #[tokio::test]
814    async fn purge_variant_by_surrogate_key() {
815        let cache = Rc::new(Cache::default());
816        let key: CacheKey = ([1u8].as_slice()).try_into().unwrap();
817
818        // "Introduction and Variations on a Theme by Mozart, Op. 9, is one of Fernando Sor's most
819        // famous works for guitar."
820        let vary_header = HeaderName::from_static("opus-9");
821        let insert = {
822            let cache = cache.clone();
823            let key = key.clone();
824
825            move |surrogate_keys: &str, variation: &str| {
826                let mut headers = HeaderMap::new();
827                headers.insert(vary_header.clone(), variation.try_into().unwrap());
828                let surrogate_keys: SurrogateKeySet = surrogate_keys.parse().unwrap();
829                let opts = WriteOptions {
830                    max_age: Duration::from_secs(100),
831                    surrogate_keys: surrogate_keys.clone(),
832                    vary_rule: VaryRule::new([&vary_header]),
833                    ..WriteOptions::default()
834                };
835                let cache = cache.clone();
836                let key = key.clone();
837                async move {
838                    cache
839                        .insert(&key, headers.clone(), opts.clone(), Body::empty())
840                        .await;
841                    headers
842                }
843            }
844        };
845
846        let h1 = insert("one two three", "twinkle twinkle").await;
847        let h2 = insert("one three", "abcdefg").await;
848        assert!(cache.lookup(&key, &h1).await.found().is_some());
849        assert!(cache.lookup(&key, &h2).await.found().is_some());
850
851        assert_eq!(cache.purge("two".parse().unwrap(), false), 1);
852        assert!(cache.lookup(&key, &h1).await.found().is_none());
853        assert!(cache.lookup(&key, &h2).await.found().is_some());
854    }
855
856    #[tokio::test]
857    async fn soft_purge() {
858        let cache = Cache::default();
859        let key: CacheKey = ([1u8].as_slice()).try_into().unwrap();
860
861        let surrogate_keys: SurrogateKeySet = "one two three".parse().unwrap();
862        let opts = WriteOptions {
863            max_age: Duration::from_secs(100),
864            surrogate_keys: surrogate_keys.clone(),
865            ..WriteOptions::default()
866        };
867
868        cache
869            .insert(&key, HeaderMap::default(), opts.clone(), Body::empty())
870            .await;
871        let result = cache.lookup(&key, &HeaderMap::default()).await;
872        assert!(result.found().is_some());
873
874        assert_eq!(cache.purge("two".parse().unwrap(), true), 1);
875        let result = cache
876            .transaction_lookup(&key, &HeaderMap::default(), false)
877            .await;
878        let found = result.found().unwrap();
879        assert!(!found.meta().is_fresh());
880        assert!(found.meta().is_usable());
881        assert!(result.go_get().is_some());
882    }
883}