1#![warn(missing_docs, unreachable_pub)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::{
42 error, fmt,
43 future::Future,
44 marker::PhantomData,
45 pin::Pin,
46 task::{Context, Poll},
47};
48use tower::{Layer, Service};
49
50#[cfg(feature = "lru")]
51#[cfg_attr(docsrs, doc(cfg(feature = "lru")))]
52pub mod lru;
53
54pub struct CacheLayer<'a, P, R> {
59 provider: P,
60 _phantom: PhantomData<&'a R>,
61}
62
63impl<'a, P, R> CacheLayer<'a, P, R> {
64 pub fn new(provider: P) -> Self {
66 Self {
67 provider,
68 _phantom: PhantomData,
69 }
70 }
71}
72
73impl<'a, P, R, S> Layer<S> for CacheLayer<'a, P, R>
74where
75 P: Clone,
76{
77 type Service = CacheService<'a, S, P>;
78
79 fn layer(&self, inner: S) -> Self::Service {
80 CacheService {
81 inner,
82 provider: self.provider.clone(),
83 _phantom: PhantomData,
84 }
85 }
86}
87
88pub struct CacheService<'a, S, P> {
90 inner: S,
91 provider: P,
92 _phantom: PhantomData<&'a ()>,
93}
94
95impl<'a, S, P, R> Service<R> for CacheService<'a, S, P>
96where
97 S: Service<R> + Clone + Send + 'a,
98 S::Response: Clone + Send + 'a,
99 S::Error: Into<Box<dyn error::Error + Send + Sync>>,
100 S::Future: Send + 'a,
101 P: Service<ProviderRequest<R, S::Response>, Response = ProviderResponse<S::Response>>
102 + Clone
103 + Send
104 + 'a,
105 P::Response: Send + 'a,
106 P::Error: Into<Box<dyn error::Error + Send + Sync>> + Send,
107 P::Future: Send + 'a,
108 R: Clone + Send + Sync + 'a,
109{
110 type Response = S::Response;
111 type Error = Error;
112 type Future = CacheFuture<'a, R, S>;
113
114 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115 self.provider
116 .poll_ready(cx)
117 .map_err(|e| Error::ServiceError(e.into()))
118 }
119
120 fn call(&mut self, request: R) -> Self::Future {
121 let mut provider = self.provider.clone();
122 let mut inner = self.inner.clone();
123 let idem_fut = self.provider.call(ProviderRequest::Get(request.clone()));
124
125 Box::pin(async move {
126 let res = match idem_fut.await {
127 Ok(ProviderResponse::Found(res)) => Ok(res),
130 Ok(ProviderResponse::NotFound) => {
133 let response = inner
135 .call(request.clone())
136 .await
137 .map_err(|e| Error::ServiceError(e.into()));
138 match response {
139 Ok(res) => {
140 let new_res = res.clone();
142 match provider
143 .call(ProviderRequest::Insert(request, new_res))
144 .await
145 {
146 Ok(_) => Ok(res),
147 Err(e) => Err(Error::ProviderError(e.into())),
148 }
149 }
150 res => res,
151 }
152 }
153 Err(e) => Err(Error::ProviderError(e.into())),
154 };
155
156 res
157 })
158 }
159}
160
161#[derive(Clone, Debug)]
163pub enum ProviderRequest<Req, Res> {
164 Get(Req),
166 Insert(Req, Res),
168}
169
170#[derive(Debug)]
172pub enum ProviderResponse<Res> {
173 Found(Res),
175 NotFound,
177}
178
179#[derive(Debug)]
184pub enum Error {
185 ProviderError(Box<dyn error::Error + Send + Sync>),
187 ServiceError(Box<dyn error::Error + Send + Sync>),
189 InternalError,
191}
192
193impl error::Error for Error {}
194
195impl fmt::Display for Error {
196 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197 match self {
198 Error::ProviderError(e) => write!(f, "provider error: {}", e),
199 Error::ServiceError(e) => write!(f, "service error: {}", e),
200 Error::InternalError => write!(f, "internal error"),
201 }
202 }
203}
204
205type CacheFuture<'a, R, S> =
206 Pin<Box<dyn Future<Output = Result<<S as Service<R>>::Response, Error>> + Send + 'a>>;
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::{
212 collections::HashMap,
213 future::ready,
214 sync::{Arc, Mutex},
215 };
216 use tower::{service_fn, Service, ServiceBuilder};
217
218 #[derive(Clone, Default, Debug)]
219 struct SimpleCache {
220 cache: Arc<Mutex<HashMap<String, String>>>,
221 }
222
223 impl Service<ProviderRequest<String, String>> for SimpleCache {
224 type Response = ProviderResponse<String>;
225 type Error = Error;
226 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
227
228 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
229 Poll::Ready(Ok(()))
230 }
231
232 fn call(&mut self, request: ProviderRequest<String, String>) -> Self::Future {
233 Box::pin(ready(match request {
234 ProviderRequest::Get(req) => match self.cache.lock().unwrap().get(&req) {
235 Some(res) => Ok(ProviderResponse::Found(res.clone())),
236 None => Ok(ProviderResponse::NotFound),
237 },
238 ProviderRequest::Insert(req, res) => {
239 self.cache.lock().unwrap().insert(req, res.clone());
240 Ok(ProviderResponse::Found(res))
241 }
242 }))
243 }
244 }
245
246 async fn service(req: String) -> Result<String, Error> {
247 Ok(req.to_uppercase())
248 }
249
250 #[tokio::test]
251 async fn test_insert() -> Result<(), Error> {
252 let cache = SimpleCache::default();
253 let cache_layer = CacheLayer::<_, String>::new(cache.clone());
254
255 let mut service = ServiceBuilder::new()
256 .layer(cache_layer)
257 .service(service_fn(service));
258
259 assert_eq!(cache.cache.lock().unwrap().len(), 0);
260 let res = service.call(String::from("Hello")).await?;
261
262 assert_eq!(res, String::from("HELLO"));
263 {
264 let inner_cache = cache.cache.lock().unwrap();
265 assert_eq!(inner_cache.len(), 1);
266 assert_eq!(
267 inner_cache.get(&String::from("Hello")),
268 Some(&String::from("HELLO"))
269 );
270 }
271
272 Ok(())
273 }
274
275 #[tokio::test]
276 async fn test_get() -> Result<(), Error> {
277 let cache = SimpleCache::default();
278 {
279 let mut inner_cache = cache.cache.lock().unwrap();
280 inner_cache.insert(String::from("Hello"), String::from("hello"));
281 }
282 let cache_layer = CacheLayer::<_, String>::new(cache.clone());
283
284 let mut service = ServiceBuilder::new()
285 .layer(cache_layer)
286 .service(service_fn(service));
287
288 let res = service.call(String::from("Hello")).await?;
289 assert_eq!(res, String::from("hello"));
290
291 Ok(())
292 }
293}