1use super::cache::{middleware::*, *};
2
3use {
4 http::{request::*, response::*},
5 http_body::*,
6 kutil::{
7 http::{transcoding::*, *},
8 std::{error::*, future::*, immutable::*},
9 },
10 std::{convert::*, mem, result::Result, sync::*, task::*},
11 tower::*,
12};
13
14pub struct CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT = CommonCacheKey>
23where
24 CacheT: Cache<CacheKeyT>,
25 CacheKeyT: CacheKey,
26{
27 inner_service: InnerServiceT,
28 caching: MiddlewareCachingConfiguration<RequestBodyT, CacheT, CacheKeyT>,
29 encoding: MiddlewareEncodingConfiguration,
30}
31
32impl<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
33 CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
34where
35 CacheT: Cache<CacheKeyT>,
36 CacheKeyT: CacheKey,
37{
38 pub fn new(
40 inner_service: InnerServiceT,
41 caching: MiddlewareCachingConfiguration<RequestBodyT, CacheT, CacheKeyT>,
42 encoding: MiddlewareEncodingConfiguration,
43 ) -> Self {
44 assert!(caching.inner.min_body_size <= caching.inner.max_body_size);
45 Self {
46 inner_service,
47 caching: caching.clone(),
48 encoding: encoding.clone(),
49 }
50 }
51
52 fn clone_and_keep_inner_service(&mut self) -> Self
56 where
57 InnerServiceT: Clone,
58 {
59 let mut clone = self.clone();
60 clone.inner_service = mem::replace(&mut self.inner_service, clone.inner_service);
61 clone
62 }
63
64 async fn handle<ResponseBodyT>(
66 mut self,
67 request: Request<RequestBodyT>,
68 ) -> Result<Response<TranscodingBody<ResponseBodyT>>, InnerServiceT::Error>
69 where
70 InnerServiceT: Service<Request<RequestBodyT>, Response = Response<ResponseBodyT>>,
71 ResponseBodyT: 'static + Body + From<ImmutableBytes> + Send + Unpin,
72 ResponseBodyT::Data: From<ImmutableBytes> + Send,
73 ResponseBodyT::Error: Into<CapturedError>,
74 {
75 if request.should_skip_cache(&self.caching) {
76 let uri = request.uri().clone();
78 let encoding = request.select_encoding(&self.encoding);
79 let content_length = request.headers().content_length();
80
81 return self
82 .inner_service
83 .call(request)
84 .await
85 .map(|upstream_response| {
86 let (encoding, _skip_encoding) = upstream_response.validate_encoding(
87 &uri,
88 encoding,
89 content_length,
90 &self.encoding,
91 );
92 upstream_response
93 .with_transcoding_body(&encoding, self.encoding.inner.encodable_by_default)
94 });
95 }
96
97 let cache = self.caching.cache.clone().expect("has cache");
98 let cache_key = request.cache_key_with_hook(&self.caching);
99
100 match cache.get(&cache_key).await {
101 Some(cached_response) => Ok({
102 if modified(request.headers(), cached_response.headers()) {
103 tracing::debug!("hit");
104
105 cached_response
106 .to_transcoding_response(
107 &request.select_encoding(&self.encoding),
108 false,
109 cache,
110 cache_key,
111 &self.encoding.inner,
112 )
113 .await
114 } else {
115 tracing::debug!("hit (not modified)");
116
117 not_modified_transcoding_response()
118 }
119 }),
120
121 None => {
122 let uri = request.uri().clone();
124 let encoding = request.select_encoding(&self.encoding);
125
126 let upstream_response = self.inner_service.call(request).await?;
127
128 Ok({
129 let (skip_caching, content_length) =
130 upstream_response.should_skip_cache(&uri, &self.caching);
131 let (encoding, skip_encoding) = upstream_response.validate_encoding(
132 &uri,
133 encoding.clone(),
134 content_length,
135 &self.encoding,
136 );
137
138 if skip_caching {
139 upstream_response.with_transcoding_body(
140 &encoding,
141 self.encoding.inner.encodable_by_default,
142 )
143 } else {
144 tracing::debug!("miss");
145
146 match CachedResponse::new_for(
147 &uri,
148 upstream_response,
149 content_length,
150 encoding.clone(),
151 skip_encoding,
152 &self.caching.inner,
153 &self.encoding.inner,
154 )
155 .await
156 {
157 Ok(cached_response) => {
158 tracing::debug!("store ({})", encoding);
159 Arc::new(cached_response)
160 .to_transcoding_response(
161 &encoding,
162 true,
163 cache,
164 cache_key,
165 &self.encoding.inner,
166 )
167 .await
168 }
169
170 Err(error) => match error.pieces {
171 Some(pieces) => {
172 tracing::debug!("skip ({})", error.error);
173 pieces.response.with_transcoding_body_with_first_bytes(
174 Some(pieces.first_bytes),
175 &encoding,
176 self.encoding.inner.encodable_by_default,
177 )
178 }
179
180 None => {
181 tracing::error!(
182 "could not create cache entry: {} {}",
183 cache_key,
184 error
185 );
186 error_transcoding_response()
187 }
188 },
189 }
190 }
191 })
192 }
193 }
194 }
195}
196
197impl<InnerServiceT, RequestBodyT, CacheT, CacheKeyT> Clone
198 for CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
199where
200 InnerServiceT: Clone,
201 CacheT: Cache<CacheKeyT>,
202 CacheKeyT: CacheKey,
203{
204 fn clone(&self) -> Self {
205 Self {
206 inner_service: self.inner_service.clone(),
207 caching: self.caching.clone(),
208 encoding: self.encoding.clone(),
209 }
210 }
211}
212
213impl<InnerServiceT, RequestBodyT, ResponseBodyT, ErrorT, CacheT, CacheKeyT>
214 Service<Request<RequestBodyT>>
215 for CachingService<InnerServiceT, RequestBodyT, CacheT, CacheKeyT>
216where
217 InnerServiceT: 'static
218 + Service<Request<RequestBodyT>, Response = Response<ResponseBodyT>, Error = ErrorT>
219 + Clone
220 + Send,
221 InnerServiceT::Future: Send,
222 RequestBodyT: 'static + Send,
223 ResponseBodyT: 'static + Body + From<ImmutableBytes> + Send + Unpin,
224 ResponseBodyT::Data: From<ImmutableBytes> + Send,
225 ResponseBodyT::Error: Into<CapturedError>,
226 CacheT: Cache<CacheKeyT>,
227 CacheKeyT: CacheKey,
228{
229 type Response = Response<TranscodingBody<ResponseBodyT>>;
230 type Error = InnerServiceT::Error;
231 type Future = CapturedFuture<Result<Self::Response, Self::Error>>;
232
233 fn poll_ready(&mut self, context: &mut Context) -> Poll<Result<(), Self::Error>> {
234 self.inner_service.poll_ready(context)
238 }
239
240 fn call(&mut self, request: Request<RequestBodyT>) -> Self::Future {
241 let cloned_self = self.clone_and_keep_inner_service();
249 capture_async! { cloned_self.handle(request).await }
250 }
251}