xitca_http/util/middleware/
context.rs

1//! middleware for adding typed state to service request.
2
3use core::fmt;
4
5use xitca_service::Service;
6
7use crate::http::{BorrowReq, BorrowReqMut};
8
9/// ServiceFactory type for constructing compile time checked stateful service.
10///
11/// State is roughly doing the same thing as `move ||` style closure capture. The difference comes
12/// down to:
13///
14/// - The captured state is constructed lazily when [Service::call] method is called.
15/// - State can be referenced in nested types and beyond closures.
16///
17/// # Example:
18///```rust
19/// # use std::convert::Infallible;
20/// # use xitca_http::util::middleware::context::{ContextBuilder, Context};
21/// # use xitca_service::{fn_service, Service, ServiceExt};
22///
23/// // function service.
24/// async fn state_handler(req: Context<'_, String, String>) -> Result<String, Infallible> {
25///    let (parent_req, state) = req.into_parts();
26///    assert_eq!(state, "string_state");
27///    Ok(String::from("string_response"))
28/// }
29///
30/// # async fn stateful() {
31/// // Construct Stateful service builder with closure.
32/// let service = fn_service(state_handler)
33///     // Stateful service builder would construct given service builder and pass (&State, Req) to it's
34///     // Service::call method.
35///     .enclosed(ContextBuilder::new(|| async { Ok::<_, Infallible>(String::from("string_state")) }))
36///     .call(())
37///     .await
38///     .unwrap();
39///
40/// let req = String::default();
41/// let res = service.call(req).await.unwrap();
42/// assert_eq!(res, "string_response");
43///
44/// # }
45///```
46pub struct ContextBuilder<CF> {
47    builder: CF,
48}
49
50impl<CF, Fut, C, CErr> ContextBuilder<CF>
51where
52    CF: Fn() -> Fut,
53    Fut: Future<Output = Result<C, CErr>>,
54{
55    /// Make a stateful service factory with given future.
56    pub fn new(builder: CF) -> Self {
57        Self { builder }
58    }
59}
60
61/// Specialized Request type State service factory.
62///
63/// This type enables borrow parent service request type as &Req and &mut Req
64pub struct Context<'a, Req, C> {
65    req: Req,
66    state: &'a C,
67}
68
69impl<'a, Req, C> Context<'a, Req, C> {
70    /// Destruct request into a tuple of (&state, parent_request).
71    #[inline]
72    pub fn into_parts(self) -> (Req, &'a C) {
73        (self.req, self.state)
74    }
75}
76
77// impls to forward trait from Req type.
78// BorrowReq/Mut are traits needed for nesting Router/Route service inside Context service.
79impl<T, Req, C> BorrowReq<T> for Context<'_, Req, C>
80where
81    Req: BorrowReq<T>,
82{
83    #[inline]
84    fn borrow(&self) -> &T {
85        self.req.borrow()
86    }
87}
88
89impl<T, Req, C> BorrowReqMut<T> for Context<'_, Req, C>
90where
91    Req: BorrowReqMut<T>,
92{
93    #[inline]
94    fn borrow_mut(&mut self) -> &mut T {
95        self.req.borrow_mut()
96    }
97}
98
99type Error = Box<dyn fmt::Debug>;
100
101impl<CF, Fut, C, CErr, S, E> Service<Result<S, E>> for ContextBuilder<CF>
102where
103    CF: Fn() -> Fut,
104    Fut: Future<Output = Result<C, CErr>>,
105    C: 'static,
106    CErr: fmt::Debug + 'static,
107    E: fmt::Debug + 'static,
108{
109    type Response = service::ContextService<C, S>;
110    type Error = Error;
111
112    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
113        let service = res.map_err(|e| Box::new(e) as Error)?;
114        let state = (self.builder)().await.map_err(|e| Box::new(e) as Error)?;
115        Ok(service::ContextService { service, state })
116    }
117}
118
119mod service {
120    use xitca_service::ready::ReadyService;
121
122    use super::*;
123
124    pub struct ContextService<C, S> {
125        pub(super) state: C,
126        pub(super) service: S,
127    }
128
129    impl<Req, C, S, Res, Err> Service<Req> for ContextService<C, S>
130    where
131        S: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err>,
132    {
133        type Response = Res;
134        type Error = Err;
135
136        #[inline]
137        async fn call(&self, req: Req) -> Result<Self::Response, Self::Error> {
138            self.service
139                .call(Context {
140                    req,
141                    state: &self.state,
142                })
143                .await
144        }
145    }
146
147    impl<C, S> ReadyService for ContextService<C, S>
148    where
149        S: ReadyService,
150    {
151        type Ready = S::Ready;
152
153        #[inline]
154        async fn ready(&self) -> Self::Ready {
155            self.service.ready().await
156        }
157    }
158}
159
160#[cfg(feature = "router")]
161mod router_impl {
162    use xitca_service::object::ServiceObject;
163
164    use crate::util::service::router::{IntoObject, PathGen, RouteGen, RouteObject};
165
166    use super::*;
167
168    pub type ContextObject<Req, C, Res, Err> =
169        Box<dyn for<'c> ServiceObject<Context<'c, Req, C>, Response = Res, Error = Err>>;
170
171    impl<C, I, Arg, Req, Res, Err> IntoObject<I, Arg> for Context<'_, Req, C>
172    where
173        C: 'static,
174        Req: 'static,
175        I: Service<Arg> + RouteGen + Send + Sync + 'static,
176        I::Response: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err> + 'static,
177    {
178        type Object = RouteObject<Arg, ContextObject<Req, C, Res, Err>, I::Error>;
179
180        fn into_object(inner: I) -> Self::Object {
181            struct Builder<I, Req, C>(I, core::marker::PhantomData<fn(Req, C)>);
182
183            impl<I, Req, C> PathGen for Builder<I, Req, C>
184            where
185                I: PathGen,
186            {
187                fn path_gen(&mut self, prefix: &str) -> String {
188                    self.0.path_gen(prefix)
189                }
190            }
191
192            impl<I, Req, C> RouteGen for Builder<I, Req, C>
193            where
194                I: RouteGen,
195            {
196                type Route<R> = I::Route<R>;
197
198                fn route_gen<R>(route: R) -> Self::Route<R> {
199                    I::route_gen(route)
200                }
201            }
202
203            impl<C, I, Arg, Req, Res, Err> Service<Arg> for Builder<I, Req, C>
204            where
205                I: Service<Arg> + RouteGen,
206                I::Response: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err> + 'static,
207            {
208                type Response = ContextObject<Req, C, Res, Err>;
209                type Error = I::Error;
210
211                async fn call(&self, arg: Arg) -> Result<Self::Response, Self::Error> {
212                    self.0.call(arg).await.map(|s| Box::new(s) as _)
213                }
214            }
215
216            RouteObject(Box::new(Builder(inner, core::marker::PhantomData)))
217        }
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use std::convert::Infallible;
224
225    use xitca_service::{ServiceExt, fn_service};
226    use xitca_unsafe_collection::futures::NowOrPanic;
227
228    use crate::http::{Request, Response};
229
230    use super::*;
231
232    struct Context2<'a, ST> {
233        req: Request<()>,
234        state: &'a ST,
235    }
236
237    async fn into_context(req: Context<'_, Request<()>, String>) -> Result<Context2<'_, String>, Infallible> {
238        let (req, state) = req.into_parts();
239        assert_eq!(state, "string_state");
240        Ok(Context2 { req, state })
241    }
242
243    async fn ctx_handler(ctx: Context2<'_, String>) -> Result<Response<()>, Infallible> {
244        assert_eq!(ctx.state, "string_state");
245        assert_eq!(ctx.req.method().as_str(), "GET");
246        Ok(Response::new(()))
247    }
248
249    #[test]
250    fn test_state_and_then() {
251        let res = fn_service(into_context)
252            .and_then(fn_service(ctx_handler))
253            .enclosed(ContextBuilder::new(|| async {
254                Ok::<_, Infallible>(String::from("string_state"))
255            }))
256            .call(())
257            .now_or_panic()
258            .ok()
259            .unwrap()
260            .call(Request::default())
261            .now_or_panic()
262            .unwrap();
263
264        assert_eq!(res.status().as_u16(), 200);
265    }
266
267    #[cfg(feature = "router")]
268    #[test]
269    fn test_state_in_router() {
270        use crate::{
271            http::RequestExt,
272            util::service::{Router, route::get},
273        };
274
275        async fn enclosed<S, Req, C, Res, Err>(service: &S, req: Context<'_, Req, C>) -> Result<Res, Err>
276        where
277            S: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err>,
278        {
279            service.call(req).await
280        }
281
282        let router = || {
283            Router::new().insert(
284                "/",
285                get(fn_service(async |req: Context<'_, Request<RequestExt<()>>, String>| {
286                    let (_, state) = req.into_parts();
287                    assert_eq!(state, "string_state");
288                    Ok::<_, Infallible>(Response::new(()))
289                })),
290            )
291        };
292
293        let router_with_ctx = || {
294            router().enclosed_fn(enclosed).enclosed(ContextBuilder::new(|| async {
295                Ok::<_, Infallible>(String::from("string_state"))
296            }))
297        };
298
299        fn bound_check<T: Send + Sync>(_: T) {}
300
301        bound_check(router_with_ctx());
302
303        let res = router_with_ctx()
304            .call(())
305            .now_or_panic()
306            .ok()
307            .unwrap()
308            .call(Request::default())
309            .now_or_panic()
310            .unwrap();
311
312        assert_eq!(res.status().as_u16(), 200);
313
314        let res = router()
315            .insert("/nest", router())
316            .enclosed(ContextBuilder::new(|| async {
317                Ok::<_, Infallible>(String::from("string_state"))
318            }))
319            .call(())
320            .now_or_panic()
321            .ok()
322            .unwrap()
323            .call(Request::default())
324            .now_or_panic()
325            .unwrap();
326
327        assert_eq!(res.status().as_u16(), 200);
328    }
329}