Skip to main content

viceroy_lib/
cache.rs

1use core::str;
2use std::{collections::HashSet, sync::Arc, time::Duration};
3
4use bytes::Bytes;
5#[cfg(test)]
6use proptest_derive::Arbitrary;
7
8use crate::{
9    body::Body,
10    component::bindings::fastly::compute::types::Error as ComponentError,
11    wiggle_abi::types::{BodyHandle, CacheOverrideTag, FastlyStatus},
12};
13
14use http::{HeaderMap, HeaderValue};
15
16mod store;
17mod variance;
18
19use store::{CacheData, CacheKeyObjects, GetBodyBuilder, ObjectMeta, Obligation};
20pub use variance::VaryRule;
21
22#[derive(Debug, thiserror::Error)]
23#[non_exhaustive]
24pub enum Error {
25    #[error("invalid key")]
26    InvalidKey,
27
28    #[error("handle is not writeable")]
29    CannotWrite,
30
31    #[error("no entry for key in cache")]
32    Missing,
33
34    #[error("invalid argument: {0}")]
35    InvalidArgument(&'static str),
36
37    #[error("cache entry's body is currently being read by another body")]
38    HandleBodyUsed,
39
40    #[error("cache entry is not revalidatable")]
41    NotRevalidatable,
42}
43
44impl From<Error> for crate::Error {
45    fn from(value: Error) -> Self {
46        crate::Error::CacheError(value)
47    }
48}
49
50impl From<&Error> for FastlyStatus {
51    fn from(value: &Error) -> Self {
52        match value {
53            // TODO: cceckman-at-fastly: These may not correspond to the same errors as the compute
54            // platform uses. Check!
55            Error::InvalidKey => FastlyStatus::Inval,
56            Error::InvalidArgument(_) => FastlyStatus::Inval,
57            Error::CannotWrite => FastlyStatus::Badf,
58            Error::Missing => FastlyStatus::None,
59            Error::HandleBodyUsed => FastlyStatus::Badf,
60            Error::NotRevalidatable => FastlyStatus::Badf,
61        }
62    }
63}
64
65impl From<Error> for ComponentError {
66    fn from(value: Error) -> Self {
67        match value {
68            // TODO: cceckman-at-fastly: These may not correspond to the same errors as the compute
69            // platform uses. Check!
70            Error::InvalidKey => ComponentError::InvalidArgument,
71            Error::InvalidArgument(_) => ComponentError::InvalidArgument,
72            Error::CannotWrite => ComponentError::AuxiliaryError,
73            Error::Missing => ComponentError::CannotRead,
74            Error::HandleBodyUsed => ComponentError::AuxiliaryError,
75            Error::NotRevalidatable => ComponentError::AuxiliaryError,
76        }
77    }
78}
79
80/// Primary cache key: an up-to-4KiB buffer.
81#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
82#[cfg_attr(test, derive(Arbitrary))]
83pub struct CacheKey(
84    #[cfg_attr(test, proptest(filter = "|f| f.len() <= CacheKey::MAX_LENGTH"))] Vec<u8>,
85);
86
87impl CacheKey {
88    /// The maximum size of a cache key is 4KiB.
89    pub const MAX_LENGTH: usize = 4096;
90}
91
92impl TryFrom<&Vec<u8>> for CacheKey {
93    type Error = Error;
94
95    fn try_from(value: &Vec<u8>) -> Result<Self, Self::Error> {
96        value.as_slice().try_into()
97    }
98}
99
100impl TryFrom<Vec<u8>> for CacheKey {
101    type Error = Error;
102
103    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
104        if value.len() > Self::MAX_LENGTH {
105            Err(Error::InvalidKey)
106        } else {
107            Ok(CacheKey(value))
108        }
109    }
110}
111
112impl TryFrom<&[u8]> for CacheKey {
113    type Error = Error;
114
115    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
116        if value.len() > CacheKey::MAX_LENGTH {
117            Err(Error::InvalidKey)
118        } else {
119            Ok(CacheKey(value.to_owned()))
120        }
121    }
122}
123
124impl TryFrom<&str> for CacheKey {
125    type Error = Error;
126
127    fn try_from(value: &str) -> Result<Self, Self::Error> {
128        value.as_bytes().try_into()
129    }
130}
131
132/// The result of a lookup: the object (if found), and/or an obligation to fetch.
133#[derive(Debug)]
134pub struct CacheEntry {
135    key: CacheKey,
136    found: Option<Found>,
137    go_get: Option<Obligation>,
138
139    /// Respect the range in body() even when the body length is not yet known.
140    ///
141    /// When a cached item is Found, the length of the cached item may or may not be known:
142    /// if no expected length was provided and the body is still streaming, the length is unknown.
143    ///
144    /// When always_use_requested_range is false, and the length is unknown,
145    /// body() returns the full body regardless of the requested range.
146    /// When always_use_requested_range is true, and the length is unknown,
147    /// body() blocks until the start of the range is available.
148    always_use_requested_range: bool,
149}
150
151impl CacheEntry {
152    /// Set the always_use_requested_range flag.
153    /// This applies to all subsequent lookups from this CacheEntry or future entries derived from it.
154    pub fn with_always_use_requested_range(self, always_use_requested_range: bool) -> Self {
155        Self {
156            always_use_requested_range,
157            ..self
158        }
159    }
160
161    /// Return a stub entry to hold in CacheBusy.
162    pub fn stub(&self) -> CacheEntry {
163        Self {
164            key: self.key.clone(),
165            found: None,
166            go_get: None,
167            always_use_requested_range: false,
168        }
169    }
170
171    /// Returns the key used to generate this CacheEntry.
172    pub fn key(&self) -> &CacheKey {
173        &self.key
174    }
175    /// Returns the data found in the cache, if any was present.
176    pub fn found(&self) -> Option<&Found> {
177        self.found.as_ref()
178    }
179
180    /// Returns the data found in the cache, if any was present.
181    pub fn found_mut(&mut self) -> Option<&mut Found> {
182        self.found.as_mut()
183    }
184
185    /// Returns the obligation to fetch, if required
186    pub fn go_get(&self) -> Option<&Obligation> {
187        self.go_get.as_ref()
188    }
189
190    /// Ignore the obligation to fetch, if present.
191    /// Returns true if it actually canceled an obligation.
192    pub fn cancel(&mut self) -> bool {
193        self.go_get.take().is_some()
194    }
195
196    /// Access the body of the cached item, if available.
197    pub async fn body(&self, from: Option<u64>, to: Option<u64>) -> Result<Body, crate::Error> {
198        let found = self
199            .found
200            .as_ref()
201            .ok_or(crate::Error::CacheError(Error::Missing))?;
202        found
203            .get_body()
204            .with_range(from, to)
205            .with_always_use_requested_range(self.always_use_requested_range)
206            .build()
207            .await
208    }
209
210    /// Insert the provided body into the cache.
211    ///
212    /// Returns a CacheEntry where the new item is Found.
213    pub fn insert(
214        &mut self,
215        options: WriteOptions,
216        body: Body,
217    ) -> Result<CacheEntry, crate::Error> {
218        let go_get = self.go_get.take().ok_or(Error::NotRevalidatable)?;
219        let found = go_get.insert(options, body);
220        Ok(CacheEntry {
221            key: self.key.clone(),
222            found: Some(found),
223            go_get: None,
224            always_use_requested_range: self.always_use_requested_range,
225        })
226    }
227
228    /// Freshen the existing cache item according to the new write options,
229    /// without changing the body.
230    pub async fn update(&mut self, options: WriteOptions) -> Result<(), crate::Error> {
231        let go_get = self.go_get.take().ok_or(Error::NotRevalidatable)?;
232        match go_get.update(options).await {
233            Ok(()) => Ok(()),
234            Err((go_get, err)) => {
235                // On failure, preserve the obligation.
236                self.go_get = Some(go_get);
237                Err(err)
238            }
239        }
240    }
241}
242
243/// A successful retrieval of an item from the cache.
244#[derive(Debug)]
245pub struct Found {
246    data: Arc<CacheData>,
247
248    /// The handle for the last body used to read from this Found.
249    ///
250    /// Only one Body may be outstanding from a given Found at a time.
251    /// (This is an implementation restriction within the compute platform).
252    /// We mirror the BodyHandle here when we create it; we can later check whether the handle is
253    /// still valid, to find an outstanding read.
254    pub last_body_handle: Option<BodyHandle>,
255}
256
257impl From<Arc<CacheData>> for Found {
258    fn from(data: Arc<CacheData>) -> Self {
259        Found {
260            data,
261            last_body_handle: None,
262        }
263    }
264}
265
266impl Found {
267    fn get_body(&self) -> GetBodyBuilder<'_> {
268        self.data.as_ref().body()
269    }
270
271    /// Access the metadata of the cached object.
272    pub fn meta(&self) -> &ObjectMeta {
273        self.data.get_meta()
274    }
275
276    /// The length of the cached object, if known.
277    pub fn length(&self) -> Option<u64> {
278        self.data.length()
279    }
280}
281
282/// Cache for a service.
283///
284// TODO: cceckman-at-fastly:
285// Explain some about how this works:
286// - Request collapsing
287// - Stale-while-revalidate
288pub struct Cache {
289    inner: moka::future::Cache<CacheKey, Arc<CacheKeyObjects>>,
290}
291
292impl Default for Cache {
293    fn default() -> Self {
294        // TODO: cceckman-at-fastly:
295        // Weight by size, allow a cap on max size?
296        let inner = moka::future::Cache::builder()
297            .eviction_listener(|key, _value, cause| {
298                tracing::info!("cache eviction of {key:?}: {cause:?}")
299            })
300            .build();
301        Cache { inner }
302    }
303}
304
305impl Cache {
306    /// Perform a non-transactional lookup.
307    pub async fn lookup(&self, key: &CacheKey, headers: &HeaderMap) -> CacheEntry {
308        let found = self
309            .inner
310            .get_with_by_ref(key, async { Default::default() })
311            .await
312            .get(headers)
313            .map(|data| Found {
314                data,
315                last_body_handle: None,
316            });
317        CacheEntry {
318            key: key.clone(),
319            found,
320            go_get: None,
321            always_use_requested_range: false,
322        }
323    }
324
325    /// Perform a transactional lookup.
326    pub async fn transaction_lookup(
327        &self,
328        key: &CacheKey,
329        headers: &HeaderMap,
330        ok_to_wait: bool,
331    ) -> CacheEntry {
332        let (found, obligation) = self
333            .inner
334            .get_with_by_ref(key, async { Default::default() })
335            .await
336            .transaction_get(headers, ok_to_wait)
337            .await;
338        CacheEntry {
339            key: key.clone(),
340            found: found.map(|v| v.into()),
341            go_get: obligation,
342            always_use_requested_range: false,
343        }
344    }
345
346    /// Perform a non-transactional lookup for the given cache key.
347    /// Note: races with other insertions, including transactional insertions.
348    /// Last writer wins!
349    pub async fn insert(
350        &self,
351        key: &CacheKey,
352        request_headers: HeaderMap,
353        options: WriteOptions,
354        body: Body,
355    ) {
356        self.inner
357            .get_with_by_ref(key, async { Default::default() })
358            .await
359            .insert(request_headers, options, body, None);
360    }
361
362    /// Purge/soft-purge all cache entries corresponding to the given surrogate key.
363    /// Returns the number of entries (variants) purged.
364    ///
365    /// Note: this does not block concurrent reads _or inserts_; an insertion can race with the
366    /// purge.
367    pub fn purge(&self, key: SurrogateKey, soft_purge: bool) -> usize {
368        self.inner
369            .iter()
370            .map(|(_, entry)| entry.purge(&key, soft_purge))
371            .sum()
372    }
373}
374
375/// Options that can be applied to a write, e.g. insert or transaction_insert.
376#[derive(Default, Clone)]
377pub struct WriteOptions {
378    pub max_age: Duration,
379    pub initial_age: Duration,
380    pub stale_while_revalidate: Duration,
381    pub vary_rule: VaryRule,
382    pub user_metadata: Bytes,
383    pub length: Option<u64>,
384    pub sensitive_data: bool,
385    pub edge_max_age: Duration,
386    pub surrogate_keys: SurrogateKeySet,
387}
388
389impl WriteOptions {
390    pub fn new(max_age: Duration) -> Self {
391        WriteOptions {
392            max_age,
393            initial_age: Duration::ZERO,
394            stale_while_revalidate: Duration::ZERO,
395            vary_rule: VaryRule::default(),
396            user_metadata: Bytes::new(),
397            length: None,
398            sensitive_data: false,
399            edge_max_age: max_age,
400            surrogate_keys: Default::default(),
401        }
402    }
403}
404
405/// Optional override for response caching behavior.
406#[derive(Clone, Debug, Default)]
407pub enum CacheOverride {
408    /// Do not override the behavior specified in the origin response's cache control headers.
409    #[default]
410    None,
411    /// Do not cache the response to this request, regardless of the origin response's headers.
412    Pass,
413    /// Override particular cache control settings.
414    ///
415    /// The origin response's cache control headers will be used for ttl and stale_while_revalidate if `None`.
416    Override {
417        ttl: Option<u32>,
418        stale_while_revalidate: Option<u32>,
419        pci: bool,
420        surrogate_key: Option<HeaderValue>,
421    },
422}
423
424impl CacheOverride {
425    pub fn is_pass(&self) -> bool {
426        matches!(self, Self::Pass)
427    }
428
429    /// Convert from the representation suitable for passing across the ABI boundary.
430    ///
431    /// Returns `None` if the tag is not recognized. Depending on the tag, some of the values may be
432    /// ignored.
433    pub fn from_abi(
434        tag: u32,
435        ttl: u32,
436        swr: u32,
437        surrogate_key: Option<HeaderValue>,
438    ) -> Option<Self> {
439        CacheOverrideTag::from_bits(tag).map(|tag| {
440            if tag.contains(CacheOverrideTag::PASS) {
441                return CacheOverride::Pass;
442            }
443            if tag.is_empty() && surrogate_key.is_none() {
444                return CacheOverride::None;
445            }
446            let ttl = if tag.contains(CacheOverrideTag::TTL) {
447                Some(ttl)
448            } else {
449                None
450            };
451            let stale_while_revalidate = if tag.contains(CacheOverrideTag::STALE_WHILE_REVALIDATE) {
452                Some(swr)
453            } else {
454                None
455            };
456            let pci = tag.contains(CacheOverrideTag::PCI);
457            CacheOverride::Override {
458                ttl,
459                stale_while_revalidate,
460                pci,
461                surrogate_key,
462            }
463        })
464    }
465}
466
467/// Maximum length of surrogate keys (when combined with spaces).
468const MAX_SURROGATE_KEYS_LENGTH: usize = 16 * 1024;
469/// Maximum length of a single surrogate key.
470const MAX_SURROGATE_KEY_LENGTH: usize = 1024;
471
472#[derive(Debug, Default, Clone)]
473pub struct SurrogateKeySet(HashSet<SurrogateKey>);
474
475impl std::fmt::Display for SurrogateKeySet {
476    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        write!(f, "{{")?;
478        for (i, item) in self.0.iter().enumerate() {
479            if i == 0 {
480                write!(f, "{item}")?;
481            } else {
482                write!(f, " {item}")?;
483            }
484        }
485        write!(f, "}}")
486    }
487}
488
489impl TryFrom<&[u8]> for SurrogateKeySet {
490    type Error = crate::Error;
491
492    fn try_from(s: &[u8]) -> Result<Self, Self::Error> {
493        if s.len() > MAX_SURROGATE_KEYS_LENGTH {
494            return Err(
495                Error::InvalidArgument("surrogate key set exceeds maximum length (16Ki)").into(),
496            );
497        }
498        let result: Result<HashSet<_>, _> = s
499            .split(|c| *c == b' ')
500            .filter(|sk| !sk.is_empty())
501            .map(SurrogateKey::try_from)
502            .collect();
503        Ok(SurrogateKeySet(result?))
504    }
505}
506
507impl std::str::FromStr for SurrogateKeySet {
508    type Err = crate::Error;
509
510    fn from_str(s: &str) -> Result<Self, Self::Err> {
511        s.as_bytes().try_into()
512    }
513}
514
515/// A validated surrogate key: a non-empty string containing only visible ASCII characters.
516#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone)]
517pub struct SurrogateKey(String);
518
519impl std::fmt::Display for SurrogateKey {
520    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521        self.0.fmt(f)
522    }
523}
524
525impl TryFrom<&[u8]> for SurrogateKey {
526    type Error = Error;
527
528    fn try_from(b: &[u8]) -> Result<Self, Self::Error> {
529        if b.len() > MAX_SURROGATE_KEY_LENGTH {
530            return Err(Error::InvalidArgument(
531                "surrogate key exceeds maximum length (1024)",
532            ));
533        }
534
535        if !b.iter().all(|c| c.is_ascii_graphic()) {
536            return Err(Error::InvalidArgument(
537                "surrogate key contains characters other than graphical ASCII",
538            ));
539        }
540        let s = unsafe {
541            // All graphic ASCII characters are equivalent to the corresponding UTF-8 codepoints.
542            str::from_utf8_unchecked(b)
543        };
544        Ok(SurrogateKey(s.to_owned()))
545    }
546}
547
548impl std::str::FromStr for SurrogateKey {
549    type Err = Error;
550
551    fn from_str(s: &str) -> Result<Self, Self::Err> {
552        s.as_bytes().try_into()
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use std::rc::Rc;
559
560    use http::HeaderName;
561    use proptest::prelude::*;
562
563    use super::*;
564
565    proptest! {
566        #[test]
567        fn reject_cache_key_too_long(l in 4097usize..5000) {
568            let mut v : Vec<u8> = Vec::new();
569            v.resize(l, 0);
570            CacheKey::try_from(&v).unwrap_err();
571        }
572    }
573
574    proptest! {
575        #[test]
576        fn accept_valid_cache_key_len(l in 0usize..4096) {
577            let mut v : Vec<u8> = Vec::new();
578            v.resize(l, 0);
579            let _ = CacheKey::try_from(&v).unwrap();
580        }
581    }
582
583    proptest! {
584        #[test]
585        fn nontransactional_insert_lookup(
586                key in any::<CacheKey>(),
587                max_age in any::<u32>(),
588                initial_age in any::<u32>(),
589                value in any::<Vec<u8>>()) {
590            let cache = Cache::default();
591
592            // We can't use tokio::test and proptest! together; both alter the signature of the
593            // test function, and are not aware of each other enough for it to pass.
594            let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
595            rt.block_on(async {
596                let empty = cache.lookup(&key, &HeaderMap::default()).await;
597                assert!(empty.found().is_none());
598                // The non-transactional case does not produce an obligation (go_get)
599                assert!(empty.go_get().is_none());
600
601                let write_options = WriteOptions {
602                    max_age: Duration::from_secs(max_age as u64),
603                    initial_age: Duration::from_secs(initial_age as u64),
604                    ..Default::default()
605                };
606
607                cache.insert(&key, HeaderMap::default(), write_options, value.clone().into()).await;
608
609                let nonempty = cache.lookup(&key, &HeaderMap::default()).await;
610                let found = nonempty.found().expect("should have found inserted key");
611                let got = found.get_body().build().await.unwrap().read_into_vec().await.unwrap();
612                assert_eq!(got, value);
613            });
614        }
615    }
616
617    #[test]
618    fn parse_surrogate_keys() {
619        let keys: SurrogateKeySet = "keyA keyB".parse().unwrap();
620        let key_a: SurrogateKey = "keyA".parse().unwrap();
621        let key_b: SurrogateKey = "keyB".parse().unwrap();
622        assert_eq!(keys.0.len(), 2);
623        assert!(keys.0.contains(&key_a));
624        assert!(keys.0.contains(&key_b));
625    }
626
627    #[tokio::test]
628    async fn insert_immediately_stale() {
629        let cache = Cache::default();
630        let key = ([1u8].as_slice()).try_into().unwrap();
631
632        // Insert an already-stale entry:
633        let write_options = WriteOptions {
634            max_age: Duration::from_secs(1),
635            initial_age: Duration::from_secs(2),
636            ..Default::default()
637        };
638
639        let mut body = Body::empty();
640        body.push_back([1u8].as_slice());
641
642        cache
643            .insert(&key, HeaderMap::default(), write_options, body)
644            .await;
645
646        let nonempty = cache.lookup(&key, &HeaderMap::default()).await;
647        let found = nonempty.found().expect("should have found inserted key");
648        assert!(!found.meta().is_fresh());
649    }
650
651    #[tokio::test]
652    async fn test_vary() {
653        let cache = Cache::default();
654        let key = ([1u8].as_slice()).try_into().unwrap();
655
656        let header_name = HeaderName::from_static("x-viceroy-test");
657        let request_headers: HeaderMap = [(header_name.clone(), HeaderValue::from_static("test"))]
658            .into_iter()
659            .collect();
660
661        let write_options = WriteOptions {
662            max_age: Duration::from_secs(100),
663            vary_rule: VaryRule::new([&header_name].into_iter()),
664            ..Default::default()
665        };
666        let body = Body::empty();
667        cache
668            .insert(&key, request_headers.clone(), write_options, body)
669            .await;
670
671        let empty_headers = cache.lookup(&key, &HeaderMap::default()).await;
672        assert!(empty_headers.found().is_none());
673
674        let matched_headers = cache.lookup(&key, &request_headers).await;
675        assert!(matched_headers.found.is_some());
676
677        let r2_headers: HeaderMap = [(header_name.clone(), HeaderValue::from_static("assert"))]
678            .into_iter()
679            .collect();
680        let mismatched_headers = cache.lookup(&key, &r2_headers).await;
681        assert!(mismatched_headers.found.is_none());
682    }
683
684    #[tokio::test]
685    async fn insert_stale_while_revalidate() {
686        let cache = Cache::default();
687        let key = ([1u8].as_slice()).try_into().unwrap();
688
689        // Insert an already-stale entry, with an SWR period:
690        let write_options = WriteOptions {
691            max_age: Duration::from_secs(1),
692            initial_age: Duration::from_secs(2),
693            stale_while_revalidate: Duration::from_secs(10),
694            ..Default::default()
695        };
696
697        let mut body = Body::empty();
698        body.push_back([1u8].as_slice());
699        cache
700            .insert(&key, HeaderMap::default(), write_options, body)
701            .await;
702
703        let nonempty = cache.lookup(&key, &HeaderMap::default()).await;
704        let found = nonempty.found().expect("should have found inserted key");
705        assert!(!found.meta().is_fresh());
706        assert!(found.meta().is_usable());
707    }
708
709    #[tokio::test]
710    async fn insert_and_revalidate() {
711        let cache = Cache::default();
712        let key = ([1u8].as_slice()).try_into().unwrap();
713
714        // Insert an already-stale entry, with an SWR period:
715        let write_options = WriteOptions {
716            max_age: Duration::from_secs(1),
717            initial_age: Duration::from_secs(2),
718            stale_while_revalidate: Duration::from_secs(10),
719            ..Default::default()
720        };
721
722        let mut body = Body::empty();
723        body.push_back([1u8].as_slice());
724        cache
725            .insert(&key, HeaderMap::default(), write_options, body)
726            .await;
727
728        let mut txn1 = cache
729            .transaction_lookup(&key, &HeaderMap::default(), true)
730            .await;
731        let found = txn1.found().expect("should have found inserted key");
732        assert!(!found.meta().is_fresh());
733        assert!(found.meta().is_usable());
734
735        assert!(txn1.go_get().is_some());
736
737        // Another lookup should not get the obligation *or* block
738        {
739            let txn2 = cache
740                .transaction_lookup(&key, &HeaderMap::default(), true)
741                .await;
742            let found = txn2.found().expect("should have found inserted key");
743            assert!(!found.meta().is_fresh());
744            assert!(found.meta().is_usable());
745            assert!(txn2.go_get().is_none());
746        }
747
748        txn1.update(WriteOptions {
749            max_age: Duration::from_secs(10),
750            stale_while_revalidate: Duration::from_secs(10),
751            ..WriteOptions::default()
752        })
753        .await
754        .unwrap();
755
756        // After this, should get the new response:
757        let txn3 = cache
758            .transaction_lookup(&key, &HeaderMap::default(), true)
759            .await;
760        assert!(txn3.go_get().is_none());
761        let found = txn3.found().expect("should find updated entry");
762        assert!(found.meta().is_usable());
763        assert!(found.meta().is_fresh());
764    }
765
766    #[tokio::test]
767    async fn cannot_revalidate_first_insert() {
768        let cache = Cache::default();
769        let key = ([1u8].as_slice()).try_into().unwrap();
770
771        let mut txn1 = cache
772            .transaction_lookup(&key, &HeaderMap::default(), false)
773            .await;
774        assert!(txn1.found().is_none());
775        let opts = WriteOptions {
776            max_age: Duration::from_secs(10),
777            stale_while_revalidate: Duration::from_secs(10),
778            ..WriteOptions::default()
779        };
780        txn1.update(opts.clone()).await.unwrap_err();
781
782        // But we should still be able to insert.
783        txn1.insert(opts.clone(), Body::empty()).unwrap();
784    }
785
786    #[tokio::test]
787    async fn purge_by_surrogate_key() {
788        let cache = Cache::default();
789        let key: CacheKey = ([1u8].as_slice()).try_into().unwrap();
790
791        let surrogate_keys: SurrogateKeySet = "one two three".parse().unwrap();
792        let opts = WriteOptions {
793            max_age: Duration::from_secs(100),
794            surrogate_keys: surrogate_keys.clone(),
795            ..WriteOptions::default()
796        };
797
798        cache
799            .insert(&key, HeaderMap::default(), opts.clone(), Body::empty())
800            .await;
801        let result = cache.lookup(&key, &HeaderMap::default()).await;
802        assert!(result.found().is_some());
803
804        assert_eq!(cache.purge("two".parse().unwrap(), false), 1);
805        let result = cache.lookup(&key, &HeaderMap::default()).await;
806        assert!(result.found().is_none());
807    }
808
809    #[tokio::test]
810    async fn purge_variant_by_surrogate_key() {
811        let cache = Rc::new(Cache::default());
812        let key: CacheKey = ([1u8].as_slice()).try_into().unwrap();
813
814        // "Introduction and Variations on a Theme by Mozart, Op. 9, is one of Fernando Sor's most
815        // famous works for guitar."
816        let vary_header = HeaderName::from_static("opus-9");
817        let insert = {
818            let cache = cache.clone();
819            let key = key.clone();
820
821            move |surrogate_keys: &str, variation: &str| {
822                let mut headers = HeaderMap::new();
823                headers.insert(vary_header.clone(), variation.try_into().unwrap());
824                let surrogate_keys: SurrogateKeySet = surrogate_keys.parse().unwrap();
825                let opts = WriteOptions {
826                    max_age: Duration::from_secs(100),
827                    surrogate_keys: surrogate_keys.clone(),
828                    vary_rule: VaryRule::new([&vary_header]),
829                    ..WriteOptions::default()
830                };
831                let cache = cache.clone();
832                let key = key.clone();
833                async move {
834                    cache
835                        .insert(&key, headers.clone(), opts.clone(), Body::empty())
836                        .await;
837                    headers
838                }
839            }
840        };
841
842        let h1 = insert("one two three", "twinkle twinkle").await;
843        let h2 = insert("one three", "abcdefg").await;
844        assert!(cache.lookup(&key, &h1).await.found().is_some());
845        assert!(cache.lookup(&key, &h2).await.found().is_some());
846
847        assert_eq!(cache.purge("two".parse().unwrap(), false), 1);
848        assert!(cache.lookup(&key, &h1).await.found().is_none());
849        assert!(cache.lookup(&key, &h2).await.found().is_some());
850    }
851
852    #[tokio::test]
853    async fn soft_purge() {
854        let cache = Cache::default();
855        let key: CacheKey = ([1u8].as_slice()).try_into().unwrap();
856
857        let surrogate_keys: SurrogateKeySet = "one two three".parse().unwrap();
858        let opts = WriteOptions {
859            max_age: Duration::from_secs(100),
860            surrogate_keys: surrogate_keys.clone(),
861            ..WriteOptions::default()
862        };
863
864        cache
865            .insert(&key, HeaderMap::default(), opts.clone(), Body::empty())
866            .await;
867        let result = cache.lookup(&key, &HeaderMap::default()).await;
868        assert!(result.found().is_some());
869
870        assert_eq!(cache.purge("two".parse().unwrap(), true), 1);
871        let result = cache
872            .transaction_lookup(&key, &HeaderMap::default(), false)
873            .await;
874        let found = result.found().unwrap();
875        assert!(!found.meta().is_fresh());
876        assert!(found.meta().is_usable());
877        assert!(result.go_get().is_some());
878    }
879}