1#![doc = include_str!("../README.md")]
2#![doc(
3 test(attr(deny(warnings))),
4 html_favicon_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png",
5 html_logo_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8
9use http::header::CONTENT_TYPE;
10use http::{HeaderValue, Method, Request, Response, StatusCode};
11use http_body::{Body, Frame, SizeHint};
12use http_body_util::BodyExt;
13use prost_reflect::bytes::{Buf, Bytes, BytesMut};
14use prost_reflect::{DynamicMessage, ReflectMessage};
15use serde::Serialize;
16use std::convert::Infallible;
17use std::error::Error;
18use std::future::poll_fn;
19#[cfg(feature = "reqwest-012")]
20use std::future::Future;
21use std::mem::take;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24use tower_service::Service;
25pub use twurst_error::{TwirpError, TwirpErrorCode};
26
27const APPLICATION_JSON: HeaderValue = HeaderValue::from_static("application/json");
28const APPLICATION_PROTOBUF: HeaderValue = HeaderValue::from_static("application/protobuf");
29
30#[derive(Clone)]
42pub struct TwirpHttpClient<S: TwirpHttpService> {
43 service: S,
44 base_url: Option<String>,
45 use_json: bool,
46}
47
48#[cfg(feature = "reqwest-012")]
49impl TwirpHttpClient<Reqwest012Service> {
50 pub fn new_using_reqwest_012(base_url: impl Into<String>) -> Self {
60 Self::new_with_reqwest_012_client(reqwest_012::Client::new(), base_url)
61 }
62
63 pub fn new_with_reqwest_012_client(
75 client: reqwest_012::Client,
76 base_url: impl Into<String>,
77 ) -> Self {
78 Self::new_with_base(Reqwest012Service(client), base_url)
79 }
80}
81
82impl<S: TwirpHttpService> TwirpHttpClient<S> {
83 pub fn new_with_base(service: S, base_url: impl Into<String>) -> Self {
99 let mut base_url = base_url.into();
100 if base_url.ends_with('/') {
102 base_url.pop();
103 }
104 Self {
105 service,
106 base_url: Some(base_url),
107 use_json: false,
108 }
109 }
110
111 pub fn new(service: S) -> Self {
124 Self {
125 service,
126 base_url: None,
127 use_json: false,
128 }
129 }
130
131 pub fn use_json(&mut self) {
133 self.use_json = true;
134 }
135
136 pub fn use_binary_protobuf(&mut self) {
138 self.use_json = false;
139 }
140
141 pub async fn call<I: ReflectMessage, O: ReflectMessage + Default>(
145 &self,
146 path: &str,
147 request: &I,
148 ) -> Result<O, TwirpError> {
149 self.service.ready().await.map_err(|e| {
151 TwirpError::wrap(
152 TwirpErrorCode::Unknown,
153 format!("Service is not ready: {e}"),
154 e,
155 )
156 })?;
157 let request = self.build_request(path, request)?;
158 let response = self.service.call(request).await.map_err(|e| {
159 TwirpError::wrap(
160 TwirpErrorCode::Unknown,
161 format!("Transport error during the request: {e}"),
162 e,
163 )
164 })?;
165 self.extract_response(response).await
166 }
167
168 fn build_request<T: ReflectMessage>(
169 &self,
170 path: &str,
171 message: &T,
172 ) -> Result<Request<TwirpRequestBody>, TwirpError> {
173 let mut request_builder = Request::builder().method(Method::POST);
174 request_builder = if let Some(base_url) = &self.base_url {
175 request_builder.uri(format!("{}{}", base_url, path))
176 } else {
177 request_builder.uri(path)
178 };
179 if self.use_json {
180 request_builder
181 .header(CONTENT_TYPE, APPLICATION_JSON)
182 .body(json_encode(message)?.into())
183 } else {
184 let mut buffer = BytesMut::with_capacity(message.encoded_len());
185 message.encode(&mut buffer).map_err(|e| {
186 TwirpError::wrap(
187 TwirpErrorCode::Internal,
188 format!("Failed to serialize to protobuf: {e}"),
189 e,
190 )
191 })?;
192 request_builder
193 .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
194 .body(Bytes::from(buffer).into())
195 }
196 .map_err(|e| {
197 TwirpError::wrap(
198 TwirpErrorCode::Malformed,
199 format!("Failed to construct request: {e}"),
200 e,
201 )
202 })
203 }
204
205 async fn extract_response<T: ReflectMessage + Default>(
206 &self,
207 response: Response<S::ResponseBody>,
208 ) -> Result<T, TwirpError> {
209 let (parts, body) = response.into_parts();
212 let body = body.collect().await.map_err(|e| {
213 TwirpError::wrap(
214 TwirpErrorCode::Internal,
215 format!("Failed to load request body: {e}"),
216 e,
217 )
218 })?;
219 let response = Response::from_parts(parts, body);
220
221 if response.status() != StatusCode::OK {
223 return Err(response.map(|b| b.to_bytes()).into());
224 }
225
226 let content_type = response.headers().get(CONTENT_TYPE).cloned();
228 let body = response.into_body();
229 if content_type == Some(APPLICATION_PROTOBUF) {
230 T::decode(body.aggregate()).map_err(|e| {
231 TwirpError::wrap(
232 TwirpErrorCode::Malformed,
233 format!("Bad response binary protobuf encoding: {e}"),
234 e,
235 )
236 })
237 } else if content_type == Some(APPLICATION_JSON) {
238 json_decode(&body.to_bytes())
239 } else if let Some(content_type) = content_type {
240 Err(TwirpError::malformed(format!(
241 "Unsupported response content-type: {}",
242 String::from_utf8_lossy(content_type.as_bytes())
243 )))
244 } else {
245 Err(TwirpError::malformed("No content-type in the response"))
246 }
247 }
248}
249
250#[trait_variant::make(Send)]
254pub trait TwirpHttpService: 'static {
255 type ResponseBody: Body<Error: Error + Send + Sync>;
256 type Error: Error + Send + Sync + 'static;
257
258 async fn ready(&self) -> Result<(), Self::Error>;
259
260 async fn call(
261 &self,
262 request: Request<TwirpRequestBody>,
263 ) -> Result<Response<Self::ResponseBody>, Self::Error>;
264}
265
266impl<
267 S: Service<
268 Request<TwirpRequestBody>,
269 Error: Error + Send + Sync + 'static,
270 Response = Response<RespBody>,
271 Future: Send,
272 > + Clone
273 + Send
274 + Sync
275 + 'static,
276 RespBody: Body<Error: Error + Send + Sync + 'static>,
277 > TwirpHttpService for S
278{
279 type ResponseBody = RespBody;
280 type Error = S::Error;
281
282 async fn ready(&self) -> Result<(), Self::Error> {
283 poll_fn(|cx| Service::poll_ready(&mut self.clone(), cx)).await
284 }
285
286 async fn call(
287 &self,
288 request: Request<TwirpRequestBody>,
289 ) -> Result<Response<RespBody>, S::Error> {
290 Service::call(&mut self.clone(), request).await
291 }
292}
293
294pub struct TwirpRequestBody(Bytes);
298
299impl From<Bytes> for TwirpRequestBody {
300 #[inline]
301 fn from(body: Bytes) -> Self {
302 Self(body)
303 }
304}
305
306impl From<TwirpRequestBody> for Bytes {
307 #[inline]
308 fn from(body: TwirpRequestBody) -> Self {
309 body.0
310 }
311}
312
313impl Body for TwirpRequestBody {
314 type Data = Bytes;
315 type Error = Infallible;
316
317 #[inline]
318 fn poll_frame(
319 mut self: Pin<&mut Self>,
320 _cx: &mut Context<'_>,
321 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
322 let data = take(&mut self.0);
323 Poll::Ready(if data.has_remaining() {
324 Some(Ok(Frame::data(data)))
325 } else {
326 None
327 })
328 }
329
330 #[inline]
331 fn is_end_stream(&self) -> bool {
332 !self.0.has_remaining()
333 }
334
335 #[inline]
336 fn size_hint(&self) -> SizeHint {
337 SizeHint::with_exact(self.0.remaining() as u64)
338 }
339}
340
341fn json_encode<T: ReflectMessage>(message: &T) -> Result<Bytes, TwirpError> {
342 let mut serializer = serde_json::Serializer::new(Vec::new());
343 message
344 .transcode_to_dynamic()
345 .serialize(&mut serializer)
346 .map_err(|e| {
347 TwirpError::wrap(
348 TwirpErrorCode::Malformed,
349 format!("Failed to serialize request to JSON: {e}"),
350 e,
351 )
352 })?;
353 Ok(serializer.into_inner().into())
354}
355
356fn json_decode<T: ReflectMessage + Default>(message: &[u8]) -> Result<T, TwirpError> {
357 let dynamic_message = dynamic_json_decode::<T>(message).map_err(|e| {
358 TwirpError::wrap(
359 TwirpErrorCode::Malformed,
360 format!("Failed to parse JSON response: {e}"),
361 e,
362 )
363 })?;
364 dynamic_message.transcode_to().map_err(|e| {
365 TwirpError::internal(format!(
366 "Internal error while parsing the JSON response: {e}"
367 ))
368 })
369}
370
371fn dynamic_json_decode<T: ReflectMessage + Default>(
372 message: &[u8],
373) -> Result<DynamicMessage, serde_json::Error> {
374 let mut deserializer = serde_json::Deserializer::from_slice(message);
375 let dynamic_message =
376 DynamicMessage::deserialize(T::default().descriptor(), &mut deserializer)?;
377 deserializer.end()?;
378 Ok(dynamic_message)
379}
380
381#[cfg(feature = "reqwest-012")]
383#[derive(Clone, Default)]
384pub struct Reqwest012Service(reqwest_012::Client);
385
386#[cfg(feature = "reqwest-012")]
387impl Reqwest012Service {
388 #[inline]
389 pub fn new() -> Self {
390 reqwest_012::Client::new().into()
391 }
392}
393
394#[cfg(feature = "reqwest-012")]
395impl From<reqwest_012::Client> for Reqwest012Service {
396 #[inline]
397 fn from(client: reqwest_012::Client) -> Self {
398 Self(client)
399 }
400}
401
402#[cfg(feature = "reqwest-012")]
403impl<B: Into<reqwest_012::Body>> Service<Request<B>> for Reqwest012Service {
404 type Response = Response<reqwest_012::Body>;
405 type Error = reqwest_012::Error;
406 type Future = Pin<
407 Box<dyn Future<Output = Result<Response<reqwest_012::Body>, reqwest_012::Error>> + Send>,
408 >;
409
410 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411 self.0.poll_ready(cx)
412 }
413
414 fn call(&mut self, req: Request<B>) -> Self::Future {
415 let req = match req.try_into() {
416 Ok(req) => req,
417 Err(e) => return Box::pin(async move { Err(e) }),
418 };
419 let future = self.0.call(req);
420 Box::pin(async move { Ok(future.await?.into()) })
421 }
422}
423
424#[cfg(feature = "reqwest-012")]
425impl From<TwirpRequestBody> for reqwest_012::Body {
426 #[inline]
427 fn from(body: TwirpRequestBody) -> Self {
428 body.0.into()
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 #[cfg(feature = "reqwest-012")]
436 use prost_reflect::prost::Message;
437 use prost_reflect::prost_types::Timestamp;
438 use std::future::Ready;
439 use std::io;
440 use std::task::{Context, Poll};
441 use tower::service_fn;
442
443 #[tokio::test]
444 async fn not_ready_service() -> Result<(), Box<dyn Error>> {
445 #[derive(Clone)]
446 struct NotReadyService;
447
448 impl<S> Service<S> for NotReadyService {
449 type Response = Response<String>;
450 type Error = TwirpError;
451 type Future = Ready<Result<Response<String>, TwirpError>>;
452
453 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
454 Poll::Ready(Err(TwirpError::internal("foo")))
455 }
456
457 fn call(&mut self, _: S) -> Self::Future {
458 unimplemented!()
459 }
460 }
461
462 let client = TwirpHttpClient::new(NotReadyService);
463 assert_eq!(
464 client
465 .call::<_, Timestamp>("", &Timestamp::default())
466 .await
467 .unwrap_err()
468 .to_string(),
469 "Twirp Unknown error: Service is not ready: Twirp Internal error: foo"
470 );
471 Ok(())
472 }
473
474 #[tokio::test]
475 async fn json_request_without_base_ok() -> Result<(), Box<dyn Error>> {
476 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
477 assert_eq!(request.method(), Method::POST);
478 assert_eq!(request.uri(), "/foo");
479 Ok::<_, TwirpError>(
480 Response::builder()
481 .header(CONTENT_TYPE, APPLICATION_JSON)
482 .body("\"1970-01-01T00:00:10Z\"".to_string())
483 .unwrap(),
484 )
485 });
486
487 let mut client = TwirpHttpClient::new(service);
488 client.use_json();
489 let response = client
490 .call::<_, Timestamp>(
491 "/foo",
492 &Timestamp {
493 seconds: 10,
494 nanos: 0,
495 },
496 )
497 .await?;
498 assert_eq!(
499 response,
500 Timestamp {
501 seconds: 10,
502 nanos: 0
503 }
504 );
505 Ok(())
506 }
507
508 #[cfg(feature = "reqwest-012")]
509 #[tokio::test]
510 async fn binary_request_without_base_ok() -> Result<(), Box<dyn Error>> {
511 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
512 assert_eq!(request.method(), Method::POST);
513 assert_eq!(request.uri(), "/foo");
514 Ok::<_, TwirpError>(
515 Response::builder()
516 .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
517 .body(reqwest_012::Body::from(
518 Timestamp {
519 seconds: 10,
520 nanos: 0,
521 }
522 .encode_to_vec(),
523 ))
524 .unwrap(),
525 )
526 });
527
528 let response = TwirpHttpClient::new(service)
529 .call::<_, Timestamp>(
530 "/foo",
531 &Timestamp {
532 seconds: 10,
533 nanos: 0,
534 },
535 )
536 .await?;
537 assert_eq!(
538 response,
539 Timestamp {
540 seconds: 10,
541 nanos: 0
542 }
543 );
544 Ok(())
545 }
546
547 #[tokio::test]
548 async fn request_with_base_twirp_error() -> Result<(), Box<dyn Error>> {
549 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
550 assert_eq!(request.method(), Method::POST);
551 assert_eq!(request.uri(), "http://example.com/twirp/foo");
552 Ok::<Response<String>, TwirpError>(TwirpError::not_found("not found").into())
553 });
554
555 let response_error = TwirpHttpClient::new_with_base(service, "http://example.com/twirp")
556 .call::<_, Timestamp>(
557 "/foo",
558 &Timestamp {
559 seconds: 10,
560 nanos: 0,
561 },
562 )
563 .await
564 .unwrap_err();
565 assert_eq!(response_error, TwirpError::not_found("not found"));
566 Ok(())
567 }
568
569 #[tokio::test]
570 async fn request_with_base_other_error() -> Result<(), Box<dyn Error>> {
571 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
572 assert_eq!(request.method(), Method::POST);
573 assert_eq!(request.uri(), "http://example.com/twirp/foo");
574 Ok::<Response<String>, TwirpError>(
575 Response::builder()
576 .status(StatusCode::UNAUTHORIZED)
577 .body("foo".to_string())
578 .unwrap(),
579 )
580 });
581
582 let response_error = TwirpHttpClient::new_with_base(service, "http://example.com/twirp/")
583 .call::<_, Timestamp>(
584 "/foo",
585 &Timestamp {
586 seconds: 10,
587 nanos: 0,
588 },
589 )
590 .await
591 .unwrap_err();
592 assert_eq!(response_error, TwirpError::unauthenticated("foo"));
593 Ok(())
594 }
595
596 #[tokio::test]
597 async fn request_transport_error() -> Result<(), Box<dyn Error>> {
598 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
599 assert_eq!(request.method(), Method::POST);
600 assert_eq!(request.uri(), "/foo");
601 Err::<Response<String>, _>(io::Error::other("Transport error"))
602 });
603
604 let response_error = TwirpHttpClient::new(service)
605 .call::<_, Timestamp>(
606 "/foo",
607 &Timestamp {
608 seconds: 10,
609 nanos: 0,
610 },
611 )
612 .await
613 .unwrap_err();
614 assert_eq!(
615 response_error,
616 TwirpError::new(
617 TwirpErrorCode::Unknown,
618 "Transport error during the request: Transport error"
619 )
620 );
621 Ok(())
622 }
623
624 #[tokio::test]
625 async fn wrong_content_type_response() -> Result<(), Box<dyn Error>> {
626 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
627 assert_eq!(request.method(), Method::POST);
628 assert_eq!(request.uri(), "/foo");
629 Ok::<Response<String>, TwirpError>(
630 Response::builder()
631 .status(StatusCode::OK)
632 .header(CONTENT_TYPE, "foo/bar")
633 .body("foo".into())
634 .unwrap(),
635 )
636 });
637
638 let response_error = TwirpHttpClient::new(service)
639 .call::<_, Timestamp>(
640 "/foo",
641 &Timestamp {
642 seconds: 10,
643 nanos: 0,
644 },
645 )
646 .await
647 .unwrap_err();
648 assert_eq!(
649 response_error,
650 TwirpError::malformed("Unsupported response content-type: foo/bar")
651 );
652 Ok(())
653 }
654
655 #[tokio::test]
656 async fn invalid_protobuf_response() -> Result<(), Box<dyn Error>> {
657 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
658 assert_eq!(request.method(), Method::POST);
659 assert_eq!(request.uri(), "/foo");
660 Ok::<Response<String>, TwirpError>(
661 Response::builder()
662 .status(StatusCode::OK)
663 .header(CONTENT_TYPE, APPLICATION_PROTOBUF)
664 .body("azerty".into())
665 .unwrap(),
666 )
667 });
668
669 let mut client = TwirpHttpClient::new(service);
670 client.use_json();
671 let response_error = client
672 .call::<_, Timestamp>(
673 "/foo",
674 &Timestamp {
675 seconds: 10,
676 nanos: 0,
677 },
678 )
679 .await
680 .unwrap_err();
681 assert_eq!(
682 response_error,
683 TwirpError::malformed("Bad response binary protobuf encoding: failed to decode Protobuf message: buffer underflow")
684 );
685 Ok(())
686 }
687
688 #[tokio::test]
689 async fn invalid_json_response() -> Result<(), Box<dyn Error>> {
690 let service = service_fn(|request: Request<TwirpRequestBody>| async move {
691 assert_eq!(request.method(), Method::POST);
692 assert_eq!(request.uri(), "/foo");
693 Ok::<Response<String>, TwirpError>(
694 Response::builder()
695 .status(StatusCode::OK)
696 .header(CONTENT_TYPE, APPLICATION_JSON)
697 .body("foo".into())
698 .unwrap(),
699 )
700 });
701
702 let mut client = TwirpHttpClient::new(service);
703 client.use_json();
704 let response_error = client
705 .call::<_, Timestamp>(
706 "/foo",
707 &Timestamp {
708 seconds: 10,
709 nanos: 0,
710 },
711 )
712 .await
713 .unwrap_err();
714 assert_eq!(
715 response_error,
716 TwirpError::malformed(
717 "Failed to parse JSON response: expected ident at line 1 column 2"
718 )
719 );
720 Ok(())
721 }
722
723 #[tokio::test]
724 async fn response_future_is_send() {
725 fn is_send<T: Send>(_: T) {}
726
727 let service = service_fn(|_: Request<TwirpRequestBody>| async move {
728 Ok::<_, TwirpError>(Response::new(String::new()))
729 });
730 let client = TwirpHttpClient::new(service);
731
732 is_send(client.call::<_, Timestamp>("/foo", &Timestamp::default()));
734 }
735}