rust_serv/middleware/
cache.rs1use std::task::{Context, Poll};
2use hyper::Request;
3use http_body_util::BodyExt;
4use tower::{Layer, Service};
5
6#[derive(Clone)]
8pub struct CacheLayer;
9
10impl<S> Layer<S> for CacheLayer {
11 type Service = CacheService<S>;
12
13 fn layer(&self, inner: S) -> Self::Service {
14 CacheService { inner }
15 }
16}
17
18#[derive(Clone)]
19pub struct CacheService<S> {
20 inner: S,
21}
22
23impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CacheService<S>
24where
25 S: Service<Request<ReqBody>, Response = hyper::Response<ResBody>>,
26 ReqBody: BodyExt + Send + 'static,
27 ResBody: BodyExt + Send + 'static,
28{
29 type Response = S::Response;
30 type Error = S::Error;
31 type Future = S::Future;
32
33 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
34 self.inner.poll_ready(cx)
35 }
36
37 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
38 self.inner.call(req)
41 }
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 use hyper::{Request, Response, Method, HeaderMap};
48 use hyper::header::{IF_NONE_MATCH, IF_MODIFIED_SINCE, ETAG, LAST_MODIFIED};
49 use http_body_util::Full;
50 use hyper::body::Bytes;
51 use std::pin::Pin;
52 use std::future::Future;
53 use std::task::{Context, Poll};
54
55 #[derive(Clone)]
57 struct MockService {
58 should_return_304: bool,
59 }
60
61 impl Service<Request<Full<Bytes>>> for MockService {
62 type Response = Response<Full<Bytes>>;
63 type Error = std::convert::Infallible;
64 type Future = Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
65
66 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67 Poll::Ready(Ok(()))
68 }
69
70 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
71 let should_304 = self.should_return_304;
72 let if_none_match = req.headers().get(IF_NONE_MATCH).cloned();
73 let if_modified_since = req.headers().get(IF_MODIFIED_SINCE).cloned();
74
75 Box::pin(async move {
76 let mut builder = Response::builder()
77 .status(if should_304 && (if_none_match.is_some() || if_modified_since.is_some()) {
78 hyper::StatusCode::NOT_MODIFIED
79 } else {
80 hyper::StatusCode::OK
81 });
82
83 let mut headers = HeaderMap::new();
85 headers.insert(ETAG, "\"test-etag\"".parse().unwrap());
86 headers.insert(LAST_MODIFIED, "Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap());
87
88 for (name, value) in headers.iter() {
89 builder = builder.header(name, value);
90 }
91
92 let body = if should_304 && (if_none_match.is_some() || if_modified_since.is_some()) {
93 Full::new(Bytes::new())
94 } else {
95 Full::new(Bytes::from("test content"))
96 };
97
98 Ok(builder.body(body).unwrap())
99 })
100 }
101 }
102
103 #[test]
104 fn test_cache_layer_creation() {
105 let _layer = CacheLayer;
106 }
108
109 #[test]
110 fn test_cache_layer_clone() {
111 let layer = CacheLayer;
112 let _cloned = layer.clone();
113 }
115
116 #[test]
117 fn test_cache_service_creation() {
118 let layer = CacheLayer;
119 let mock_service = MockService { should_return_304: false };
120 let _cache_service = layer.layer(mock_service);
121 }
123
124 #[tokio::test]
125 async fn test_cache_service_call_without_headers() {
126 let layer = CacheLayer;
127 let mock_service = MockService { should_return_304: false };
128 let mut cache_service = layer.layer(mock_service);
129
130 let request = Request::builder()
131 .method(Method::GET)
132 .uri("/test")
133 .body(Full::new(Bytes::new()))
134 .unwrap();
135
136 let response = cache_service.call(request).await.unwrap();
137 assert_eq!(response.status(), hyper::StatusCode::OK);
138 }
139
140 #[tokio::test]
141 async fn test_cache_service_with_if_none_match() {
142 let layer = CacheLayer;
143 let mock_service = MockService { should_return_304: true };
144 let mut cache_service = layer.layer(mock_service);
145
146 let request = Request::builder()
147 .method(Method::GET)
148 .uri("/test")
149 .header(IF_NONE_MATCH, "\"test-etag\"")
150 .body(Full::new(Bytes::new()))
151 .unwrap();
152
153 let response = cache_service.call(request).await.unwrap();
154 assert_eq!(response.status(), hyper::StatusCode::NOT_MODIFIED);
156 }
157
158 #[tokio::test]
159 async fn test_cache_service_with_if_modified_since() {
160 let layer = CacheLayer;
161 let mock_service = MockService { should_return_304: true };
162 let mut cache_service = layer.layer(mock_service);
163
164 let request = Request::builder()
165 .method(Method::GET)
166 .uri("/test")
167 .header(IF_MODIFIED_SINCE, "Wed, 21 Oct 2015 07:28:00 GMT")
168 .body(Full::new(Bytes::new()))
169 .unwrap();
170
171 let response = cache_service.call(request).await.unwrap();
172 assert_eq!(response.status(), hyper::StatusCode::NOT_MODIFIED);
174 }
175
176 #[tokio::test]
177 async fn test_cache_service_with_both_cache_headers() {
178 let layer = CacheLayer;
179 let mock_service = MockService { should_return_304: true };
180 let mut cache_service = layer.layer(mock_service);
181
182 let request = Request::builder()
183 .method(Method::GET)
184 .uri("/test")
185 .header(IF_NONE_MATCH, "\"test-etag\"")
186 .header(IF_MODIFIED_SINCE, "Wed, 21 Oct 2015 07:28:00 GMT")
187 .body(Full::new(Bytes::new()))
188 .unwrap();
189
190 let response = cache_service.call(request).await.unwrap();
191 assert_eq!(response.status(), hyper::StatusCode::NOT_MODIFIED);
193 }
194
195 #[tokio::test]
196 async fn test_cache_service_poll_ready() {
197 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
198
199 let layer = CacheLayer;
200 let mock_service = MockService { should_return_304: false };
201 let mut cache_service = layer.layer(mock_service);
202
203 fn dummy_clone(_: *const ()) -> RawWaker {
205 RawWaker::new(std::ptr::null(), &VTABLE)
206 }
207 fn dummy(_: *const ()) {}
208 static VTABLE: RawWakerVTable = RawWakerVTable::new(dummy_clone, dummy, dummy, dummy);
209 let raw_waker = RawWaker::new(std::ptr::null(), &VTABLE);
210 let waker = unsafe { Waker::from_raw(raw_waker) };
211 let mut cx = Context::from_waker(&waker);
212
213 let poll_result = cache_service.poll_ready(&mut cx);
214 assert!(matches!(poll_result, Poll::Ready(Ok(()))));
215 }
216
217 #[test]
218 fn test_cache_service_clone() {
219 let layer = CacheLayer;
220 let mock_service = MockService { should_return_304: false };
221 let cache_service = layer.layer(mock_service);
222 let _cloned = cache_service.clone();
223 }
225
226 #[tokio::test]
227 async fn test_cache_service_multiple_requests() {
228 let layer = CacheLayer;
229 let mock_service = MockService { should_return_304: false };
230 let mut cache_service = layer.layer(mock_service);
231
232 for i in 0..5 {
234 let request = Request::builder()
235 .method(Method::GET)
236 .uri(&format!("/page/{}", i))
237 .body(Full::new(Bytes::new()))
238 .unwrap();
239
240 let response = cache_service.call(request).await.unwrap();
241 assert_eq!(response.status(), hyper::StatusCode::OK);
242 }
243 }
244
245 #[tokio::test]
246 async fn test_cache_service_with_post_request() {
247 let layer = CacheLayer;
248 let mock_service = MockService { should_return_304: false };
249 let mut cache_service = layer.layer(mock_service);
250
251 let request = Request::builder()
252 .method(Method::POST)
253 .uri("/api/data")
254 .body(Full::new(Bytes::from("test data")))
255 .unwrap();
256
257 let response = cache_service.call(request).await.unwrap();
258 assert_eq!(response.status(), hyper::StatusCode::OK);
259 }
260
261 #[tokio::test]
262 async fn test_cache_service_with_different_etags() {
263 let layer = CacheLayer;
264 let mock_service = MockService { should_return_304: false };
265 let mut cache_service = layer.layer(mock_service);
266
267 let request = Request::builder()
268 .method(Method::GET)
269 .uri("/test")
270 .header(IF_NONE_MATCH, "\"different-etag\"")
271 .body(Full::new(Bytes::new()))
272 .unwrap();
273
274 let response = cache_service.call(request).await.unwrap();
275 assert_eq!(response.status(), hyper::StatusCode::OK);
277 }
278}