tower_cache/
lib.rs

1#![warn(missing_docs, unreachable_pub)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! # Cache layer for `tower::Service`s
5//!
6//! [`CacheLayer`] is a tower Layer that provides caches for `Service`s by using
7//! another service to handle the cache. This allows the usage of asynchronous
8//! and external caches.
9//!
10//! ## Usage
11//!
12//! ```rust
13//! use std::convert::Infallible;
14//! use tower::{ServiceBuilder, service_fn};
15//! use tower_cache::{
16//!     CacheLayer,
17//!     lru::LruProvider,
18//! };
19//! async fn handler(req: String) -> Result<String, Infallible> {
20//!     Ok(req.to_uppercase())
21//! }
22//!
23//! // Initialize the cache provider service
24//! let lru_provider = LruProvider::<String, String>::new(20);
25//!
26//! // Initialize the service
27//! let my_service = service_fn(handler);
28//!
29//! // Wrap the service with CacheLayer.
30//! let my_service = ServiceBuilder::new()
31//!     .layer(CacheLayer::<_, String>::new(lru_provider))
32//!     .service(handler);
33//! ```
34//!
35//! ## Creating cache providers
36//!
37//! A cache provider is a [`tower::Service`] that takes a [`ProviderRequest`]
38//! as request and returns a [`ProviderResponse`].
39//!
40
41use std::{
42    error, fmt,
43    future::Future,
44    marker::PhantomData,
45    pin::Pin,
46    task::{Context, Poll},
47};
48use tower::{Layer, Service};
49
50#[cfg(feature = "lru")]
51#[cfg_attr(docsrs, doc(cfg(feature = "lru")))]
52pub mod lru;
53
54/// Layer that adds cache to a [`tower::Service`]
55///
56/// This works by using a cache provider service that takes a [`ProviderRequest`]
57/// and returns a [`ProviderResponse`].
58pub struct CacheLayer<'a, P, R> {
59    provider: P,
60    _phantom: PhantomData<&'a R>,
61}
62
63impl<'a, P, R> CacheLayer<'a, P, R> {
64    /// Create a new [`CacheLayer`]
65    pub fn new(provider: P) -> Self {
66        Self {
67            provider,
68            _phantom: PhantomData,
69        }
70    }
71}
72
73impl<'a, P, R, S> Layer<S> for CacheLayer<'a, P, R>
74where
75    P: Clone,
76{
77    type Service = CacheService<'a, S, P>;
78
79    fn layer(&self, inner: S) -> Self::Service {
80        CacheService {
81            inner,
82            provider: self.provider.clone(),
83            _phantom: PhantomData,
84        }
85    }
86}
87
88/// Service generated by [`CacheLayer`].
89pub struct CacheService<'a, S, P> {
90    inner: S,
91    provider: P,
92    _phantom: PhantomData<&'a ()>,
93}
94
95impl<'a, S, P, R> Service<R> for CacheService<'a, S, P>
96where
97    S: Service<R> + Clone + Send + 'a,
98    S::Response: Clone + Send + 'a,
99    S::Error: Into<Box<dyn error::Error + Send + Sync>>,
100    S::Future: Send + 'a,
101    P: Service<ProviderRequest<R, S::Response>, Response = ProviderResponse<S::Response>>
102        + Clone
103        + Send
104        + 'a,
105    P::Response: Send + 'a,
106    P::Error: Into<Box<dyn error::Error + Send + Sync>> + Send,
107    P::Future: Send + 'a,
108    R: Clone + Send + Sync + 'a,
109{
110    type Response = S::Response;
111    type Error = Error;
112    type Future = CacheFuture<'a, R, S>;
113
114    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115        self.provider
116            .poll_ready(cx)
117            .map_err(|e| Error::ServiceError(e.into()))
118    }
119
120    fn call(&mut self, request: R) -> Self::Future {
121        let mut provider = self.provider.clone();
122        let mut inner = self.inner.clone();
123        let idem_fut = self.provider.call(ProviderRequest::Get(request.clone()));
124
125        Box::pin(async move {
126            let res = match idem_fut.await {
127                // If we have a response in the cache, we can immediately return without
128                // calling the inner service.
129                Ok(ProviderResponse::Found(res)) => Ok(res),
130                // Response not found - we need to call the inner service and update the
131                // cache.
132                Ok(ProviderResponse::NotFound) => {
133                    // Fetch the response from the inner service.
134                    let response = inner
135                        .call(request.clone())
136                        .await
137                        .map_err(|e| Error::ServiceError(e.into()));
138                    match response {
139                        Ok(res) => {
140                            // Store the response in the cache provider.
141                            let new_res = res.clone();
142                            match provider
143                                .call(ProviderRequest::Insert(request, new_res))
144                                .await
145                            {
146                                Ok(_) => Ok(res),
147                                Err(e) => Err(Error::ProviderError(e.into())),
148                            }
149                        }
150                        res => res,
151                    }
152                }
153                Err(e) => Err(Error::ProviderError(e.into())),
154            };
155
156            res
157        })
158    }
159}
160
161/// Requests sent to the cache provider
162#[derive(Clone, Debug)]
163pub enum ProviderRequest<Req, Res> {
164    /// Check if the provider has a similar request
165    Get(Req),
166    /// Insert a response into the provider
167    Insert(Req, Res),
168}
169
170/// Responses sent by the cache provider
171#[derive(Debug)]
172pub enum ProviderResponse<Res> {
173    /// The cache provider found a similar request
174    Found(Res),
175    /// The cache provider did not find a similar request
176    NotFound,
177}
178
179/// Error returned by the [`CacheLayer`]
180///
181/// As errors can come from both the cache provider, the inner service, or
182/// the layer itself, this uses a custom enum to propagate errors.
183#[derive(Debug)]
184pub enum Error {
185    /// Error generated by the cache provider
186    ProviderError(Box<dyn error::Error + Send + Sync>),
187    /// Error generated by the inner service
188    ServiceError(Box<dyn error::Error + Send + Sync>),
189    /// Internal error
190    InternalError,
191}
192
193impl error::Error for Error {}
194
195impl fmt::Display for Error {
196    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197        match self {
198            Error::ProviderError(e) => write!(f, "provider error: {}", e),
199            Error::ServiceError(e) => write!(f, "service error: {}", e),
200            Error::InternalError => write!(f, "internal error"),
201        }
202    }
203}
204
205type CacheFuture<'a, R, S> =
206    Pin<Box<dyn Future<Output = Result<<S as Service<R>>::Response, Error>> + Send + 'a>>;
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use std::{
212        collections::HashMap,
213        future::ready,
214        sync::{Arc, Mutex},
215    };
216    use tower::{service_fn, Service, ServiceBuilder};
217
218    #[derive(Clone, Default, Debug)]
219    struct SimpleCache {
220        cache: Arc<Mutex<HashMap<String, String>>>,
221    }
222
223    impl Service<ProviderRequest<String, String>> for SimpleCache {
224        type Response = ProviderResponse<String>;
225        type Error = Error;
226        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
227
228        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
229            Poll::Ready(Ok(()))
230        }
231
232        fn call(&mut self, request: ProviderRequest<String, String>) -> Self::Future {
233            Box::pin(ready(match request {
234                ProviderRequest::Get(req) => match self.cache.lock().unwrap().get(&req) {
235                    Some(res) => Ok(ProviderResponse::Found(res.clone())),
236                    None => Ok(ProviderResponse::NotFound),
237                },
238                ProviderRequest::Insert(req, res) => {
239                    self.cache.lock().unwrap().insert(req, res.clone());
240                    Ok(ProviderResponse::Found(res))
241                }
242            }))
243        }
244    }
245
246    async fn service(req: String) -> Result<String, Error> {
247        Ok(req.to_uppercase())
248    }
249
250    #[tokio::test]
251    async fn test_insert() -> Result<(), Error> {
252        let cache = SimpleCache::default();
253        let cache_layer = CacheLayer::<_, String>::new(cache.clone());
254
255        let mut service = ServiceBuilder::new()
256            .layer(cache_layer)
257            .service(service_fn(service));
258
259        assert_eq!(cache.cache.lock().unwrap().len(), 0);
260        let res = service.call(String::from("Hello")).await?;
261
262        assert_eq!(res, String::from("HELLO"));
263        {
264            let inner_cache = cache.cache.lock().unwrap();
265            assert_eq!(inner_cache.len(), 1);
266            assert_eq!(
267                inner_cache.get(&String::from("Hello")),
268                Some(&String::from("HELLO"))
269            );
270        }
271
272        Ok(())
273    }
274
275    #[tokio::test]
276    async fn test_get() -> Result<(), Error> {
277        let cache = SimpleCache::default();
278        {
279            let mut inner_cache = cache.cache.lock().unwrap();
280            inner_cache.insert(String::from("Hello"), String::from("hello"));
281        }
282        let cache_layer = CacheLayer::<_, String>::new(cache.clone());
283
284        let mut service = ServiceBuilder::new()
285            .layer(cache_layer)
286            .service(service_fn(service));
287
288        let res = service.call(String::from("Hello")).await?;
289        assert_eq!(res, String::from("hello"));
290
291        Ok(())
292    }
293}