xitca_web/middleware/
tower_http_compat.rs1use tower_layer::Layer;
4
5use crate::service::{Service, tower_http_compat::TowerCompatService};
6
7#[derive(Clone)]
16pub struct TowerHttpCompat<L>(L);
17
18impl<L> TowerHttpCompat<L> {
19 pub const fn new(layer: L) -> Self {
43 Self(layer)
44 }
45}
46
47impl<L, S, E> Service<Result<S, E>> for TowerHttpCompat<L>
48where
49 L: Layer<compat_layer::CompatLayer<S>>,
50{
51 type Response = TowerCompatService<L::Service>;
52 type Error = E;
53
54 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
55 res.map(|service| {
56 let service = self.0.layer(compat_layer::CompatLayer::new(service));
57 TowerCompatService::new(service)
58 })
59 }
60}
61
62mod compat_layer {
63 use core::{
64 cell::RefCell,
65 future::Future,
66 pin::Pin,
67 task::{Context, Poll},
68 };
69
70 use std::rc::Rc;
71
72 use crate::{
73 WebContext,
74 http::{Request, RequestExt, Response, WebResponse},
75 service::tower_http_compat::{CompatBody, CompatReqBody},
76 };
77
78 use super::*;
79
80 pub struct CompatLayer<S>(Rc<S>);
81
82 impl<S> CompatLayer<S> {
83 pub(super) fn new(service: S) -> Self {
84 Self(Rc::new(service))
85 }
86 }
87
88 impl<S, C, ReqB, ResB, Err> tower_service::Service<Request<CompatReqBody<RequestExt<ReqB>, C>>> for CompatLayer<S>
89 where
90 S: for<'r> Service<WebContext<'r, C, ReqB>, Response = WebResponse<ResB>, Error = Err> + 'static,
91 C: 'static,
92 ReqB: 'static,
93 {
94 type Response = Response<CompatBody<ResB>>;
95 type Error = Err;
96 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
97
98 #[inline]
99 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 Poll::Ready(Ok(()))
101 }
102
103 fn call(&mut self, req: Request<CompatReqBody<RequestExt<ReqB>, C>>) -> Self::Future {
104 let service = self.0.clone();
105 Box::pin(async move {
106 let (parts, body) = req.into_parts();
107 let (body, ctx) = body.into_parts();
108 let (ext, body) = body.replace_body(());
109
110 let mut req = Request::from_parts(parts, ext);
111 let mut body = RefCell::new(body);
112 let req = WebContext::new(&mut req, &mut body, &ctx);
113
114 service.call(req).await.map(|res| res.map(CompatBody::new))
115 })
116 }
117 }
118}
119
120#[cfg(test)]
121mod test {
122 use core::convert::Infallible;
123
124 use tower_http::set_status::SetStatusLayer;
125 use xitca_unsafe_collection::futures::NowOrPanic;
126
127 use crate::{
128 App, WebContext,
129 body::ResponseBody,
130 http::WebRequest,
131 http::{StatusCode, WebResponse},
132 service::fn_service,
133 };
134
135 use super::*;
136
137 async fn handler(ctx: WebContext<'_, &'static str>) -> Result<WebResponse, Infallible> {
138 assert_eq!(*ctx.state(), "996");
139 Ok(ctx.into_response(ResponseBody::empty()))
140 }
141
142 #[test]
143 fn tower_set_status() {
144 let res = App::new()
145 .with_state("996")
146 .at("/", fn_service(handler))
147 .enclosed(TowerHttpCompat::new(SetStatusLayer::new(StatusCode::OK)))
148 .enclosed(TowerHttpCompat::new(SetStatusLayer::new(StatusCode::NOT_FOUND)))
149 .finish()
150 .call(())
151 .now_or_panic()
152 .unwrap()
153 .call(WebRequest::default())
154 .now_or_panic()
155 .unwrap();
156
157 assert_eq!(res.status(), StatusCode::NOT_FOUND);
158 }
159}