Skip to main content

rust_serv/middleware/
cache.rs

1use std::task::{Context, Poll};
2use hyper::Request;
3use http_body_util::BodyExt;
4use tower::{Layer, Service};
5
6/// Cache middleware layer
7#[derive(Clone)]
8pub struct CacheLayer;
9
10impl<S> Layer<S> for CacheLayer {
11    type Service = CacheService<S>;
12
13    fn layer(&self, inner: S) -> Self::Service {
14        CacheService { inner }
15    }
16}
17
18#[derive(Clone)]
19pub struct CacheService<S> {
20    inner: S,
21}
22
23impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CacheService<S>
24where
25    S: Service<Request<ReqBody>, Response = hyper::Response<ResBody>>,
26    ReqBody: BodyExt + Send + 'static,
27    ResBody: BodyExt + Send + 'static,
28{
29    type Response = S::Response;
30    type Error = S::Error;
31    type Future = S::Future;
32
33    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
34        self.inner.poll_ready(cx)
35    }
36
37    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
38        // Check If-None-Match, If-Modified-Since headers
39        // Caching will be implemented in a later iteration
40        self.inner.call(req)
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use hyper::{Request, Response, Method, HeaderMap};
48    use hyper::header::{IF_NONE_MATCH, IF_MODIFIED_SINCE, ETAG, LAST_MODIFIED};
49    use http_body_util::Full;
50    use hyper::body::Bytes;
51    use std::pin::Pin;
52    use std::future::Future;
53    use std::task::{Context, Poll};
54
55    /// Mock service for testing
56    #[derive(Clone)]
57    struct MockService {
58        should_return_304: bool,
59    }
60
61    impl Service<Request<Full<Bytes>>> for MockService {
62        type Response = Response<Full<Bytes>>;
63        type Error = std::convert::Infallible;
64        type Future = Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
65
66        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67            Poll::Ready(Ok(()))
68        }
69
70        fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
71            let should_304 = self.should_return_304;
72            let if_none_match = req.headers().get(IF_NONE_MATCH).cloned();
73            let if_modified_since = req.headers().get(IF_MODIFIED_SINCE).cloned();
74
75            Box::pin(async move {
76                let mut builder = Response::builder()
77                    .status(if should_304 && (if_none_match.is_some() || if_modified_since.is_some()) {
78                        hyper::StatusCode::NOT_MODIFIED
79                    } else {
80                        hyper::StatusCode::OK
81                    });
82
83                // Add cache-related headers
84                let mut headers = HeaderMap::new();
85                headers.insert(ETAG, "\"test-etag\"".parse().unwrap());
86                headers.insert(LAST_MODIFIED, "Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap());
87
88                for (name, value) in headers.iter() {
89                    builder = builder.header(name, value);
90                }
91
92                let body = if should_304 && (if_none_match.is_some() || if_modified_since.is_some()) {
93                    Full::new(Bytes::new())
94                } else {
95                    Full::new(Bytes::from("test content"))
96                };
97
98                Ok(builder.body(body).unwrap())
99            })
100        }
101    }
102
103    #[test]
104    fn test_cache_layer_creation() {
105        let _layer = CacheLayer;
106        // Layer should be created successfully
107    }
108
109    #[test]
110    fn test_cache_layer_clone() {
111        let layer = CacheLayer;
112        let _cloned = layer.clone();
113        // Layer should be clonable
114    }
115
116    #[test]
117    fn test_cache_service_creation() {
118        let layer = CacheLayer;
119        let mock_service = MockService { should_return_304: false };
120        let _cache_service = layer.layer(mock_service);
121        // CacheService should be created successfully
122    }
123
124    #[tokio::test]
125    async fn test_cache_service_call_without_headers() {
126        let layer = CacheLayer;
127        let mock_service = MockService { should_return_304: false };
128        let mut cache_service = layer.layer(mock_service);
129
130        let request = Request::builder()
131            .method(Method::GET)
132            .uri("/test")
133            .body(Full::new(Bytes::new()))
134            .unwrap();
135
136        let response = cache_service.call(request).await.unwrap();
137        assert_eq!(response.status(), hyper::StatusCode::OK);
138    }
139
140    #[tokio::test]
141    async fn test_cache_service_with_if_none_match() {
142        let layer = CacheLayer;
143        let mock_service = MockService { should_return_304: true };
144        let mut cache_service = layer.layer(mock_service);
145
146        let request = Request::builder()
147            .method(Method::GET)
148            .uri("/test")
149            .header(IF_NONE_MATCH, "\"test-etag\"")
150            .body(Full::new(Bytes::new()))
151            .unwrap();
152
153        let response = cache_service.call(request).await.unwrap();
154        // Cache middleware should pass through to the service
155        assert_eq!(response.status(), hyper::StatusCode::NOT_MODIFIED);
156    }
157
158    #[tokio::test]
159    async fn test_cache_service_with_if_modified_since() {
160        let layer = CacheLayer;
161        let mock_service = MockService { should_return_304: true };
162        let mut cache_service = layer.layer(mock_service);
163
164        let request = Request::builder()
165            .method(Method::GET)
166            .uri("/test")
167            .header(IF_MODIFIED_SINCE, "Wed, 21 Oct 2015 07:28:00 GMT")
168            .body(Full::new(Bytes::new()))
169            .unwrap();
170
171        let response = cache_service.call(request).await.unwrap();
172        // Cache middleware should pass through to the service
173        assert_eq!(response.status(), hyper::StatusCode::NOT_MODIFIED);
174    }
175
176    #[tokio::test]
177    async fn test_cache_service_with_both_cache_headers() {
178        let layer = CacheLayer;
179        let mock_service = MockService { should_return_304: true };
180        let mut cache_service = layer.layer(mock_service);
181
182        let request = Request::builder()
183            .method(Method::GET)
184            .uri("/test")
185            .header(IF_NONE_MATCH, "\"test-etag\"")
186            .header(IF_MODIFIED_SINCE, "Wed, 21 Oct 2015 07:28:00 GMT")
187            .body(Full::new(Bytes::new()))
188            .unwrap();
189
190        let response = cache_service.call(request).await.unwrap();
191        // Cache middleware should pass through to the service
192        assert_eq!(response.status(), hyper::StatusCode::NOT_MODIFIED);
193    }
194
195    #[tokio::test]
196    async fn test_cache_service_poll_ready() {
197        use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
198
199        let layer = CacheLayer;
200        let mock_service = MockService { should_return_304: false };
201        let mut cache_service = layer.layer(mock_service);
202
203        // Create a dummy waker
204        fn dummy_clone(_: *const ()) -> RawWaker {
205            RawWaker::new(std::ptr::null(), &VTABLE)
206        }
207        fn dummy(_: *const ()) {}
208        static VTABLE: RawWakerVTable = RawWakerVTable::new(dummy_clone, dummy, dummy, dummy);
209        let raw_waker = RawWaker::new(std::ptr::null(), &VTABLE);
210        let waker = unsafe { Waker::from_raw(raw_waker) };
211        let mut cx = Context::from_waker(&waker);
212
213        let poll_result = cache_service.poll_ready(&mut cx);
214        assert!(matches!(poll_result, Poll::Ready(Ok(()))));
215    }
216
217    #[test]
218    fn test_cache_service_clone() {
219        let layer = CacheLayer;
220        let mock_service = MockService { should_return_304: false };
221        let cache_service = layer.layer(mock_service);
222        let _cloned = cache_service.clone();
223        // CacheService should be clonable
224    }
225
226    #[tokio::test]
227    async fn test_cache_service_multiple_requests() {
228        let layer = CacheLayer;
229        let mock_service = MockService { should_return_304: false };
230        let mut cache_service = layer.layer(mock_service);
231
232        // Make multiple requests
233        for i in 0..5 {
234            let request = Request::builder()
235                .method(Method::GET)
236                .uri(&format!("/page/{}", i))
237                .body(Full::new(Bytes::new()))
238                .unwrap();
239
240            let response = cache_service.call(request).await.unwrap();
241            assert_eq!(response.status(), hyper::StatusCode::OK);
242        }
243    }
244
245    #[tokio::test]
246    async fn test_cache_service_with_post_request() {
247        let layer = CacheLayer;
248        let mock_service = MockService { should_return_304: false };
249        let mut cache_service = layer.layer(mock_service);
250
251        let request = Request::builder()
252            .method(Method::POST)
253            .uri("/api/data")
254            .body(Full::new(Bytes::from("test data")))
255            .unwrap();
256
257        let response = cache_service.call(request).await.unwrap();
258        assert_eq!(response.status(), hyper::StatusCode::OK);
259    }
260
261    #[tokio::test]
262    async fn test_cache_service_with_different_etags() {
263        let layer = CacheLayer;
264        let mock_service = MockService { should_return_304: false };
265        let mut cache_service = layer.layer(mock_service);
266
267        let request = Request::builder()
268            .method(Method::GET)
269            .uri("/test")
270            .header(IF_NONE_MATCH, "\"different-etag\"")
271            .body(Full::new(Bytes::new()))
272            .unwrap();
273
274        let response = cache_service.call(request).await.unwrap();
275        // Service should return OK as etag doesn't match
276        assert_eq!(response.status(), hyper::StatusCode::OK);
277    }
278}