Skip to main content

tower_http_cache_plus/
service.rs

1use super::cache::{middleware::*, *};
2
3use {
4    http::{request::*, response::*},
5    http_body::*,
6    kutil::{
7        http::{transcoding::*, *},
8        std::{error::*, future::*, immutable::*},
9    },
10    std::{convert::*, mem, result::Result, sync::*, task::*},
11    tower::*,
12};
13
14//
15// CachingService
16//
17
18/// HTTP response caching service with integrated compression.
19///
20/// You will often be using [CachingLayer](super::CachingLayer) rather than this service directly,
21/// thus this service's functionality is documented there.
22pub struct CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT = CommonCacheKey>
23where
24    CacheT: Cache<CacheKeyT>,
25    CacheKeyT: CacheKey,
26{
27    inner_service: InnerServiceT,
28    caching: MiddlewareCachingConfiguration<RequestBodyT, CacheT, CacheKeyT>,
29    encoding: MiddlewareEncodingConfiguration,
30}
31
32impl<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
33    CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
34where
35    CacheT: Cache<CacheKeyT>,
36    CacheKeyT: CacheKey,
37{
38    /// Constructor.
39    pub fn new(
40        inner_service: InnerServiceT,
41        caching: MiddlewareCachingConfiguration<RequestBodyT, CacheT, CacheKeyT>,
42        encoding: MiddlewareEncodingConfiguration,
43    ) -> Self {
44        assert!(caching.inner.min_body_size <= caching.inner.max_body_size);
45        Self {
46            inner_service,
47            caching: caching.clone(),
48            encoding: encoding.clone(),
49        }
50    }
51
52    // Clone while keeping `inner_service`.
53    //
54    // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
55    fn clone_and_keep_inner_service(&mut self) -> Self
56    where
57        InnerServiceT: Clone,
58    {
59        let mut clone = self.clone();
60        clone.inner_service = mem::replace(&mut self.inner_service, clone.inner_service);
61        clone
62    }
63
64    // Handle request.
65    async fn handle<ResponseBodyT>(
66        mut self,
67        request: Request<RequestBodyT>,
68    ) -> Result<Response<TranscodingBody<ResponseBodyT>>, InnerServiceT::Error>
69    where
70        InnerServiceT: Service<Request<RequestBodyT>, Response = Response<ResponseBodyT>>,
71        ResponseBodyT: 'static + Body + From<ImmutableBytes> + Send + Unpin,
72        ResponseBodyT::Data: From<ImmutableBytes> + Send,
73        ResponseBodyT::Error: Into<CapturedError>,
74    {
75        if request.should_skip_cache(&self.caching) {
76            // Capture request data before moving the request to the inner service
77            let uri = request.uri().clone();
78            let encoding = request.select_encoding(&self.encoding);
79            let content_length = request.headers().content_length();
80
81            return self
82                .inner_service
83                .call(request)
84                .await
85                .map(|upstream_response| {
86                    let (encoding, _skip_encoding) = upstream_response.validate_encoding(
87                        &uri,
88                        encoding,
89                        content_length,
90                        &self.encoding,
91                    );
92                    upstream_response
93                        .with_transcoding_body(&encoding, self.encoding.inner.encodable_by_default)
94                });
95        }
96
97        let cache = self.caching.cache.clone().expect("has cache");
98        let cache_key = request.cache_key_with_hook(&self.caching);
99
100        match cache.get(&cache_key).await {
101            Some(cached_response) => Ok({
102                if modified(request.headers(), cached_response.headers()) {
103                    tracing::debug!("hit");
104
105                    cached_response
106                        .to_transcoding_response(
107                            &request.select_encoding(&self.encoding),
108                            false,
109                            cache,
110                            cache_key,
111                            &self.encoding.inner,
112                        )
113                        .await
114                } else {
115                    tracing::debug!("hit (not modified)");
116
117                    not_modified_transcoding_response()
118                }
119            }),
120
121            None => {
122                // Capture request data before moving the request to the inner service
123                let uri = request.uri().clone();
124                let encoding = request.select_encoding(&self.encoding);
125
126                let upstream_response = self.inner_service.call(request).await?;
127
128                Ok({
129                    let (skip_caching, content_length) =
130                        upstream_response.should_skip_cache(&uri, &self.caching);
131                    let (encoding, skip_encoding) = upstream_response.validate_encoding(
132                        &uri,
133                        encoding.clone(),
134                        content_length,
135                        &self.encoding,
136                    );
137
138                    if skip_caching {
139                        upstream_response.with_transcoding_body(
140                            &encoding,
141                            self.encoding.inner.encodable_by_default,
142                        )
143                    } else {
144                        tracing::debug!("miss");
145
146                        match CachedResponse::new_for(
147                            &uri,
148                            upstream_response,
149                            content_length,
150                            encoding.clone(),
151                            skip_encoding,
152                            &self.caching.inner,
153                            &self.encoding.inner,
154                        )
155                        .await
156                        {
157                            Ok(cached_response) => {
158                                tracing::debug!("store ({})", encoding);
159                                Arc::new(cached_response)
160                                    .to_transcoding_response(
161                                        &encoding,
162                                        true,
163                                        cache,
164                                        cache_key,
165                                        &self.encoding.inner,
166                                    )
167                                    .await
168                            }
169
170                            Err(error) => match error.pieces {
171                                Some(pieces) => {
172                                    tracing::debug!("skip ({})", error.error);
173                                    pieces.response.with_transcoding_body_with_first_bytes(
174                                        Some(pieces.first_bytes),
175                                        &encoding,
176                                        self.encoding.inner.encodable_by_default,
177                                    )
178                                }
179
180                                None => {
181                                    tracing::error!(
182                                        "could not create cache entry: {} {}",
183                                        cache_key,
184                                        error
185                                    );
186                                    error_transcoding_response()
187                                }
188                            },
189                        }
190                    }
191                })
192            }
193        }
194    }
195}
196
197impl<InnerServiceT, RequestBodyT, CacheT, CacheKeyT> Clone
198    for CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
199where
200    InnerServiceT: Clone,
201    CacheT: Cache<CacheKeyT>,
202    CacheKeyT: CacheKey,
203{
204    fn clone(&self) -> Self {
205        Self {
206            inner_service: self.inner_service.clone(),
207            caching: self.caching.clone(),
208            encoding: self.encoding.clone(),
209        }
210    }
211}
212
213impl<InnerServiceT, RequestBodyT, ResponseBodyT, ErrorT, CacheT, CacheKeyT>
214    Service<Request<RequestBodyT>>
215    for CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
216where
217    InnerServiceT: 'static
218        + Service<Request<RequestBodyT>, Response = Response<ResponseBodyT>, Error = ErrorT>
219        + Clone
220        + Send,
221    InnerServiceT::Future: Send,
222    RequestBodyT: 'static + Send,
223    ResponseBodyT: 'static + Body + From<ImmutableBytes> + Send + Unpin,
224    ResponseBodyT::Data: From<ImmutableBytes> + Send,
225    ResponseBodyT::Error: Into<CapturedError>,
226    CacheT: Cache<CacheKeyT>,
227    CacheKeyT: CacheKey,
228{
229    type Response = Response<TranscodingBody<ResponseBodyT>>;
230    type Error = InnerServiceT::Error;
231    type Future = CapturedFuture<Result<Self::Response, Self::Error>>;
232
233    fn poll_ready(&mut self, context: &mut Context) -> Poll<Result<(), Self::Error>> {
234        // Note that if we are using the cache, we technically don't have to depend on the inner
235        // service being poll_ready for us to be poll_ready, however Tower's design does not allow
236        // us to optimize here
237        self.inner_service.poll_ready(context)
238    }
239
240    fn call(&mut self, request: Request<RequestBodyT>) -> Self::Future {
241        // We unfortunately must clone the `&mut self` because it cannot be sent to the future as is;
242        //
243        // The worry is that we are cloning our inner service, too, which will clone *its* inner service,
244        // and so on... It can be a sizeable clone if there are many service layers
245        //
246        // But this seems to be standard practice in Tower due to its design!
247
248        let cloned_self = self.clone_and_keep_inner_service();
249        capture_async! { cloned_self.handle(request).await }
250    }
251}