server_fn/middleware/
mod.rs

1use crate::error::ServerFnErrorErr;
2use bytes::Bytes;
3use std::{future::Future, pin::Pin};
4
5/// An abstraction over a middleware layer, which can be used to add additional
6/// middleware layer to a [`Service`].
7pub trait Layer<Req, Res>: Send + Sync + 'static {
8    /// Adds this layer to the inner service.
9    fn layer(&self, inner: BoxedService<Req, Res>) -> BoxedService<Req, Res>;
10}
11
12/// A type-erased service, which takes an HTTP request and returns a response.
13pub struct BoxedService<Req, Res> {
14    /// A function that converts a [`ServerFnErrorErr`] into a string.
15    pub ser: fn(ServerFnErrorErr) -> Bytes,
16    /// The inner service.
17    pub service: Box<dyn Service<Req, Res> + Send>,
18}
19
20impl<Req, Res> BoxedService<Req, Res> {
21    /// Constructs a type-erased service from this service.
22    pub fn new(
23        ser: fn(ServerFnErrorErr) -> Bytes,
24        service: impl Service<Req, Res> + Send + 'static,
25    ) -> Self {
26        Self {
27            ser,
28            service: Box::new(service),
29        }
30    }
31
32    /// Converts a request into a response by running the inner service.
33    pub fn run(
34        &mut self,
35        req: Req,
36    ) -> Pin<Box<dyn Future<Output = Res> + Send>> {
37        self.service.run(req, self.ser)
38    }
39}
40
41/// A service converts an HTTP request into a response.
42pub trait Service<Request, Response> {
43    /// Converts a request into a response.
44    fn run(
45        &mut self,
46        req: Request,
47        ser: fn(ServerFnErrorErr) -> Bytes,
48    ) -> Pin<Box<dyn Future<Output = Response> + Send>>;
49}
50
51#[cfg(feature = "axum-no-default")]
52mod axum {
53    use super::{BoxedService, Service};
54    use crate::{error::ServerFnErrorErr, response::Res, ServerFnError};
55    use axum::body::Body;
56    use bytes::Bytes;
57    use http::{Request, Response};
58    use std::{future::Future, pin::Pin};
59
60    impl<S> super::Service<Request<Body>, Response<Body>> for S
61    where
62        S: tower::Service<Request<Body>, Response = Response<Body>>,
63        S::Future: Send + 'static,
64        S::Error: std::fmt::Display + Send + 'static,
65    {
66        fn run(
67            &mut self,
68            req: Request<Body>,
69            ser: fn(ServerFnErrorErr) -> Bytes,
70        ) -> Pin<Box<dyn Future<Output = Response<Body>> + Send>> {
71            let path = req.uri().path().to_string();
72            let inner = self.call(req);
73            Box::pin(async move {
74                inner.await.unwrap_or_else(|e| {
75                    // TODO: This does not set the Content-Type on the response. Doing so will
76                    //  require a breaking change in order to get the correct encoding from the
77                    //  error's `FromServerFnError::Encoder::CONTENT_TYPE` impl.
78                    //  Note: This only applies to middleware errors.
79                    let err =
80                        ser(ServerFnErrorErr::MiddlewareError(e.to_string()));
81                    Response::<Body>::error_response(&path, err)
82                })
83            })
84        }
85    }
86
87    impl tower::Service<Request<Body>>
88        for BoxedService<Request<Body>, Response<Body>>
89    {
90        type Response = Response<Body>;
91        type Error = ServerFnError;
92        type Future = Pin<
93            Box<
94                dyn std::future::Future<
95                        Output = Result<Self::Response, Self::Error>,
96                    > + Send,
97            >,
98        >;
99
100        fn poll_ready(
101            &mut self,
102            _cx: &mut std::task::Context<'_>,
103        ) -> std::task::Poll<Result<(), Self::Error>> {
104            Ok(()).into()
105        }
106
107        fn call(&mut self, req: Request<Body>) -> Self::Future {
108            let inner = self.service.run(req, self.ser);
109            Box::pin(async move { Ok(inner.await) })
110        }
111    }
112
113    impl<L> super::Layer<Request<Body>, Response<Body>> for L
114    where
115        L: tower_layer::Layer<BoxedService<Request<Body>, Response<Body>>>
116            + Sync
117            + Send
118            + 'static,
119        L::Service: Service<Request<Body>, Response<Body>> + Send + 'static,
120    {
121        fn layer(
122            &self,
123            inner: BoxedService<Request<Body>, Response<Body>>,
124        ) -> BoxedService<Request<Body>, Response<Body>> {
125            BoxedService::new(inner.ser, self.layer(inner))
126        }
127    }
128}
129
130#[cfg(feature = "actix-no-default")]
131mod actix {
132    use crate::{
133        error::ServerFnErrorErr,
134        request::actix::ActixRequest,
135        response::{actix::ActixResponse, Res},
136    };
137    use actix_web::{HttpRequest, HttpResponse};
138    use bytes::Bytes;
139    use std::{future::Future, pin::Pin};
140
141    impl<S> super::Service<HttpRequest, HttpResponse> for S
142    where
143        S: actix_web::dev::Service<HttpRequest, Response = HttpResponse>,
144        S::Future: Send + 'static,
145        S::Error: std::fmt::Display + Send + 'static,
146    {
147        fn run(
148            &mut self,
149            req: HttpRequest,
150            ser: fn(ServerFnErrorErr) -> Bytes,
151        ) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> {
152            let path = req.uri().path().to_string();
153            let inner = self.call(req);
154            Box::pin(async move {
155                inner.await.unwrap_or_else(|e| {
156                    // TODO: This does not set the Content-Type on the response. Doing so will
157                    //  require a breaking change in order to get the correct encoding from the
158                    //  error's `FromServerFnError::Encoder::CONTENT_TYPE` impl.
159                    //  Note: This only applies to middleware errors.
160                    let err =
161                        ser(ServerFnErrorErr::MiddlewareError(e.to_string()));
162                    ActixResponse::error_response(&path, err).take()
163                })
164            })
165        }
166    }
167
168    impl<S> super::Service<ActixRequest, ActixResponse> for S
169    where
170        S: actix_web::dev::Service<HttpRequest, Response = HttpResponse>,
171        S::Future: Send + 'static,
172        S::Error: std::fmt::Display + Send + 'static,
173    {
174        fn run(
175            &mut self,
176            req: ActixRequest,
177            ser: fn(ServerFnErrorErr) -> Bytes,
178        ) -> Pin<Box<dyn Future<Output = ActixResponse> + Send>> {
179            let path = req.0 .0.uri().path().to_string();
180            let inner = self.call(req.0.take().0);
181            Box::pin(async move {
182                ActixResponse::from(inner.await.unwrap_or_else(|e| {
183                    let err =
184                        ser(ServerFnErrorErr::MiddlewareError(e.to_string()));
185                    ActixResponse::error_response(&path, err).take()
186                }))
187            })
188        }
189    }
190}