volo_http/client/layer/
fail_on_status.rs

1use std::{error::Error, fmt};
2
3use http::status::StatusCode;
4use motore::{layer::Layer, service::Service};
5use url::Url;
6use volo::context::Context;
7
8use crate::{
9    error::{ClientError, client::request_error},
10    request::RequestPartsExt,
11    response::Response,
12};
13
14/// [`Layer`] for throwing service error with the response's error status code.
15///
16/// Users can use [`FailOnStatus::all`], [`FailOnStatus::client_error`] or
17/// [`FailOnStatus::server_error`] for creating the [`FailOnStatus`] layer that convert all (4XX and
18/// 5XX), client error (4XX) or server error (5XX) to a error of service.
19#[derive(Clone, Debug, Default)]
20pub struct FailOnStatus {
21    client_error: bool,
22    server_error: bool,
23    detailed: bool,
24}
25
26impl FailOnStatus {
27    /// Create a [`FailOnStatus`] layer that return error [`StatusCodeError`] for all error status
28    /// codes (4XX and 5XX).
29    pub fn all() -> Self {
30        Self {
31            client_error: true,
32            server_error: true,
33            detailed: false,
34        }
35    }
36
37    /// Create a [`FailOnStatus`] layer that return error [`StatusCodeError`] for client error
38    /// status codes (4XX).
39    pub fn client_error() -> Self {
40        Self {
41            client_error: true,
42            server_error: false,
43            detailed: false,
44        }
45    }
46
47    /// Create a [`FailOnStatus`] layer that return error [`StatusCodeError`] for server error
48    /// status codes (5XX).
49    pub fn server_error() -> Self {
50        Self {
51            client_error: false,
52            server_error: true,
53            detailed: false,
54        }
55    }
56
57    /// Collect more details in [`StatusCodeError`].
58    ///
59    /// When error occurs, the request has been consumed and the original response will be dropped.
60    /// With this flag enabled, the layer will save more details in [`StatusCodeError`].
61    pub fn detailed(mut self) -> Self {
62        self.detailed = true;
63        self
64    }
65}
66
67impl<S> Layer<S> for FailOnStatus {
68    type Service = FailOnStatusService<S>;
69
70    fn layer(self, inner: S) -> Self::Service {
71        FailOnStatusService {
72            inner,
73            fail_on: self,
74        }
75    }
76}
77
78/// The [`Service`] generated by [`FailOnStatus`] layer.
79///
80/// See [`FailOnStatus`] for more details.
81pub struct FailOnStatusService<S> {
82    inner: S,
83    fail_on: FailOnStatus,
84}
85
86impl<Cx, Req, S, B> Service<Cx, Req> for FailOnStatusService<S>
87where
88    Cx: Context + Send,
89    Req: RequestPartsExt + Send,
90    S: Service<Cx, Req, Response = Response<B>, Error = ClientError> + Send + Sync,
91{
92    type Response = S::Response;
93    type Error = S::Error;
94
95    async fn call(&self, cx: &mut Cx, req: Req) -> Result<Self::Response, Self::Error> {
96        let url = if self.fail_on.detailed {
97            req.url()
98        } else {
99            None
100        };
101        let resp = self.inner.call(cx, req).await?;
102        let status = resp.status();
103        if (self.fail_on.client_error && status.is_client_error())
104            || (self.fail_on.server_error && status.is_server_error())
105        {
106            Err(request_error(StatusCodeError { status, url })
107                .with_endpoint(cx.rpc_info().callee()))
108        } else {
109            Ok(resp)
110        }
111    }
112}
113
114/// Client received a response with an error status code.
115pub struct StatusCodeError {
116    status: StatusCode,
117    url: Option<Url>,
118}
119
120impl StatusCodeError {
121    /// The original status code.
122    pub fn status(&self) -> StatusCode {
123        self.status
124    }
125
126    /// The target [`Url`]
127    ///
128    /// It will only be saved when [`FailOnStatus::detailed`] enabled.
129    pub fn url(&self) -> Option<&Url> {
130        self.url.as_ref()
131    }
132}
133
134impl fmt::Debug for StatusCodeError {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        f.debug_struct("StatusCodeError")
137            .field("status", &self.status)
138            .finish()
139    }
140}
141
142impl fmt::Display for StatusCodeError {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        write!(f, "client received an error status `{}`", self.status)?;
145        if let Some(url) = &self.url {
146            write!(f, " for `{url}`")?;
147        }
148        Ok(())
149    }
150}
151
152impl Error for StatusCodeError {}
153
154#[cfg(test)]
155mod fail_on_status_tests {
156    use http::status::StatusCode;
157    use motore::service::Service;
158
159    use super::FailOnStatus;
160    use crate::{
161        ClientBuilder, body::Body, client::test_helpers::MockTransport, context::ClientContext,
162        error::ClientError, request::Request, response::Response,
163    };
164
165    struct ReturnStatus;
166
167    impl Service<ClientContext, Request> for ReturnStatus {
168        type Response = Response;
169        type Error = ClientError;
170
171        fn call(
172            &self,
173            _: &mut ClientContext,
174            req: Request,
175        ) -> impl std::future::Future<Output = Result<Self::Response, Self::Error>> + Send {
176            let path = req.uri().path();
177            assert_eq!(&path[..1], "/");
178            let status_code = path[1..].parse::<u16>().expect("invalid uri");
179            let status_code = StatusCode::from_u16(status_code).expect("invalid status code");
180            let mut resp = Response::new(Body::empty());
181            *resp.status_mut() = status_code;
182            async { Ok(resp) }
183        }
184    }
185
186    #[tokio::test]
187    async fn fail_on_status_test() {
188        {
189            // Reject all error status codes
190            let client = ClientBuilder::new()
191                .layer_outer_front(FailOnStatus::all())
192                .mock(MockTransport::service(ReturnStatus))
193                .unwrap();
194            client.get("/400").send().await.unwrap_err();
195            client.get("/500").send().await.unwrap_err();
196        }
197        {
198            // Reject client error status codes
199            let client = ClientBuilder::new()
200                .layer_outer_front(FailOnStatus::client_error())
201                .mock(MockTransport::service(ReturnStatus))
202                .unwrap();
203            client.get("/400").send().await.unwrap_err();
204            // 5XX is server error, it should not be handled
205            client.get("/500").send().await.unwrap();
206        }
207        {
208            // Reject all error status codes
209            let client = ClientBuilder::new()
210                .layer_outer_front(FailOnStatus::server_error())
211                .mock(MockTransport::service(ReturnStatus))
212                .unwrap();
213            // 4XX is client error, it should not be handled
214            client.get("/400").send().await.unwrap();
215            client.get("/500").send().await.unwrap_err();
216        }
217    }
218}