xitca_http/util/middleware/
extension.rs1use xitca_service::{Service, ready::ReadyService};
2
3use crate::http::{BorrowReqMut, Extensions};
4
5#[derive(Clone)]
9pub struct Extension<S> {
10 state: S,
11}
12
13impl<S> Extension<S> {
14 pub fn new(state: S) -> Self
15 where
16 S: Send + Sync + Clone + 'static,
17 {
18 Extension { state }
19 }
20}
21
22impl<T, S, E> Service<Result<S, E>> for Extension<T>
23where
24 T: Clone,
25{
26 type Response = ExtensionService<S, T>;
27 type Error = E;
28
29 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
30 res.map(|service| ExtensionService {
31 service,
32 state: self.state.clone(),
33 })
34 }
35}
36
37pub struct ExtensionService<S, St> {
38 service: S,
39 state: St,
40}
41
42impl<S, St> Clone for ExtensionService<S, St>
43where
44 S: Clone,
45 St: Clone,
46{
47 fn clone(&self) -> Self {
48 Self {
49 service: self.service.clone(),
50 state: self.state.clone(),
51 }
52 }
53}
54
55impl<S, St, Req> Service<Req> for ExtensionService<S, St>
56where
57 S: Service<Req>,
58 St: Send + Sync + Clone + 'static,
59 Req: BorrowReqMut<Extensions>,
60{
61 type Response = S::Response;
62 type Error = S::Error;
63
64 #[inline]
65 async fn call(&self, mut req: Req) -> Result<Self::Response, Self::Error> {
66 req.borrow_mut().insert(self.state.clone());
67 self.service.call(req).await
68 }
69}
70
71impl<S, St> ReadyService for ExtensionService<S, St>
72where
73 S: ReadyService,
74 St: Send + Sync + Clone + 'static,
75{
76 type Ready = S::Ready;
77
78 #[inline]
79 async fn ready(&self) -> Self::Ready {
80 self.service.ready().await
81 }
82}
83
84#[cfg(test)]
85mod test {
86 use xitca_service::{ServiceExt, fn_service};
87 use xitca_unsafe_collection::futures::NowOrPanic;
88
89 use crate::http::Request;
90
91 use super::*;
92
93 #[test]
94 fn state_middleware() {
95 let service = fn_service(|req: Request<()>| async move {
96 assert_eq!("state", req.extensions().get::<String>().unwrap());
97 Ok::<_, ()>("996")
98 })
99 .enclosed(Extension::new(String::from("state")))
100 .call(())
101 .now_or_panic()
102 .unwrap();
103
104 let res = service.call(Request::new(())).now_or_panic().unwrap();
105
106 assert_eq!("996", res);
107 }
108
109 #[test]
110 fn state_middleware_http_request() {
111 let service = fn_service(|req: http::Request<()>| async move {
112 assert_eq!("state", req.extensions().get::<String>().unwrap());
113 Ok::<_, ()>("996")
114 })
115 .enclosed(Extension::new(String::from("state")))
116 .call(())
117 .now_or_panic()
118 .unwrap();
119
120 let res = service.call(http::Request::new(())).now_or_panic().unwrap();
121
122 assert_eq!("996", res);
123 }
124}