xitca_web/middleware/
sync.rs

1//! synchronous function as middleware.
2
3use core::mem;
4
5use std::sync::mpsc::Receiver;
6
7use tokio::sync::mpsc::UnboundedSender;
8
9use crate::{
10    context::WebContext,
11    http::{Request, RequestExt, Response},
12    service::Service,
13};
14
15/// experimental type for sync function as middleware.
16pub struct SyncMiddleware<F>(F);
17
18impl<F> SyncMiddleware<F> {
19    /// *. Sync middleware does not have access to request/response body.
20    ///
21    /// construct a new middleware with given sync function.
22    /// the function must be actively calling [Next::call] and finish it to drive inner services to completion.
23    /// panic in sync function middleware would result in a panic at task level and it's client connection would
24    /// be terminated immediately.
25    pub fn new<C, E>(func: F) -> Self
26    where
27        F: Fn(&mut Next<E>, WebContext<'_, C>) -> Result<Response<()>, E> + Send + Sync + 'static,
28        C: Clone + Send + 'static,
29        E: Send + 'static,
30    {
31        Self(func)
32    }
33}
34
35/// next/inner services of a middleware function. [Next::call] must run to complete in order to drive
36/// services.
37pub struct Next<E> {
38    tx: UnboundedSender<Request<RequestExt<()>>>,
39    rx: Receiver<Result<Response<()>, E>>,
40}
41
42impl<E> Next<E> {
43    /// call next/inner services to complete where they would produce either a http response or an error.
44    pub fn call<C>(&mut self, mut ctx: WebContext<'_, C>) -> Result<Response<()>, E> {
45        let req = mem::take(ctx.req_mut());
46        self.tx.send(req).unwrap();
47        self.rx.recv().unwrap()
48    }
49}
50
51impl<F, S, E> Service<Result<S, E>> for SyncMiddleware<F>
52where
53    F: Clone,
54{
55    type Response = service::SyncService<F, S>;
56    type Error = E;
57
58    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
59        res.map(|service| service::SyncService {
60            func: self.0.clone(),
61            service,
62        })
63    }
64}
65
66mod service {
67    use core::cell::RefCell;
68
69    use std::sync::mpsc::sync_channel;
70
71    use tokio::sync::mpsc::unbounded_channel;
72
73    use crate::{body::RequestBody, http::WebResponse, service::ready::ReadyService};
74
75    use super::*;
76
77    pub struct SyncService<F, S> {
78        pub(super) func: F,
79        pub(super) service: S,
80    }
81
82    impl<'r, F, C, S, B, ResB, Err> Service<WebContext<'r, C, B>> for SyncService<F, S>
83    where
84        F: Fn(&mut Next<Err>, WebContext<'_, C>) -> Result<Response<()>, Err> + Send + Clone + 'static,
85        C: Clone + Send + 'static,
86        S: for<'r2> Service<WebContext<'r, C, B>, Response = WebResponse<ResB>, Error = Err>,
87        Err: Send + 'static,
88    {
89        type Response = WebResponse<ResB>;
90        type Error = Err;
91
92        async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
93            let func = self.func.clone();
94            let state = ctx.state().clone();
95            let mut req = mem::take(ctx.req_mut());
96
97            let (tx, mut rx) = unbounded_channel();
98            let (tx2, rx2) = sync_channel(1);
99
100            let mut next = Next { tx, rx: rx2 };
101            let handle = tokio::task::spawn_blocking(move || {
102                let mut body = RefCell::new(RequestBody::None);
103                let ctx = WebContext::new(&mut req, &mut body, &state);
104                func(&mut next, ctx)
105            });
106
107            *ctx.req_mut() = match rx.recv().await {
108                Some(req) => req,
109                None => {
110                    // tx is dropped which means spawned thread exited already. join it and panic if necessary.
111                    match handle.await.unwrap() {
112                        Ok(_) => todo!("there is no support for body type yet"),
113                        Err(e) => return Err(e),
114                    }
115                }
116            };
117
118            match self.service.call(ctx).await {
119                Ok(res) => {
120                    let (parts, body) = res.into_parts();
121                    let _ = tx2.send(Ok(Response::from_parts(parts, ())));
122                    let res = handle.await.unwrap()?;
123                    Ok(res.map(|_| body))
124                }
125                Err(e) => {
126                    let _ = tx2.send(Err(e));
127                    let res = handle.await.unwrap()?;
128                    Ok(res.map(|_| todo!("there is no support for body type yet")))
129                }
130            }
131        }
132    }
133
134    impl<F, S> ReadyService for SyncService<F, S>
135    where
136        S: ReadyService,
137    {
138        type Ready = S::Ready;
139
140        #[inline]
141        async fn ready(&self) -> Self::Ready {
142            self.service.ready().await
143        }
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use core::convert::Infallible;
150
151    use crate::{
152        App,
153        body::ResponseBody,
154        http::{StatusCode, WebResponse},
155        service::fn_service,
156    };
157
158    use super::*;
159
160    async fn handler(req: WebContext<'_, &'static str>) -> Result<WebResponse, Infallible> {
161        assert_eq!(*req.state(), "996");
162        Ok(req.into_response(ResponseBody::empty()))
163    }
164
165    fn middleware<E>(next: &mut Next<E>, ctx: WebContext<'_, &'static str>) -> Result<Response<()>, E> {
166        next.call(ctx)
167    }
168
169    #[tokio::test]
170    async fn sync_middleware() {
171        let res = App::new()
172            .with_state("996")
173            .at("/", fn_service(handler))
174            .enclosed(SyncMiddleware::new(middleware))
175            .finish()
176            .call(())
177            .await
178            .unwrap()
179            .call(Request::default())
180            .await
181            .unwrap();
182
183        assert_eq!(res.status(), StatusCode::OK);
184    }
185}