temporalio_client/
callback_based.rs1use anyhow::anyhow;
5use bytes::{BufMut, BytesMut};
6use futures_util::{future::BoxFuture, stream};
7use http::{HeaderMap, Request, Response};
8use http_body_util::{BodyExt, StreamBody, combinators::BoxBody};
9use hyper::body::{Bytes, Frame};
10use std::{
11 sync::Arc,
12 task::{Context, Poll},
13};
14use tonic::{Status, metadata::GRPC_CONTENT_TYPE};
15use tower::Service;
16
17#[derive(Debug)]
19pub struct GrpcRequest {
20 pub service: String,
22 pub rpc: String,
24 pub headers: HeaderMap,
26 pub proto: Bytes,
28}
29
30#[derive(Debug)]
32pub struct GrpcSuccessResponse {
33 pub headers: HeaderMap,
35
36 pub proto: Vec<u8>,
38}
39
40#[derive(Clone)]
42pub struct CallbackBasedGrpcService {
43 #[allow(clippy::type_complexity)] pub callback: Arc<
46 dyn Fn(GrpcRequest) -> BoxFuture<'static, Result<GrpcSuccessResponse, Status>>
47 + Send
48 + Sync,
49 >,
50}
51impl std::fmt::Debug for CallbackBasedGrpcService {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("CallbackBasedGrpcService").finish()
54 }
55}
56
57impl Service<Request<tonic::body::Body>> for CallbackBasedGrpcService {
58 type Response = http::Response<tonic::body::Body>;
59 type Error = anyhow::Error;
60 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
61
62 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 Poll::Ready(Ok(()))
64 }
65
66 fn call(&mut self, req: Request<tonic::body::Body>) -> Self::Future {
67 let callback = self.callback.clone();
68
69 Box::pin(async move {
70 let (parts, body) = req.into_parts();
72 let mut path_parts = parts.uri.path().trim_start_matches('/').split('/');
73 let req_body = body.collect().await.map_err(|e| anyhow!(e))?.to_bytes();
74 if req_body.len() < 5 {
77 return Err(anyhow!("Too few request bytes: {}", req_body.len()));
78 } else if req_body[0] != 0 {
79 return Err(anyhow!("Compression not supported"));
80 }
81 let req_proto_len =
82 u32::from_be_bytes([req_body[1], req_body[2], req_body[3], req_body[4]]) as usize;
83 if req_body.len() < 5 + req_proto_len {
84 return Err(anyhow!(
85 "Expected request body length at least {}, got {}",
86 5 + req_proto_len,
87 req_body.len()
88 ));
89 }
90 let req = GrpcRequest {
91 service: path_parts.next().unwrap_or_default().to_owned(),
92 rpc: path_parts.next().unwrap_or_default().to_owned(),
93 headers: parts.headers,
94 proto: req_body.slice(5..5 + req_proto_len),
95 };
96
97 match (callback)(req).await {
99 Ok(success) => {
100 let mut body_prepend = BytesMut::with_capacity(5);
105 body_prepend.put_u8(0); body_prepend.put_u32(success.proto.len() as u32);
107 let stream = stream::iter(vec![
108 Ok::<_, Status>(Frame::data(Bytes::from(body_prepend))),
109 Ok::<_, Status>(Frame::data(Bytes::from(success.proto))),
110 ]);
111 let stream_body = StreamBody::new(stream);
112 let full_body = BoxBody::new(stream_body).boxed();
113
114 let mut resp_builder = Response::builder()
116 .status(200)
117 .header(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
118 for (key, value) in success.headers.iter() {
119 resp_builder = resp_builder.header(key, value);
120 }
121 Ok(resp_builder
122 .body(tonic::body::Body::new(full_body))
123 .map_err(|e| anyhow!(e))?)
124 }
125 Err(status) => Ok(status.into_http()),
126 }
127 })
128 }
129}