twurst_client/
lib.rs

1#![doc = include_str!("../README.md")]
2#![doc(
3    test(attr(deny(warnings))),
4    html_favicon_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png",
5    html_logo_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8
9use http::header::CONTENT_TYPE;
10use http::{HeaderValue, Method, Request, Response, StatusCode};
11use http_body::{Body, Frame, SizeHint};
12use http_body_util::BodyExt;
13use prost_reflect::bytes::{Buf, Bytes, BytesMut};
14use prost_reflect::{DynamicMessage, ReflectMessage};
15use serde::Serialize;
16use std::convert::Infallible;
17use std::error::Error;
18use std::future::poll_fn;
19#[cfg(feature = "reqwest-012")]
20use std::future::Future;
21use std::mem::take;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24use tower_service::Service;
25pub use twurst_error::{TwirpError, TwirpErrorCode};
26
27const APPLICATION_JSON: HeaderValue = HeaderValue::from_static("application/json");
28const APPLICATION_PROTOBUF: HeaderValue = HeaderValue::from_static("application/protobuf");
29
30/// Underlying client used by autogenerated clients to handle networking.
31///
32/// Can be constructed with [`TwirpHttpClient::new_using_reqwest_012`] to use [`reqwest 0.12`](reqwest_012)
33/// or from a regular [`tower::Service`](Service) using [`TwirpHttpClient::new_with_base`]
34/// or [`TwirpHttpClient::new`] if relative URLs are fine.
35///
36/// URL grammar for twirp service is `URL ::= Base-URL [ Prefix ] "/" [ Package "." ] Service "/" Method`.
37/// The `/ [ Package "." ] Service "/" Method` part is auto-generated by the build step
38/// but the `Base-URL [ Prefix ]` must be set to do proper call to remote services.
39/// This is the `base_url` parameter.
40/// If not filled, request URL is only going to be the auto-generated part.
41#[derive(Clone)]
42pub struct TwirpHttpClient<S: TwirpHttpService> {
43    service: S,
44    base_url: Option<String>,
45    use_json: bool,
46}
47
48#[cfg(feature = "reqwest-012")]
49impl TwirpHttpClient<Reqwest012Service> {
50    /// Builds a new client using [`reqwest 0.12`](reqwest_012).
51    ///
52    /// Note that `base_url` must be absolute with a scheme like `https://`.
53    ///
54    /// ```
55    /// use twurst_client::TwirpHttpClient;
56    ///
57    /// let _client = TwirpHttpClient::new_using_reqwest_012("http://example.com/twirp");
58    /// ```
59    pub fn new_using_reqwest_012(base_url: impl Into<String>) -> Self {
60        Self::new_with_reqwest_012_client(reqwest_012::Client::new(), base_url)
61    }
62
63    /// Builds a new client using [`reqwest 0.12`](reqwest_012).
64    ///
65    /// Note that `base_url` must be absolute with a scheme like `https://`.
66    ///
67    /// ```
68    /// # use reqwest_012::Client;
69    /// use twurst_client::TwirpHttpClient;
70    ///
71    /// let _client =
72    ///     TwirpHttpClient::new_with_reqwest_012_client(Client::new(), "http://example.com/twirp");
73    /// ```
74    pub fn new_with_reqwest_012_client(
75        client: reqwest_012::Client,
76        base_url: impl Into<String>,
77    ) -> Self {
78        Self::new_with_base(Reqwest012Service(client), base_url)
79    }
80}
81
82impl<S: TwirpHttpService> TwirpHttpClient<S> {
83    /// Builds a new client from a [`tower::Service`](Service) and a base URL to the Twirp endpoint.
84    ///
85    /// ```
86    /// use http::Response;
87    /// use std::convert::Infallible;
88    /// use twurst_client::TwirpHttpClient;
89    /// use twurst_error::TwirpError;
90    ///
91    /// let _client = TwirpHttpClient::new_with_base(
92    ///     tower::service_fn(|_request| async {
93    ///         Ok::<Response<String>, Infallible>(TwirpError::unimplemented("not implemented").into())
94    ///     }),
95    ///     "http://example.com/twirp",
96    /// );
97    /// ```
98    pub fn new_with_base(service: S, base_url: impl Into<String>) -> Self {
99        let mut base_url = base_url.into();
100        // We remove the last '/' to make concatenation work
101        if base_url.ends_with('/') {
102            base_url.pop();
103        }
104        Self {
105            service,
106            base_url: Some(base_url),
107            use_json: false,
108        }
109    }
110
111    /// New service without base URL. Relative URLs will be used for requests!
112    ///
113    /// ```
114    /// use http::Response;
115    /// use std::convert::Infallible;
116    /// use twurst_client::TwirpHttpClient;
117    /// use twurst_error::TwirpError;
118    ///
119    /// let _client = TwirpHttpClient::new(tower::service_fn(|_request| async {
120    ///     Ok::<Response<String>, Infallible>(TwirpError::unimplemented("not implemented").into())
121    /// }));
122    /// ```
123    pub fn new(service: S) -> Self {
124        Self {
125            service,
126            base_url: None,
127            use_json: false,
128        }
129    }
130
131    /// Use JSON for requests and response instead of binary protobuf encoding that is used by default
132    pub fn use_json(&mut self) {
133        self.use_json = true;
134    }
135
136    /// Use binary protobuf encoding for requests and response (the default)
137    pub fn use_binary_protobuf(&mut self) {
138        self.use_json = false;
139    }
140
141    /// Send a Twirp request and get a response.
142    ///
143    /// Used internally by the generated code.
144    pub async fn call<I: ReflectMessage, O: ReflectMessage + Default>(
145        &self,
146        path: &str,
147        request: &I,
148    ) -> Result<O, TwirpError> {
149        // We ensure that the service is ready
150        self.service.ready().await.map_err(|e| {
151            TwirpError::wrap(
152                TwirpErrorCode::Unknown,
153                format!("Service is not ready: {e}"),
154                e,
155            )
156        })?;
157        let request = self.build_request(path, request)?;
158        let response = self.service.call(request).await.map_err(|e| {
159            TwirpError::wrap(
160                TwirpErrorCode::Unknown,
161                format!("Transport error during the request: {e}"),
162                e,
163            )
164        })?;
165        self.extract_response(response).await
166    }
167
168    fn build_request<T: ReflectMessage>(
169        &self,
170        path: &str,
171        message: &T,
172    ) -> Result<Request<TwirpRequestBody>, TwirpError> {
173        let mut request_builder = Request::builder().method(Method::POST);
174        request_builder = if let Some(base_url) = &self.base_url {
175            request_builder.uri(format!("{}{}", base_url, path))
176        } else {
177            request_builder.uri(path)
178        };
179        if self.use_json {
180            request_builder
181                .header(CONTENT_TYPE, APPLICATION_JSON)
182                .body(json_encode(message)?.into())
183        } else {
184            let mut buffer = BytesMut::with_capacity(message.encoded_len());
185            message.encode(&mut buffer).map_err(|e| {
186                TwirpError::wrap(
187                    TwirpErrorCode::Internal,
188                    format!("Failed to serialize to protobuf: {e}"),
189                    e,
190                )
191            })?;
192            request_builder
193                .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
194                .body(Bytes::from(buffer).into())
195        }
196        .map_err(|e| {
197            TwirpError::wrap(
198                TwirpErrorCode::Malformed,
199                format!("Failed to construct request: {e}"),
200                e,
201            )
202        })
203    }
204
205    async fn extract_response<T: ReflectMessage + Default>(
206        &self,
207        response: Response<S::ResponseBody>,
208    ) -> Result<T, TwirpError> {
209        // We collect the body
210        // TODO: size limit
211        let (parts, body) = response.into_parts();
212        let body = body.collect().await.map_err(|e| {
213            TwirpError::wrap(
214                TwirpErrorCode::Internal,
215                format!("Failed to load request body: {e}"),
216                e,
217            )
218        })?;
219        let response = Response::from_parts(parts, body);
220
221        // Error
222        if response.status() != StatusCode::OK {
223            return Err(response.map(|b| b.to_bytes()).into());
224        }
225
226        // Success
227        let content_type = response.headers().get(CONTENT_TYPE).cloned();
228        let body = response.into_body();
229        if content_type == Some(APPLICATION_PROTOBUF) {
230            T::decode(body.aggregate()).map_err(|e| {
231                TwirpError::wrap(
232                    TwirpErrorCode::Malformed,
233                    format!("Bad response binary protobuf encoding: {e}"),
234                    e,
235                )
236            })
237        } else if content_type == Some(APPLICATION_JSON) {
238            json_decode(&body.to_bytes())
239        } else if let Some(content_type) = content_type {
240            Err(TwirpError::malformed(format!(
241                "Unsupported response content-type: {}",
242                String::from_utf8_lossy(content_type.as_bytes())
243            )))
244        } else {
245            Err(TwirpError::malformed("No content-type in the response"))
246        }
247    }
248}
249
250/// A service that can be used to send Twirp requests eg. an HTTP client
251///
252/// Used by [`TwirpHttpClient`] to handle HTTP.
253#[trait_variant::make(Send)]
254pub trait TwirpHttpService: 'static {
255    type ResponseBody: Body<Error: Error + Send + Sync>;
256    type Error: Error + Send + Sync + 'static;
257
258    async fn ready(&self) -> Result<(), Self::Error>;
259
260    async fn call(
261        &self,
262        request: Request<TwirpRequestBody>,
263    ) -> Result<Response<Self::ResponseBody>, Self::Error>;
264}
265
266impl<
267        S: Service<
268                Request<TwirpRequestBody>,
269                Error: Error + Send + Sync + 'static,
270                Response = Response<RespBody>,
271                Future: Send,
272            > + Clone
273            + Send
274            + Sync
275            + 'static,
276        RespBody: Body<Error: Error + Send + Sync + 'static>,
277    > TwirpHttpService for S
278{
279    type ResponseBody = RespBody;
280    type Error = S::Error;
281
282    async fn ready(&self) -> Result<(), Self::Error> {
283        poll_fn(|cx| Service::poll_ready(&mut self.clone(), cx)).await
284    }
285
286    async fn call(
287        &self,
288        request: Request<TwirpRequestBody>,
289    ) -> Result<Response<RespBody>, S::Error> {
290        Service::call(&mut self.clone(), request).await
291    }
292}
293
294/// Request body for Twirp requests.
295///
296/// It is a thin wrapper on top of [`Bytes`] to implement [`Body`].
297pub struct TwirpRequestBody(Bytes);
298
299impl From<Bytes> for TwirpRequestBody {
300    #[inline]
301    fn from(body: Bytes) -> Self {
302        Self(body)
303    }
304}
305
306impl From<TwirpRequestBody> for Bytes {
307    #[inline]
308    fn from(body: TwirpRequestBody) -> Self {
309        body.0
310    }
311}
312
313impl Body for TwirpRequestBody {
314    type Data = Bytes;
315    type Error = Infallible;
316
317    #[inline]
318    fn poll_frame(
319        mut self: Pin<&mut Self>,
320        _cx: &mut Context<'_>,
321    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
322        let data = take(&mut self.0);
323        Poll::Ready(if data.has_remaining() {
324            Some(Ok(Frame::data(data)))
325        } else {
326            None
327        })
328    }
329
330    #[inline]
331    fn is_end_stream(&self) -> bool {
332        !self.0.has_remaining()
333    }
334
335    #[inline]
336    fn size_hint(&self) -> SizeHint {
337        SizeHint::with_exact(self.0.remaining() as u64)
338    }
339}
340
341fn json_encode<T: ReflectMessage>(message: &T) -> Result<Bytes, TwirpError> {
342    let mut serializer = serde_json::Serializer::new(Vec::new());
343    message
344        .transcode_to_dynamic()
345        .serialize(&mut serializer)
346        .map_err(|e| {
347            TwirpError::wrap(
348                TwirpErrorCode::Malformed,
349                format!("Failed to serialize request to JSON: {e}"),
350                e,
351            )
352        })?;
353    Ok(serializer.into_inner().into())
354}
355
356fn json_decode<T: ReflectMessage + Default>(message: &[u8]) -> Result<T, TwirpError> {
357    let dynamic_message = dynamic_json_decode::<T>(message).map_err(|e| {
358        TwirpError::wrap(
359            TwirpErrorCode::Malformed,
360            format!("Failed to parse JSON response: {e}"),
361            e,
362        )
363    })?;
364    dynamic_message.transcode_to().map_err(|e| {
365        TwirpError::internal(format!(
366            "Internal error while parsing the JSON response: {e}"
367        ))
368    })
369}
370
371fn dynamic_json_decode<T: ReflectMessage + Default>(
372    message: &[u8],
373) -> Result<DynamicMessage, serde_json::Error> {
374    let mut deserializer = serde_json::Deserializer::from_slice(message);
375    let dynamic_message =
376        DynamicMessage::deserialize(T::default().descriptor(), &mut deserializer)?;
377    deserializer.end()?;
378    Ok(dynamic_message)
379}
380
381/// Wraps a [`reqwest::Client`](reqwest_012::Client) into a [`tower::Service`](Service) compatible with [`TwirpHttpClient`].
382#[cfg(feature = "reqwest-012")]
383#[derive(Clone, Default)]
384pub struct Reqwest012Service(reqwest_012::Client);
385
386#[cfg(feature = "reqwest-012")]
387impl Reqwest012Service {
388    #[inline]
389    pub fn new() -> Self {
390        reqwest_012::Client::new().into()
391    }
392}
393
394#[cfg(feature = "reqwest-012")]
395impl From<reqwest_012::Client> for Reqwest012Service {
396    #[inline]
397    fn from(client: reqwest_012::Client) -> Self {
398        Self(client)
399    }
400}
401
402#[cfg(feature = "reqwest-012")]
403impl<B: Into<reqwest_012::Body>> Service<Request<B>> for Reqwest012Service {
404    type Response = Response<reqwest_012::Body>;
405    type Error = reqwest_012::Error;
406    type Future = Pin<
407        Box<dyn Future<Output = Result<Response<reqwest_012::Body>, reqwest_012::Error>> + Send>,
408    >;
409
410    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411        self.0.poll_ready(cx)
412    }
413
414    fn call(&mut self, req: Request<B>) -> Self::Future {
415        let req = match req.try_into() {
416            Ok(req) => req,
417            Err(e) => return Box::pin(async move { Err(e) }),
418        };
419        let future = self.0.call(req);
420        Box::pin(async move { Ok(future.await?.into()) })
421    }
422}
423
424#[cfg(feature = "reqwest-012")]
425impl From<TwirpRequestBody> for reqwest_012::Body {
426    #[inline]
427    fn from(body: TwirpRequestBody) -> Self {
428        body.0.into()
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    #[cfg(feature = "reqwest-012")]
436    use prost_reflect::prost::Message;
437    use prost_reflect::prost_types::Timestamp;
438    use std::future::Ready;
439    use std::io;
440    use std::task::{Context, Poll};
441    use tower::service_fn;
442
443    #[tokio::test]
444    async fn not_ready_service() -> Result<(), Box<dyn Error>> {
445        #[derive(Clone)]
446        struct NotReadyService;
447
448        impl<S> Service<S> for NotReadyService {
449            type Response = Response<String>;
450            type Error = TwirpError;
451            type Future = Ready<Result<Response<String>, TwirpError>>;
452
453            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
454                Poll::Ready(Err(TwirpError::internal("foo")))
455            }
456
457            fn call(&mut self, _: S) -> Self::Future {
458                unimplemented!()
459            }
460        }
461
462        let client = TwirpHttpClient::new(NotReadyService);
463        assert_eq!(
464            client
465                .call::<_, Timestamp>("", &Timestamp::default())
466                .await
467                .unwrap_err()
468                .to_string(),
469            "Twirp Unknown error: Service is not ready: Twirp Internal error: foo"
470        );
471        Ok(())
472    }
473
474    #[tokio::test]
475    async fn json_request_without_base_ok() -> Result<(), Box<dyn Error>> {
476        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
477            assert_eq!(request.method(), Method::POST);
478            assert_eq!(request.uri(), "/foo");
479            Ok::<_, TwirpError>(
480                Response::builder()
481                    .header(CONTENT_TYPE, APPLICATION_JSON)
482                    .body("\"1970-01-01T00:00:10Z\"".to_string())
483                    .unwrap(),
484            )
485        });
486
487        let mut client = TwirpHttpClient::new(service);
488        client.use_json();
489        let response = client
490            .call::<_, Timestamp>(
491                "/foo",
492                &Timestamp {
493                    seconds: 10,
494                    nanos: 0,
495                },
496            )
497            .await?;
498        assert_eq!(
499            response,
500            Timestamp {
501                seconds: 10,
502                nanos: 0
503            }
504        );
505        Ok(())
506    }
507
508    #[cfg(feature = "reqwest-012")]
509    #[tokio::test]
510    async fn binary_request_without_base_ok() -> Result<(), Box<dyn Error>> {
511        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
512            assert_eq!(request.method(), Method::POST);
513            assert_eq!(request.uri(), "/foo");
514            Ok::<_, TwirpError>(
515                Response::builder()
516                    .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
517                    .body(reqwest_012::Body::from(
518                        Timestamp {
519                            seconds: 10,
520                            nanos: 0,
521                        }
522                        .encode_to_vec(),
523                    ))
524                    .unwrap(),
525            )
526        });
527
528        let response = TwirpHttpClient::new(service)
529            .call::<_, Timestamp>(
530                "/foo",
531                &Timestamp {
532                    seconds: 10,
533                    nanos: 0,
534                },
535            )
536            .await?;
537        assert_eq!(
538            response,
539            Timestamp {
540                seconds: 10,
541                nanos: 0
542            }
543        );
544        Ok(())
545    }
546
547    #[tokio::test]
548    async fn request_with_base_twirp_error() -> Result<(), Box<dyn Error>> {
549        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
550            assert_eq!(request.method(), Method::POST);
551            assert_eq!(request.uri(), "http://example.com/twirp/foo");
552            Ok::<Response<String>, TwirpError>(TwirpError::not_found("not found").into())
553        });
554
555        let response_error = TwirpHttpClient::new_with_base(service, "http://example.com/twirp")
556            .call::<_, Timestamp>(
557                "/foo",
558                &Timestamp {
559                    seconds: 10,
560                    nanos: 0,
561                },
562            )
563            .await
564            .unwrap_err();
565        assert_eq!(response_error, TwirpError::not_found("not found"));
566        Ok(())
567    }
568
569    #[tokio::test]
570    async fn request_with_base_other_error() -> Result<(), Box<dyn Error>> {
571        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
572            assert_eq!(request.method(), Method::POST);
573            assert_eq!(request.uri(), "http://example.com/twirp/foo");
574            Ok::<Response<String>, TwirpError>(
575                Response::builder()
576                    .status(StatusCode::UNAUTHORIZED)
577                    .body("foo".to_string())
578                    .unwrap(),
579            )
580        });
581
582        let response_error = TwirpHttpClient::new_with_base(service, "http://example.com/twirp/")
583            .call::<_, Timestamp>(
584                "/foo",
585                &Timestamp {
586                    seconds: 10,
587                    nanos: 0,
588                },
589            )
590            .await
591            .unwrap_err();
592        assert_eq!(response_error, TwirpError::unauthenticated("foo"));
593        Ok(())
594    }
595
596    #[tokio::test]
597    async fn request_transport_error() -> Result<(), Box<dyn Error>> {
598        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
599            assert_eq!(request.method(), Method::POST);
600            assert_eq!(request.uri(), "/foo");
601            Err::<Response<String>, _>(io::Error::other("Transport error"))
602        });
603
604        let response_error = TwirpHttpClient::new(service)
605            .call::<_, Timestamp>(
606                "/foo",
607                &Timestamp {
608                    seconds: 10,
609                    nanos: 0,
610                },
611            )
612            .await
613            .unwrap_err();
614        assert_eq!(
615            response_error,
616            TwirpError::new(
617                TwirpErrorCode::Unknown,
618                "Transport error during the request: Transport error"
619            )
620        );
621        Ok(())
622    }
623
624    #[tokio::test]
625    async fn wrong_content_type_response() -> Result<(), Box<dyn Error>> {
626        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
627            assert_eq!(request.method(), Method::POST);
628            assert_eq!(request.uri(), "/foo");
629            Ok::<Response<String>, TwirpError>(
630                Response::builder()
631                    .status(StatusCode::OK)
632                    .header(CONTENT_TYPE, "foo/bar")
633                    .body("foo".into())
634                    .unwrap(),
635            )
636        });
637
638        let response_error = TwirpHttpClient::new(service)
639            .call::<_, Timestamp>(
640                "/foo",
641                &Timestamp {
642                    seconds: 10,
643                    nanos: 0,
644                },
645            )
646            .await
647            .unwrap_err();
648        assert_eq!(
649            response_error,
650            TwirpError::malformed("Unsupported response content-type: foo/bar")
651        );
652        Ok(())
653    }
654
655    #[tokio::test]
656    async fn invalid_protobuf_response() -> Result<(), Box<dyn Error>> {
657        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
658            assert_eq!(request.method(), Method::POST);
659            assert_eq!(request.uri(), "/foo");
660            Ok::<Response<String>, TwirpError>(
661                Response::builder()
662                    .status(StatusCode::OK)
663                    .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
664                    .body("azerty".into())
665                    .unwrap(),
666            )
667        });
668
669        let mut client = TwirpHttpClient::new(service);
670        client.use_json();
671        let response_error = client
672            .call::<_, Timestamp>(
673                "/foo",
674                &Timestamp {
675                    seconds: 10,
676                    nanos: 0,
677                },
678            )
679            .await
680            .unwrap_err();
681        assert_eq!(
682            response_error,
683            TwirpError::malformed("Bad response binary protobuf encoding: failed to decode Protobuf message: buffer underflow")
684        );
685        Ok(())
686    }
687
688    #[tokio::test]
689    async fn invalid_json_response() -> Result<(), Box<dyn Error>> {
690        let service = service_fn(|request: Request<TwirpRequestBody>| async move {
691            assert_eq!(request.method(), Method::POST);
692            assert_eq!(request.uri(), "/foo");
693            Ok::<Response<String>, TwirpError>(
694                Response::builder()
695                    .status(StatusCode::OK)
696                    .header(CONTENT_TYPE, APPLICATION_JSON)
697                    .body("foo".into())
698                    .unwrap(),
699            )
700        });
701
702        let mut client = TwirpHttpClient::new(service);
703        client.use_json();
704        let response_error = client
705            .call::<_, Timestamp>(
706                "/foo",
707                &Timestamp {
708                    seconds: 10,
709                    nanos: 0,
710                },
711            )
712            .await
713            .unwrap_err();
714        assert_eq!(
715            response_error,
716            TwirpError::malformed(
717                "Failed to parse JSON response: expected ident at line 1 column 2"
718            )
719        );
720        Ok(())
721    }
722
723    #[tokio::test]
724    async fn response_future_is_send() {
725        fn is_send<T: Send>(_: T) {}
726
727        let service = service_fn(|_: Request<TwirpRequestBody>| async move {
728            Ok::<_, TwirpError>(Response::new(String::new()))
729        });
730        let client = TwirpHttpClient::new(service);
731
732        // This will fail to compile if the future is not Send
733        is_send(client.call::<_, Timestamp>("/foo", &Timestamp::default()));
734    }
735}