twirp_rs/
client.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use reqwest::header::{InvalidHeaderValue, CONTENT_TYPE};
5use reqwest::StatusCode;
6use thiserror::Error;
7use url::Url;
8
9use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
10use crate::{serialize_proto_message, TwirpErrorResponse};
11
12#[derive(Debug, Error)]
13pub enum ClientError {
14    #[error(transparent)]
15    InvalidHeader(#[from] InvalidHeaderValue),
16    #[error("base_url must end in /, but got: {0}")]
17    InvalidBaseUrl(Url),
18    #[error(transparent)]
19    InvalidUrl(#[from] url::ParseError),
20    #[error(
21        "http error, status code: {status}, msg:{msg} for path:{path} and content-type:{content_type}"
22    )]
23    HttpError {
24        status: StatusCode,
25        msg: String,
26        path: String,
27        content_type: String,
28    },
29    #[error(transparent)]
30    JsonDecodeError(#[from] serde_json::Error),
31    #[error("malformed response: {0}")]
32    MalformedResponse(String),
33    #[error(transparent)]
34    ProtoDecodeError(#[from] prost::DecodeError),
35    #[error(transparent)]
36    ReqwestError(#[from] reqwest::Error),
37    #[error("twirp error: {0:?}")]
38    TwirpError(TwirpErrorResponse),
39}
40
41pub type Result<T, E = ClientError> = std::result::Result<T, E>;
42
43pub struct ClientBuilder {
44    base_url: Url,
45    http_client: reqwest::Client,
46    middleware: Vec<Arc<dyn Middleware>>,
47}
48
49impl ClientBuilder {
50    pub fn new(base_url: Url, http_client: reqwest::Client) -> Self {
51        Self {
52            base_url,
53            middleware: vec![],
54            http_client,
55        }
56    }
57
58    /// Add middleware to the client that will be called on each request.
59    /// Middlewares are invoked in the order they are added as part of the
60    /// request cycle.
61    pub fn with<M>(self, middleware: M) -> Self
62    where
63        M: Middleware,
64    {
65        let mut mw = self.middleware.clone();
66        mw.push(Arc::new(middleware));
67        Self {
68            base_url: self.base_url,
69            http_client: self.http_client,
70            middleware: mw,
71        }
72    }
73
74    pub fn build(self) -> Result<Client> {
75        Client::new(self.base_url, self.http_client, self.middleware)
76    }
77}
78
79/// `HttpTwirpClient` is a TwirpClient that uses `reqwest::Client` to make http
80/// requests.
81#[derive(Clone)]
82pub struct Client {
83    pub base_url: Url,
84    http_client: reqwest::Client,
85    middlewares: Vec<Arc<dyn Middleware>>,
86}
87
88impl std::fmt::Debug for Client {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("TwirpClient")
91            .field("base_url", &self.base_url)
92            .field("client", &self.http_client)
93            .field("middlewares", &self.middlewares.len())
94            .finish()
95    }
96}
97
98impl Client {
99    /// Creates a `twirp::Client`.
100    ///
101    /// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
102    /// you create one and **reuse** it.
103    pub fn new(
104        base_url: Url,
105        http_client: reqwest::Client,
106        middlewares: Vec<Arc<dyn Middleware>>,
107    ) -> Result<Self> {
108        if base_url.path().ends_with('/') {
109            Ok(Client {
110                base_url,
111                http_client,
112                middlewares,
113            })
114        } else {
115            Err(ClientError::InvalidBaseUrl(base_url))
116        }
117    }
118
119    /// Creates a `twirp::Client` with the default `reqwest::ClientBuilder`.
120    ///
121    /// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
122    /// you create one and **reuse** it.
123    pub fn from_base_url(base_url: Url) -> Result<Self> {
124        Self::new(base_url, reqwest::Client::new(), vec![])
125    }
126
127    /// Add middleware to this specific request stack. Middlewares are invoked
128    /// in the order they are added as part of the request cycle. Middleware
129    /// added here will run after any middleware added with the `ClientBuilder`.
130    pub fn with<M>(mut self, middleware: M) -> Self
131    where
132        M: Middleware,
133    {
134        self.middlewares.push(Arc::new(middleware));
135        self
136    }
137
138    /// Make an HTTP twirp request.
139    pub async fn request<I, O>(&self, url: Url, body: I) -> Result<O>
140    where
141        I: prost::Message,
142        O: prost::Message + Default,
143    {
144        let path = url.path().to_string();
145        let req = self
146            .http_client
147            .post(url)
148            .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
149            .body(serialize_proto_message(body))
150            .build()?;
151
152        // Create and execute the middleware handlers
153        let next = Next::new(&self.http_client, &self.middlewares);
154        let resp = next.run(req).await?;
155
156        // These have to be extracted because reading the body consumes `Response`.
157        let status = resp.status();
158        let content_type = resp.headers().get(CONTENT_TYPE).cloned();
159
160        // TODO: Include more info in the error cases: request path, content-type, etc.
161        match (status, content_type) {
162            (status, Some(ct)) if status.is_success() && ct.as_bytes() == CONTENT_TYPE_PROTOBUF => {
163                O::decode(resp.bytes().await?).map_err(|e| e.into())
164            }
165            (status, Some(ct))
166                if (status.is_client_error() || status.is_server_error())
167                    && ct.as_bytes() == CONTENT_TYPE_JSON =>
168            {
169                Err(ClientError::TwirpError(serde_json::from_slice(
170                    &resp.bytes().await?,
171                )?))
172            }
173            (status, ct) => Err(ClientError::HttpError {
174                status,
175                msg: "unknown error".to_string(),
176                path,
177                content_type: ct
178                    .map(|x| x.to_str().unwrap_or_default().to_string())
179                    .unwrap_or_default(),
180            }),
181        }
182    }
183}
184
185// This concept of reqwest middleware is taken pretty much directly from:
186// https://github.com/TrueLayer/reqwest-middleware, but simplified for the
187// specific needs of this twirp client.
188#[async_trait]
189pub trait Middleware: 'static + Send + Sync {
190    async fn handle(&self, mut req: reqwest::Request, next: Next<'_>) -> Result<reqwest::Response>;
191}
192
193#[async_trait]
194impl<F> Middleware for F
195where
196    F: Send
197        + Sync
198        + 'static
199        + for<'a> Fn(reqwest::Request, Next<'a>) -> BoxFuture<'a, Result<reqwest::Response>>,
200{
201    async fn handle(&self, req: reqwest::Request, next: Next<'_>) -> Result<reqwest::Response> {
202        (self)(req, next).await
203    }
204}
205
206#[derive(Clone)]
207pub struct Next<'a> {
208    client: &'a reqwest::Client,
209    middlewares: &'a [Arc<dyn Middleware>],
210}
211
212pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
213
214impl<'a> Next<'a> {
215    pub(crate) fn new(client: &'a reqwest::Client, middlewares: &'a [Arc<dyn Middleware>]) -> Self {
216        Next {
217            client,
218            middlewares,
219        }
220    }
221
222    pub fn run(mut self, req: reqwest::Request) -> BoxFuture<'a, Result<reqwest::Response>> {
223        if let Some((current, rest)) = self.middlewares.split_first() {
224            self.middlewares = rest;
225            Box::pin(current.handle(req, self))
226        } else {
227            Box::pin(async move { self.client.execute(req).await.map_err(ClientError::from) })
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use reqwest::{Request, Response};
235
236    use crate::test::*;
237
238    use super::*;
239
240    struct AssertRouting {
241        expected_url: &'static str,
242    }
243
244    #[async_trait]
245    impl Middleware for AssertRouting {
246        async fn handle(&self, req: Request, next: Next<'_>) -> Result<Response> {
247            assert_eq!(self.expected_url, &req.url().to_string());
248            next.run(req).await
249        }
250    }
251
252    #[tokio::test]
253    async fn test_base_url() {
254        let url = Url::parse("http://localhost:3001/twirp/").unwrap();
255        assert!(Client::from_base_url(url).is_ok());
256        let url = Url::parse("http://localhost:3001/twirp").unwrap();
257        assert_eq!(
258            Client::from_base_url(url).unwrap_err().to_string(),
259            "base_url must end in /, but got: http://localhost:3001/twirp",
260        );
261    }
262
263    #[tokio::test]
264    async fn test_routes() {
265        let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
266
267        let client = ClientBuilder::new(base_url, reqwest::Client::new())
268            .with(AssertRouting {
269                expected_url: "http://localhost:3001/twirp/test.TestAPI/Ping",
270            })
271            .build()
272            .unwrap();
273        assert!(client
274            .ping(PingRequest {
275                name: "hi".to_string(),
276            })
277            .await
278            .is_err()); // expected connection refused error.
279    }
280
281    #[tokio::test]
282    #[ignore = "integration"]
283    async fn test_standard_client() {
284        let h = run_test_server(3001).await;
285        let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
286        let client = Client::from_base_url(base_url).unwrap();
287        let resp = client
288            .ping(PingRequest {
289                name: "hi".to_string(),
290            })
291            .await
292            .unwrap();
293        assert_eq!(&resp.name, "hi");
294        h.abort()
295    }
296}