1use std::sync::Arc;
2
3use async_trait::async_trait;
4use reqwest::header::{InvalidHeaderValue, CONTENT_TYPE};
5use reqwest::StatusCode;
6use thiserror::Error;
7use url::Url;
8
9use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
10use crate::{serialize_proto_message, TwirpErrorResponse};
11
12#[derive(Debug, Error)]
13pub enum ClientError {
14 #[error(transparent)]
15 InvalidHeader(#[from] InvalidHeaderValue),
16 #[error("base_url must end in /, but got: {0}")]
17 InvalidBaseUrl(Url),
18 #[error(transparent)]
19 InvalidUrl(#[from] url::ParseError),
20 #[error(
21 "http error, status code: {status}, msg:{msg} for path:{path} and content-type:{content_type}"
22 )]
23 HttpError {
24 status: StatusCode,
25 msg: String,
26 path: String,
27 content_type: String,
28 },
29 #[error(transparent)]
30 JsonDecodeError(#[from] serde_json::Error),
31 #[error("malformed response: {0}")]
32 MalformedResponse(String),
33 #[error(transparent)]
34 ProtoDecodeError(#[from] prost::DecodeError),
35 #[error(transparent)]
36 ReqwestError(#[from] reqwest::Error),
37 #[error("twirp error: {0:?}")]
38 TwirpError(TwirpErrorResponse),
39}
40
41pub type Result<T, E = ClientError> = std::result::Result<T, E>;
42
43pub struct ClientBuilder {
44 base_url: Url,
45 http_client: reqwest::Client,
46 middleware: Vec<Arc<dyn Middleware>>,
47}
48
49impl ClientBuilder {
50 pub fn new(base_url: Url, http_client: reqwest::Client) -> Self {
51 Self {
52 base_url,
53 middleware: vec![],
54 http_client,
55 }
56 }
57
58 pub fn with<M>(self, middleware: M) -> Self
62 where
63 M: Middleware,
64 {
65 let mut mw = self.middleware.clone();
66 mw.push(Arc::new(middleware));
67 Self {
68 base_url: self.base_url,
69 http_client: self.http_client,
70 middleware: mw,
71 }
72 }
73
74 pub fn build(self) -> Result<Client> {
75 Client::new(self.base_url, self.http_client, self.middleware)
76 }
77}
78
79#[derive(Clone)]
82pub struct Client {
83 pub base_url: Url,
84 http_client: reqwest::Client,
85 middlewares: Vec<Arc<dyn Middleware>>,
86}
87
88impl std::fmt::Debug for Client {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("TwirpClient")
91 .field("base_url", &self.base_url)
92 .field("client", &self.http_client)
93 .field("middlewares", &self.middlewares.len())
94 .finish()
95 }
96}
97
98impl Client {
99 pub fn new(
104 base_url: Url,
105 http_client: reqwest::Client,
106 middlewares: Vec<Arc<dyn Middleware>>,
107 ) -> Result<Self> {
108 if base_url.path().ends_with('/') {
109 Ok(Client {
110 base_url,
111 http_client,
112 middlewares,
113 })
114 } else {
115 Err(ClientError::InvalidBaseUrl(base_url))
116 }
117 }
118
119 pub fn from_base_url(base_url: Url) -> Result<Self> {
124 Self::new(base_url, reqwest::Client::new(), vec![])
125 }
126
127 pub fn with<M>(mut self, middleware: M) -> Self
131 where
132 M: Middleware,
133 {
134 self.middlewares.push(Arc::new(middleware));
135 self
136 }
137
138 pub async fn request<I, O>(&self, url: Url, body: I) -> Result<O>
140 where
141 I: prost::Message,
142 O: prost::Message + Default,
143 {
144 let path = url.path().to_string();
145 let req = self
146 .http_client
147 .post(url)
148 .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
149 .body(serialize_proto_message(body))
150 .build()?;
151
152 let next = Next::new(&self.http_client, &self.middlewares);
154 let resp = next.run(req).await?;
155
156 let status = resp.status();
158 let content_type = resp.headers().get(CONTENT_TYPE).cloned();
159
160 match (status, content_type) {
162 (status, Some(ct)) if status.is_success() && ct.as_bytes() == CONTENT_TYPE_PROTOBUF => {
163 O::decode(resp.bytes().await?).map_err(|e| e.into())
164 }
165 (status, Some(ct))
166 if (status.is_client_error() || status.is_server_error())
167 && ct.as_bytes() == CONTENT_TYPE_JSON =>
168 {
169 Err(ClientError::TwirpError(serde_json::from_slice(
170 &resp.bytes().await?,
171 )?))
172 }
173 (status, ct) => Err(ClientError::HttpError {
174 status,
175 msg: "unknown error".to_string(),
176 path,
177 content_type: ct
178 .map(|x| x.to_str().unwrap_or_default().to_string())
179 .unwrap_or_default(),
180 }),
181 }
182 }
183}
184
185#[async_trait]
189pub trait Middleware: 'static + Send + Sync {
190 async fn handle(&self, mut req: reqwest::Request, next: Next<'_>) -> Result<reqwest::Response>;
191}
192
193#[async_trait]
194impl<F> Middleware for F
195where
196 F: Send
197 + Sync
198 + 'static
199 + for<'a> Fn(reqwest::Request, Next<'a>) -> BoxFuture<'a, Result<reqwest::Response>>,
200{
201 async fn handle(&self, req: reqwest::Request, next: Next<'_>) -> Result<reqwest::Response> {
202 (self)(req, next).await
203 }
204}
205
206#[derive(Clone)]
207pub struct Next<'a> {
208 client: &'a reqwest::Client,
209 middlewares: &'a [Arc<dyn Middleware>],
210}
211
212pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
213
214impl<'a> Next<'a> {
215 pub(crate) fn new(client: &'a reqwest::Client, middlewares: &'a [Arc<dyn Middleware>]) -> Self {
216 Next {
217 client,
218 middlewares,
219 }
220 }
221
222 pub fn run(mut self, req: reqwest::Request) -> BoxFuture<'a, Result<reqwest::Response>> {
223 if let Some((current, rest)) = self.middlewares.split_first() {
224 self.middlewares = rest;
225 Box::pin(current.handle(req, self))
226 } else {
227 Box::pin(async move { self.client.execute(req).await.map_err(ClientError::from) })
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use reqwest::{Request, Response};
235
236 use crate::test::*;
237
238 use super::*;
239
240 struct AssertRouting {
241 expected_url: &'static str,
242 }
243
244 #[async_trait]
245 impl Middleware for AssertRouting {
246 async fn handle(&self, req: Request, next: Next<'_>) -> Result<Response> {
247 assert_eq!(self.expected_url, &req.url().to_string());
248 next.run(req).await
249 }
250 }
251
252 #[tokio::test]
253 async fn test_base_url() {
254 let url = Url::parse("http://localhost:3001/twirp/").unwrap();
255 assert!(Client::from_base_url(url).is_ok());
256 let url = Url::parse("http://localhost:3001/twirp").unwrap();
257 assert_eq!(
258 Client::from_base_url(url).unwrap_err().to_string(),
259 "base_url must end in /, but got: http://localhost:3001/twirp",
260 );
261 }
262
263 #[tokio::test]
264 async fn test_routes() {
265 let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
266
267 let client = ClientBuilder::new(base_url, reqwest::Client::new())
268 .with(AssertRouting {
269 expected_url: "http://localhost:3001/twirp/test.TestAPI/Ping",
270 })
271 .build()
272 .unwrap();
273 assert!(client
274 .ping(PingRequest {
275 name: "hi".to_string(),
276 })
277 .await
278 .is_err()); }
280
281 #[tokio::test]
282 #[ignore = "integration"]
283 async fn test_standard_client() {
284 let h = run_test_server(3001).await;
285 let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
286 let client = Client::from_base_url(base_url).unwrap();
287 let resp = client
288 .ping(PingRequest {
289 name: "hi".to_string(),
290 })
291 .await
292 .unwrap();
293 assert_eq!(&resp.name, "hi");
294 h.abort()
295 }
296}