volo_http/client/layer/
fail_on_status.rs1use std::{error::Error, fmt};
2
3use http::status::StatusCode;
4use motore::{layer::Layer, service::Service};
5use url::Url;
6use volo::context::Context;
7
8use crate::{
9 error::{ClientError, client::request_error},
10 request::RequestPartsExt,
11 response::Response,
12};
13
14#[derive(Clone, Debug, Default)]
20pub struct FailOnStatus {
21 client_error: bool,
22 server_error: bool,
23 detailed: bool,
24}
25
26impl FailOnStatus {
27 pub fn all() -> Self {
30 Self {
31 client_error: true,
32 server_error: true,
33 detailed: false,
34 }
35 }
36
37 pub fn client_error() -> Self {
40 Self {
41 client_error: true,
42 server_error: false,
43 detailed: false,
44 }
45 }
46
47 pub fn server_error() -> Self {
50 Self {
51 client_error: false,
52 server_error: true,
53 detailed: false,
54 }
55 }
56
57 pub fn detailed(mut self) -> Self {
62 self.detailed = true;
63 self
64 }
65}
66
67impl<S> Layer<S> for FailOnStatus {
68 type Service = FailOnStatusService<S>;
69
70 fn layer(self, inner: S) -> Self::Service {
71 FailOnStatusService {
72 inner,
73 fail_on: self,
74 }
75 }
76}
77
78pub struct FailOnStatusService<S> {
82 inner: S,
83 fail_on: FailOnStatus,
84}
85
86impl<Cx, Req, S, B> Service<Cx, Req> for FailOnStatusService<S>
87where
88 Cx: Context + Send,
89 Req: RequestPartsExt + Send,
90 S: Service<Cx, Req, Response = Response<B>, Error = ClientError> + Send + Sync,
91{
92 type Response = S::Response;
93 type Error = S::Error;
94
95 async fn call(&self, cx: &mut Cx, req: Req) -> Result<Self::Response, Self::Error> {
96 let url = if self.fail_on.detailed {
97 req.url()
98 } else {
99 None
100 };
101 let resp = self.inner.call(cx, req).await?;
102 let status = resp.status();
103 if (self.fail_on.client_error && status.is_client_error())
104 || (self.fail_on.server_error && status.is_server_error())
105 {
106 Err(request_error(StatusCodeError { status, url })
107 .with_endpoint(cx.rpc_info().callee()))
108 } else {
109 Ok(resp)
110 }
111 }
112}
113
114pub struct StatusCodeError {
116 status: StatusCode,
117 url: Option<Url>,
118}
119
120impl StatusCodeError {
121 pub fn status(&self) -> StatusCode {
123 self.status
124 }
125
126 pub fn url(&self) -> Option<&Url> {
130 self.url.as_ref()
131 }
132}
133
134impl fmt::Debug for StatusCodeError {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_struct("StatusCodeError")
137 .field("status", &self.status)
138 .finish()
139 }
140}
141
142impl fmt::Display for StatusCodeError {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 write!(f, "client received an error status `{}`", self.status)?;
145 if let Some(url) = &self.url {
146 write!(f, " for `{url}`")?;
147 }
148 Ok(())
149 }
150}
151
152impl Error for StatusCodeError {}
153
154#[cfg(test)]
155mod fail_on_status_tests {
156 use http::status::StatusCode;
157 use motore::service::Service;
158
159 use super::FailOnStatus;
160 use crate::{
161 ClientBuilder, body::Body, client::test_helpers::MockTransport, context::ClientContext,
162 error::ClientError, request::Request, response::Response,
163 };
164
165 struct ReturnStatus;
166
167 impl Service<ClientContext, Request> for ReturnStatus {
168 type Response = Response;
169 type Error = ClientError;
170
171 fn call(
172 &self,
173 _: &mut ClientContext,
174 req: Request,
175 ) -> impl std::future::Future<Output = Result<Self::Response, Self::Error>> + Send {
176 let path = req.uri().path();
177 assert_eq!(&path[..1], "/");
178 let status_code = path[1..].parse::<u16>().expect("invalid uri");
179 let status_code = StatusCode::from_u16(status_code).expect("invalid status code");
180 let mut resp = Response::new(Body::empty());
181 *resp.status_mut() = status_code;
182 async { Ok(resp) }
183 }
184 }
185
186 #[tokio::test]
187 async fn fail_on_status_test() {
188 {
189 let client = ClientBuilder::new()
191 .layer_outer_front(FailOnStatus::all())
192 .mock(MockTransport::service(ReturnStatus))
193 .unwrap();
194 client.get("/400").send().await.unwrap_err();
195 client.get("/500").send().await.unwrap_err();
196 }
197 {
198 let client = ClientBuilder::new()
200 .layer_outer_front(FailOnStatus::client_error())
201 .mock(MockTransport::service(ReturnStatus))
202 .unwrap();
203 client.get("/400").send().await.unwrap_err();
204 client.get("/500").send().await.unwrap();
206 }
207 {
208 let client = ClientBuilder::new()
210 .layer_outer_front(FailOnStatus::server_error())
211 .mock(MockTransport::service(ReturnStatus))
212 .unwrap();
213 client.get("/400").send().await.unwrap();
215 client.get("/500").send().await.unwrap_err();
216 }
217 }
218}