1#![doc = include_str!("../README.md")]
2#![doc(
3 test(attr(deny(warnings))),
4 html_favicon_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png",
5 html_logo_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8
9use http::header::CONTENT_TYPE;
10use http::{HeaderValue, Method, Request, Response, StatusCode};
11use http_body::{Body, Frame, SizeHint};
12use http_body_util::BodyExt;
13use prost_reflect::bytes::{Buf, Bytes, BytesMut};
14use prost_reflect::{DeserializeOptions, DynamicMessage, ReflectMessage};
15use serde::Serialize;
16use std::convert::Infallible;
17use std::error::Error;
18#[cfg(feature = "reqwest-012")]
19use std::future::Future;
20use std::future::poll_fn;
21use std::mem::take;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24use tower_service::Service;
25pub use twurst_error::{TwirpError, TwirpErrorCode};
26
27const APPLICATION_JSON: HeaderValue = HeaderValue::from_static("application/json");
28const APPLICATION_PROTOBUF: HeaderValue = HeaderValue::from_static("application/protobuf");
29
30#[derive(Clone)]
42pub struct TwirpHttpClient<S: TwirpHttpService> {
43 service: S,
44 base_url: Option<String>,
45 use_json: bool,
46}
47
48#[cfg(feature = "reqwest-012")]
49impl TwirpHttpClient<Reqwest012Service> {
50 pub fn new_using_reqwest_012(base_url: impl Into<String>) -> Self {
60 Self::new_with_reqwest_012_client(reqwest_012::Client::new(), base_url)
61 }
62
63 pub fn new_with_reqwest_012_client(
75 client: reqwest_012::Client,
76 base_url: impl Into<String>,
77 ) -> Self {
78 Self::new_with_base(Reqwest012Service(client), base_url)
79 }
80}
81
82impl<S: TwirpHttpService> TwirpHttpClient<S> {
83 pub fn new_with_base(service: S, base_url: impl Into<String>) -> Self {
99 let mut base_url = base_url.into();
100 if base_url.ends_with('/') {
102 base_url.pop();
103 }
104 Self {
105 service,
106 base_url: Some(base_url),
107 use_json: false,
108 }
109 }
110
111 pub fn new(service: S) -> Self {
124 Self {
125 service,
126 base_url: None,
127 use_json: false,
128 }
129 }
130
131 pub fn use_json(&mut self) {
133 self.use_json = true;
134 }
135
136 pub fn use_binary_protobuf(&mut self) {
138 self.use_json = false;
139 }
140
141 pub async fn call<I: ReflectMessage, O: ReflectMessage + Default>(
145 &self,
146 path: &str,
147 request: &I,
148 ) -> Result<O, TwirpError> {
149 self.service.ready().await.map_err(|e| {
151 TwirpError::wrap(
152 TwirpErrorCode::Unknown,
153 format!("Service is not ready: {e}"),
154 e,
155 )
156 })?;
157 let request = self.build_request(path, request)?;
158 let response = self.service.call(request).await.map_err(|e| {
159 TwirpError::wrap(
160 TwirpErrorCode::Unknown,
161 format!("Transport error during the request: {e}"),
162 e,
163 )
164 })?;
165 self.extract_response(response).await
166 }
167
168 fn build_request<T: ReflectMessage>(
169 &self,
170 path: &str,
171 message: &T,
172 ) -> Result<Request<TwirpRequestBody>, TwirpError> {
173 let mut request_builder = Request::builder().method(Method::POST);
174 request_builder = if let Some(base_url) = &self.base_url {
175 request_builder.uri(format!("{base_url}{path}"))
176 } else {
177 request_builder.uri(path)
178 };
179 if self.use_json {
180 request_builder
181 .header(CONTENT_TYPE, APPLICATION_JSON)
182 .body(json_encode(message)?.into())
183 } else {
184 let mut buffer = BytesMut::with_capacity(message.encoded_len());
185 message.encode(&mut buffer).map_err(|e| {
186 TwirpError::wrap(
187 TwirpErrorCode::Internal,
188 format!("Failed to serialize to protobuf: {e}"),
189 e,
190 )
191 })?;
192 request_builder
193 .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
194 .body(Bytes::from(buffer).into())
195 }
196 .map_err(|e| {
197 TwirpError::wrap(
198 TwirpErrorCode::Malformed,
199 format!("Failed to construct request: {e}"),
200 e,
201 )
202 })
203 }
204
205 async fn extract_response<T: ReflectMessage + Default>(
206 &self,
207 response: Response<S::ResponseBody>,
208 ) -> Result<T, TwirpError> {
209 let (parts, body) = response.into_parts();
212 let body = body.collect().await.map_err(|e| {
213 TwirpError::wrap(
214 TwirpErrorCode::Internal,
215 format!("Failed to load request body: {e}"),
216 e,
217 )
218 })?;
219 let response = Response::from_parts(parts, body);
220
221 if response.status() != StatusCode::OK {
223 return Err(response.map(|b| b.to_bytes()).into());
224 }
225
226 let content_type = response.headers().get(CONTENT_TYPE).cloned();
228 let body = response.into_body();
229 if content_type == Some(APPLICATION_PROTOBUF) {
230 T::decode(body.aggregate()).map_err(|e| {
231 TwirpError::wrap(
232 TwirpErrorCode::Malformed,
233 format!("Bad response binary protobuf encoding: {e}"),
234 e,
235 )
236 })
237 } else if content_type == Some(APPLICATION_JSON) {
238 json_decode(&body.to_bytes())
239 } else if let Some(content_type) = content_type {
240 Err(TwirpError::malformed(format!(
241 "Unsupported response content-type: {}",
242 String::from_utf8_lossy(content_type.as_bytes())
243 )))
244 } else {
245 Err(TwirpError::malformed("No content-type in the response"))
246 }
247 }
248}
249
250#[trait_variant::make(Send)]
254pub trait TwirpHttpService: 'static {
255 type ResponseBody: Body<Error: Error + Send + Sync>;
256 type Error: Error + Send + Sync + 'static;
257
258 async fn ready(&self) -> Result<(), Self::Error>;
259
260 async fn call(
261 &self,
262 request: Request<TwirpRequestBody>,
263 ) -> Result<Response<Self::ResponseBody>, Self::Error>;
264}
265
266impl<
267 S: Service<
268 Request<TwirpRequestBody>,
269 Error: Error + Send + Sync + 'static,
270 Response = Response<RespBody>,
271 Future: Send,
272 > + Clone
273 + Send
274 + Sync
275 + 'static,
276 RespBody: Body<Error: Error + Send + Sync + 'static>,
277> TwirpHttpService for S
278{
279 type ResponseBody = RespBody;
280 type Error = S::Error;
281
282 async fn ready(&self) -> Result<(), Self::Error> {
283 poll_fn(|cx| Service::poll_ready(&mut self.clone(), cx)).await
284 }
285
286 async fn call(
287 &self,
288 request: Request<TwirpRequestBody>,
289 ) -> Result<Response<RespBody>, S::Error> {
290 Service::call(&mut self.clone(), request).await
291 }
292}
293
294pub struct TwirpRequestBody(Bytes);
298
299impl From<Bytes> for TwirpRequestBody {
300 #[inline]
301 fn from(body: Bytes) -> Self {
302 Self(body)
303 }
304}
305
306impl From<TwirpRequestBody> for Bytes {
307 #[inline]
308 fn from(body: TwirpRequestBody) -> Self {
309 body.0
310 }
311}
312
313impl Body for TwirpRequestBody {
314 type Data = Bytes;
315 type Error = Infallible;
316
317 #[inline]
318 fn poll_frame(
319 mut self: Pin<&mut Self>,
320 _cx: &mut Context<'_>,
321 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
322 let data = take(&mut self.0);
323 Poll::Ready(if data.has_remaining() {
324 Some(Ok(Frame::data(data)))
325 } else {
326 None
327 })
328 }
329
330 #[inline]
331 fn is_end_stream(&self) -> bool {
332 !self.0.has_remaining()
333 }
334
335 #[inline]
336 fn size_hint(&self) -> SizeHint {
337 SizeHint::with_exact(self.0.remaining() as u64)
338 }
339}
340
341fn json_encode<T: ReflectMessage>(message: &T) -> Result<Bytes, TwirpError> {
342 let mut serializer = serde_json::Serializer::new(Vec::new());
343 message
344 .transcode_to_dynamic()
345 .serialize(&mut serializer)
346 .map_err(|e| {
347 TwirpError::wrap(
348 TwirpErrorCode::Malformed,
349 format!("Failed to serialize request to JSON: {e}"),
350 e,
351 )
352 })?;
353 Ok(serializer.into_inner().into())
354}
355
356fn json_decode<T: ReflectMessage + Default>(message: &[u8]) -> Result<T, TwirpError> {
357 let dynamic_message = dynamic_json_decode::<T>(message).map_err(|e| {
358 TwirpError::wrap(
359 TwirpErrorCode::Malformed,
360 format!("Failed to parse JSON response: {e}"),
361 e,
362 )
363 })?;
364 dynamic_message.transcode_to().map_err(|e| {
365 TwirpError::internal(format!(
366 "Internal error while parsing the JSON response: {e}"
367 ))
368 })
369}
370
371fn dynamic_json_decode<T: ReflectMessage + Default>(
372 message: &[u8],
373) -> Result<DynamicMessage, serde_json::Error> {
374 let mut deserializer = serde_json::Deserializer::from_slice(message);
375 let dynamic_message = DynamicMessage::deserialize_with_options(
376 T::default().descriptor(),
377 &mut deserializer,
378 &DeserializeOptions::new().deny_unknown_fields(false),
381 )?;
382 deserializer.end()?;
383 Ok(dynamic_message)
384}
385
386#[cfg(feature = "reqwest-012")]
388#[derive(Clone, Default)]
389pub struct Reqwest012Service(reqwest_012::Client);
390
391#[cfg(feature = "reqwest-012")]
392impl Reqwest012Service {
393 #[inline]
394 pub fn new() -> Self {
395 reqwest_012::Client::new().into()
396 }
397}
398
399#[cfg(feature = "reqwest-012")]
400impl From<reqwest_012::Client> for Reqwest012Service {
401 #[inline]
402 fn from(client: reqwest_012::Client) -> Self {
403 Self(client)
404 }
405}
406
407#[cfg(feature = "reqwest-012")]
408impl<B: Into<reqwest_012::Body>> Service<Request<B>> for Reqwest012Service {
409 type Response = Response<reqwest_012::Body>;
410 type Error = reqwest_012::Error;
411 type Future = Pin<
412 Box<dyn Future<Output = Result<Response<reqwest_012::Body>, reqwest_012::Error>> + Send>,
413 >;
414
415 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
416 self.0.poll_ready(cx)
417 }
418
419 fn call(&mut self, req: Request<B>) -> Self::Future {
420 let req = match req.try_into() {
421 Ok(req) => req,
422 Err(e) => return Box::pin(async move { Err(e) }),
423 };
424 let future = self.0.call(req);
425 Box::pin(async move { Ok(future.await?.into()) })
426 }
427}
428
429#[cfg(feature = "reqwest-012")]
430impl From<TwirpRequestBody> for reqwest_012::Body {
431 #[inline]
432 fn from(body: TwirpRequestBody) -> Self {
433 body.0.into()
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use prost_reflect::ReflectMessage;
441 use prost_reflect::prost::Message;
442 use prost_reflect::prost_types::Timestamp;
443 use std::future::Ready;
444 use std::io;
445 use std::task::{Context, Poll};
446 use tower::service_fn;
447
448 const FILE_DESCRIPTOR_SET_BYTES: &[u8] = &[
449 10, 107, 10, 21, 101, 120, 97, 109, 112, 108, 101, 95, 115, 101, 114, 118, 105, 99, 101,
450 46, 112, 114, 111, 116, 111, 18, 7, 112, 97, 99, 107, 97, 103, 101, 34, 11, 10, 9, 77, 121,
451 77, 101, 115, 115, 97, 103, 101, 74, 52, 10, 6, 18, 4, 0, 0, 5, 1, 10, 8, 10, 1, 12, 18, 3,
452 0, 0, 18, 10, 8, 10, 1, 2, 18, 3, 2, 0, 16, 10, 10, 10, 2, 4, 0, 18, 4, 4, 0, 5, 1, 10, 10,
453 10, 3, 4, 0, 1, 18, 3, 4, 8, 17, 98, 6, 112, 114, 111, 116, 111, 51,
454 ];
455
456 #[derive(Message, ReflectMessage, PartialEq)]
457 #[prost_reflect(
458 file_descriptor_set_bytes = "crate::tests::FILE_DESCRIPTOR_SET_BYTES",
459 message_name = "package.MyMessage"
460 )]
461 pub struct MyMessage {}
462
463 #[tokio::test]
464 async fn not_ready_service() -> Result<(), Box<dyn Error>> {
465 #[derive(Clone)]
466 struct NotReadyService;
467
468 impl<S> Service<S> for NotReadyService {
469 type Response = Response<String>;
470 type Error = TwirpError;
471 type Future = Ready<Result<Response<String>, TwirpError>>;
472
473 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
474 Poll::Ready(Err(TwirpError::internal("foo")))
475 }
476
477 fn call(&mut self, _: S) -> Self::Future {
478 unimplemented!()
479 }
480 }
481
482 let client = TwirpHttpClient::new(NotReadyService);
483 assert_eq!(
484 client
485 .call::<_, Timestamp>("", &Timestamp::default())
486 .await
487 .unwrap_err()
488 .to_string(),
489 "Twirp Unknown error: Service is not ready: Twirp Internal error: foo"
490 );
491 Ok(())
492 }
493
494 #[tokio::test]
495 async fn json_request_without_base_ok() -> Result<(), Box<dyn Error>> {
496 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
497 assert_eq!(request.method(), Method::POST);
498 assert_eq!(request.uri(), "/foo");
499 Ok::<_, TwirpError>(
500 Response::builder()
501 .header(CONTENT_TYPE, APPLICATION_JSON)
502 .body("\"1970-01-01T00:00:10Z\"".to_string())
503 .unwrap(),
504 )
505 });
506
507 let mut client = TwirpHttpClient::new(service);
508 client.use_json();
509 let response = client
510 .call::<_, Timestamp>(
511 "/foo",
512 &Timestamp {
513 seconds: 10,
514 nanos: 0,
515 },
516 )
517 .await?;
518 assert_eq!(
519 response,
520 Timestamp {
521 seconds: 10,
522 nanos: 0
523 }
524 );
525 Ok(())
526 }
527
528 #[tokio::test]
529 async fn json_request_with_unknown_fields_ok() -> Result<(), Box<dyn Error>> {
530 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
531 assert_eq!(request.method(), Method::POST);
532 assert_eq!(request.uri(), "/foo");
533 Ok::<_, TwirpError>(
534 Response::builder()
535 .header(CONTENT_TYPE, APPLICATION_JSON)
536 .body("{\"unknown_field\":\"ignored\"}".to_string())
537 .unwrap(),
538 )
539 });
540
541 let mut client = TwirpHttpClient::new(service);
542 client.use_json();
543 let response = client
544 .call::<_, MyMessage>("/foo", &MyMessage::default())
545 .await?;
546 assert_eq!(response, MyMessage::default());
547 Ok(())
548 }
549
550 #[cfg(feature = "reqwest-012")]
551 #[tokio::test]
552 async fn binary_request_without_base_ok() -> Result<(), Box<dyn Error>> {
553 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
554 assert_eq!(request.method(), Method::POST);
555 assert_eq!(request.uri(), "/foo");
556 Ok::<_, TwirpError>(
557 Response::builder()
558 .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
559 .body(reqwest_012::Body::from(
560 Timestamp {
561 seconds: 10,
562 nanos: 0,
563 }
564 .encode_to_vec(),
565 ))
566 .unwrap(),
567 )
568 });
569
570 let response = TwirpHttpClient::new(service)
571 .call::<_, Timestamp>(
572 "/foo",
573 &Timestamp {
574 seconds: 10,
575 nanos: 0,
576 },
577 )
578 .await?;
579 assert_eq!(
580 response,
581 Timestamp {
582 seconds: 10,
583 nanos: 0
584 }
585 );
586 Ok(())
587 }
588
589 #[tokio::test]
590 async fn request_with_base_twirp_error() -> Result<(), Box<dyn Error>> {
591 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
592 assert_eq!(request.method(), Method::POST);
593 assert_eq!(request.uri(), "http://example.com/twirp/foo");
594 Ok::<Response<String>, TwirpError>(TwirpError::not_found("not found").into())
595 });
596
597 let response_error = TwirpHttpClient::new_with_base(service, "http://example.com/twirp")
598 .call::<_, Timestamp>(
599 "/foo",
600 &Timestamp {
601 seconds: 10,
602 nanos: 0,
603 },
604 )
605 .await
606 .unwrap_err();
607 assert_eq!(response_error, TwirpError::not_found("not found"));
608 Ok(())
609 }
610
611 #[tokio::test]
612 async fn request_with_base_other_error() -> Result<(), Box<dyn Error>> {
613 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
614 assert_eq!(request.method(), Method::POST);
615 assert_eq!(request.uri(), "http://example.com/twirp/foo");
616 Ok::<Response<String>, TwirpError>(
617 Response::builder()
618 .status(StatusCode::UNAUTHORIZED)
619 .body("foo".to_string())
620 .unwrap(),
621 )
622 });
623
624 let response_error = TwirpHttpClient::new_with_base(service, "http://example.com/twirp/")
625 .call::<_, Timestamp>(
626 "/foo",
627 &Timestamp {
628 seconds: 10,
629 nanos: 0,
630 },
631 )
632 .await
633 .unwrap_err();
634 assert_eq!(response_error, TwirpError::unauthenticated("foo"));
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn request_transport_error() -> Result<(), Box<dyn Error>> {
640 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
641 assert_eq!(request.method(), Method::POST);
642 assert_eq!(request.uri(), "/foo");
643 Err::<Response<String>, _>(io::Error::other("Transport error"))
644 });
645
646 let response_error = TwirpHttpClient::new(service)
647 .call::<_, Timestamp>(
648 "/foo",
649 &Timestamp {
650 seconds: 10,
651 nanos: 0,
652 },
653 )
654 .await
655 .unwrap_err();
656 assert_eq!(
657 response_error,
658 TwirpError::new(
659 TwirpErrorCode::Unknown,
660 "Transport error during the request: Transport error"
661 )
662 );
663 Ok(())
664 }
665
666 #[tokio::test]
667 async fn wrong_content_type_response() -> Result<(), Box<dyn Error>> {
668 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
669 assert_eq!(request.method(), Method::POST);
670 assert_eq!(request.uri(), "/foo");
671 Ok::<Response<String>, TwirpError>(
672 Response::builder()
673 .status(StatusCode::OK)
674 .header(CONTENT_TYPE, "foo/bar")
675 .body("foo".into())
676 .unwrap(),
677 )
678 });
679
680 let response_error = TwirpHttpClient::new(service)
681 .call::<_, Timestamp>(
682 "/foo",
683 &Timestamp {
684 seconds: 10,
685 nanos: 0,
686 },
687 )
688 .await
689 .unwrap_err();
690 assert_eq!(
691 response_error,
692 TwirpError::malformed("Unsupported response content-type: foo/bar")
693 );
694 Ok(())
695 }
696
697 #[tokio::test]
698 async fn invalid_protobuf_response() -> Result<(), Box<dyn Error>> {
699 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
700 assert_eq!(request.method(), Method::POST);
701 assert_eq!(request.uri(), "/foo");
702 Ok::<Response<String>, TwirpError>(
703 Response::builder()
704 .status(StatusCode::OK)
705 .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
706 .body("azerty".into())
707 .unwrap(),
708 )
709 });
710
711 let mut client = TwirpHttpClient::new(service);
712 client.use_json();
713 let response_error = client
714 .call::<_, Timestamp>(
715 "/foo",
716 &Timestamp {
717 seconds: 10,
718 nanos: 0,
719 },
720 )
721 .await
722 .unwrap_err();
723 assert_eq!(
724 response_error,
725 TwirpError::malformed(
726 "Bad response binary protobuf encoding: failed to decode Protobuf message: buffer underflow"
727 )
728 );
729 Ok(())
730 }
731
732 #[tokio::test]
733 async fn invalid_json_response() -> Result<(), Box<dyn Error>> {
734 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
735 assert_eq!(request.method(), Method::POST);
736 assert_eq!(request.uri(), "/foo");
737 Ok::<Response<String>, TwirpError>(
738 Response::builder()
739 .status(StatusCode::OK)
740 .header(CONTENT_TYPE, APPLICATION_JSON)
741 .body("foo".into())
742 .unwrap(),
743 )
744 });
745
746 let mut client = TwirpHttpClient::new(service);
747 client.use_json();
748 let response_error = client
749 .call::<_, Timestamp>(
750 "/foo",
751 &Timestamp {
752 seconds: 10,
753 nanos: 0,
754 },
755 )
756 .await
757 .unwrap_err();
758 assert_eq!(
759 response_error,
760 TwirpError::malformed(
761 "Failed to parse JSON response: expected ident at line 1 column 2"
762 )
763 );
764 Ok(())
765 }
766
767 #[tokio::test]
768 async fn response_future_is_send() {
769 fn is_send<T: Send>(_: T) {}
770
771 let service = service_fn(|_: Request<TwirpRequestBody>| async move {
772 Ok::<_, TwirpError>(Response::new(String::new()))
773 });
774 let client = TwirpHttpClient::new(service);
775
776 is_send(client.call::<_, Timestamp>("/foo", &Timestamp::default()));
778 }
779}