xitca_http/util/middleware/
extension.rs

1use xitca_service::{Service, ready::ReadyService};
2
3use crate::http::{BorrowReqMut, Extensions};
4
5/// builder for middleware attaching typed data to [`Request`]'s [`Extensions`]
6///
7/// [`Request`]: crate::http::Request
8#[derive(Clone)]
9pub struct Extension<S> {
10    state: S,
11}
12
13impl<S> Extension<S> {
14    pub fn new(state: S) -> Self
15    where
16        S: Send + Sync + Clone + 'static,
17    {
18        Extension { state }
19    }
20}
21
22impl<T, S, E> Service<Result<S, E>> for Extension<T>
23where
24    T: Clone,
25{
26    type Response = ExtensionService<S, T>;
27    type Error = E;
28
29    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
30        res.map(|service| ExtensionService {
31            service,
32            state: self.state.clone(),
33        })
34    }
35}
36
37pub struct ExtensionService<S, St> {
38    service: S,
39    state: St,
40}
41
42impl<S, St> Clone for ExtensionService<S, St>
43where
44    S: Clone,
45    St: Clone,
46{
47    fn clone(&self) -> Self {
48        Self {
49            service: self.service.clone(),
50            state: self.state.clone(),
51        }
52    }
53}
54
55impl<S, St, Req> Service<Req> for ExtensionService<S, St>
56where
57    S: Service<Req>,
58    St: Send + Sync + Clone + 'static,
59    Req: BorrowReqMut<Extensions>,
60{
61    type Response = S::Response;
62    type Error = S::Error;
63
64    #[inline]
65    async fn call(&self, mut req: Req) -> Result<Self::Response, Self::Error> {
66        req.borrow_mut().insert(self.state.clone());
67        self.service.call(req).await
68    }
69}
70
71impl<S, St> ReadyService for ExtensionService<S, St>
72where
73    S: ReadyService,
74    St: Send + Sync + Clone + 'static,
75{
76    type Ready = S::Ready;
77
78    #[inline]
79    async fn ready(&self) -> Self::Ready {
80        self.service.ready().await
81    }
82}
83
84#[cfg(test)]
85mod test {
86    use xitca_service::{ServiceExt, fn_service};
87    use xitca_unsafe_collection::futures::NowOrPanic;
88
89    use crate::http::Request;
90
91    use super::*;
92
93    #[test]
94    fn state_middleware() {
95        let service = fn_service(|req: Request<()>| async move {
96            assert_eq!("state", req.extensions().get::<String>().unwrap());
97            Ok::<_, ()>("996")
98        })
99        .enclosed(Extension::new(String::from("state")))
100        .call(())
101        .now_or_panic()
102        .unwrap();
103
104        let res = service.call(Request::new(())).now_or_panic().unwrap();
105
106        assert_eq!("996", res);
107    }
108
109    #[test]
110    fn state_middleware_http_request() {
111        let service = fn_service(|req: http::Request<()>| async move {
112            assert_eq!("state", req.extensions().get::<String>().unwrap());
113            Ok::<_, ()>("996")
114        })
115        .enclosed(Extension::new(String::from("state")))
116        .call(())
117        .now_or_panic()
118        .unwrap();
119
120        let res = service.call(http::Request::new(())).now_or_panic().unwrap();
121
122        assert_eq!("996", res);
123    }
124}