Skip to main content

xitca_web/middleware/
tower_http_compat.rs

1//! compatibility between tower-http layer and xitca-web middleware.
2
3use tower_layer::Layer;
4
5use crate::service::{Service, tower_http_compat::TowerCompatService};
6
7/// A middleware type that bridge `xitca-service` and `tower-service`.
8/// Any `tower-http` type that impl [Layer] trait can be passed to it and used as xitca-web's middleware.
9///
10/// # Type mutation
11/// `TowerHttpCompat` would mutate response body type from `B` to `CompatBody<ResB>`. Service enclosed
12/// by it must be able to handle it's mutation or utilize [TypeEraser] to erase the mutation.
13///
14/// [TypeEraser]: crate::middleware::eraser::TypeEraser
15#[derive(Clone)]
16pub struct TowerHttpCompat<L>(L);
17
18impl<L> TowerHttpCompat<L> {
19    /// Construct a new xitca-web middleware from tower-http layer type.
20    ///
21    /// # Limitation:
22    /// tower::Service::poll_ready implementation is ignored by TowerHttpCompat.
23    /// if a tower Service is sensitive to poll_ready implementation it should not be used.
24    ///
25    /// # Example:
26    /// ```rust
27    /// # use std::convert::Infallible;
28    /// # use xitca_web::{http::{StatusCode, WebResponse}, service::fn_service, App, WebContext};
29    /// use xitca_web::middleware::tower_http_compat::TowerHttpCompat;
30    /// use tower_http::set_status::SetStatusLayer;
31    ///
32    /// # fn doc_example() {
33    /// App::new()
34    ///     .at("/", fn_service(handler))
35    ///     .enclosed(TowerHttpCompat::new(SetStatusLayer::new(StatusCode::NOT_FOUND)));
36    /// # }
37    ///
38    /// # async fn handler(ctx: WebContext<'_>) -> Result<WebResponse, Infallible> {
39    /// #   todo!()
40    /// # }
41    /// ```
42    pub const fn new(layer: L) -> Self {
43        Self(layer)
44    }
45}
46
47impl<L, S, E> Service<Result<S, E>> for TowerHttpCompat<L>
48where
49    L: Layer<compat_layer::CompatLayer<S>>,
50{
51    type Response = TowerCompatService<L::Service>;
52    type Error = E;
53
54    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
55        res.map(|service| {
56            let service = self.0.layer(compat_layer::CompatLayer::new(service));
57            TowerCompatService::new(service)
58        })
59    }
60}
61
62mod compat_layer {
63    use core::{
64        cell::RefCell,
65        future::Future,
66        pin::Pin,
67        task::{Context, Poll},
68    };
69
70    use std::rc::Rc;
71
72    use crate::{
73        WebContext,
74        http::{Request, RequestExt, Response, WebResponse},
75        service::tower_http_compat::{CompatBody, CompatReqBody},
76    };
77
78    use super::*;
79
80    pub struct CompatLayer<S>(Rc<S>);
81
82    impl<S> CompatLayer<S> {
83        pub(super) fn new(service: S) -> Self {
84            Self(Rc::new(service))
85        }
86    }
87
88    impl<S, C, ReqB, ResB, Err> tower_service::Service<Request<CompatReqBody<RequestExt<ReqB>, C>>> for CompatLayer<S>
89    where
90        S: for<'r> Service<WebContext<'r, C, ReqB>, Response = WebResponse<ResB>, Error = Err> + 'static,
91        C: 'static,
92        ReqB: 'static,
93    {
94        type Response = Response<CompatBody<ResB>>;
95        type Error = Err;
96        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
97
98        #[inline]
99        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100            Poll::Ready(Ok(()))
101        }
102
103        fn call(&mut self, req: Request<CompatReqBody<RequestExt<ReqB>, C>>) -> Self::Future {
104            let service = self.0.clone();
105            Box::pin(async move {
106                let (parts, body) = req.into_parts();
107                let (body, ctx) = body.into_parts();
108                let (ext, body) = body.replace_body(());
109
110                let mut req = Request::from_parts(parts, ext);
111                let mut body = RefCell::new(body);
112                let req = WebContext::new(&mut req, &mut body, &ctx);
113
114                service.call(req).await.map(|res| res.map(CompatBody::new))
115            })
116        }
117    }
118}
119
120#[cfg(test)]
121mod test {
122    use core::convert::Infallible;
123
124    use tower_http::set_status::SetStatusLayer;
125    use xitca_unsafe_collection::futures::NowOrPanic;
126
127    use crate::{
128        App, WebContext,
129        body::ResponseBody,
130        http::WebRequest,
131        http::{StatusCode, WebResponse},
132        service::fn_service,
133    };
134
135    use super::*;
136
137    async fn handler(ctx: WebContext<'_, &'static str>) -> Result<WebResponse, Infallible> {
138        assert_eq!(*ctx.state(), "996");
139        Ok(ctx.into_response(ResponseBody::empty()))
140    }
141
142    #[test]
143    fn tower_set_status() {
144        let res = App::new()
145            .with_state("996")
146            .at("/", fn_service(handler))
147            .enclosed(TowerHttpCompat::new(SetStatusLayer::new(StatusCode::OK)))
148            .enclosed(TowerHttpCompat::new(SetStatusLayer::new(StatusCode::NOT_FOUND)))
149            .finish()
150            .call(())
151            .now_or_panic()
152            .unwrap()
153            .call(WebRequest::default())
154            .now_or_panic()
155            .unwrap();
156
157        assert_eq!(res.status(), StatusCode::NOT_FOUND);
158    }
159}