Skip to main content

temporalio_client/
callback_based.rs

1//! This module implements support for callback-based gRPC service that has a callback invoked for
2//! every gRPC call instead of directly using the network.
3
4use anyhow::anyhow;
5use bytes::{BufMut, BytesMut};
6use futures_util::{future::BoxFuture, stream};
7use http::{HeaderMap, Request, Response};
8use http_body_util::{BodyExt, StreamBody, combinators::BoxBody};
9use hyper::body::{Bytes, Frame};
10use std::{
11    sync::Arc,
12    task::{Context, Poll},
13};
14use tonic::{Status, metadata::GRPC_CONTENT_TYPE};
15use tower::Service;
16
17/// gRPC request for use by a callback.
18#[derive(Debug)]
19pub struct GrpcRequest {
20    /// Fully qualified gRPC service name.
21    pub service: String,
22    /// RPC name.
23    pub rpc: String,
24    /// Request headers.
25    pub headers: HeaderMap,
26    /// Protobuf bytes of the request.
27    pub proto: Bytes,
28}
29
30/// Successful gRPC response returned by a callback.
31#[derive(Debug)]
32pub struct GrpcSuccessResponse {
33    /// Response headers.
34    pub headers: HeaderMap,
35
36    /// Response proto bytes.
37    pub proto: Vec<u8>,
38}
39
40/// gRPC service that invokes the given callback on each call.
41#[derive(Clone)]
42pub struct CallbackBasedGrpcService {
43    /// Callback to invoke on each RPC call.
44    #[allow(clippy::type_complexity)] // Signature is not that complex
45    pub callback: Arc<
46        dyn Fn(GrpcRequest) -> BoxFuture<'static, Result<GrpcSuccessResponse, Status>>
47            + Send
48            + Sync,
49    >,
50}
51impl std::fmt::Debug for CallbackBasedGrpcService {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("CallbackBasedGrpcService").finish()
54    }
55}
56
57impl Service<Request<tonic::body::Body>> for CallbackBasedGrpcService {
58    type Response = http::Response<tonic::body::Body>;
59    type Error = anyhow::Error;
60    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
61
62    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        Poll::Ready(Ok(()))
64    }
65
66    fn call(&mut self, req: Request<tonic::body::Body>) -> Self::Future {
67        let callback = self.callback.clone();
68
69        Box::pin(async move {
70            // Build req
71            let (parts, body) = req.into_parts();
72            let mut path_parts = parts.uri.path().trim_start_matches('/').split('/');
73            let req_body = body.collect().await.map_err(|e| anyhow!(e))?.to_bytes();
74            // Body is flag saying whether compressed (we do not support that), then 32-bit length,
75            // then the actual proto.
76            if req_body.len() < 5 {
77                return Err(anyhow!("Too few request bytes: {}", req_body.len()));
78            } else if req_body[0] != 0 {
79                return Err(anyhow!("Compression not supported"));
80            }
81            let req_proto_len =
82                u32::from_be_bytes([req_body[1], req_body[2], req_body[3], req_body[4]]) as usize;
83            if req_body.len() < 5 + req_proto_len {
84                return Err(anyhow!(
85                    "Expected request body length at least {}, got {}",
86                    5 + req_proto_len,
87                    req_body.len()
88                ));
89            }
90            let req = GrpcRequest {
91                service: path_parts.next().unwrap_or_default().to_owned(),
92                rpc: path_parts.next().unwrap_or_default().to_owned(),
93                headers: parts.headers,
94                proto: req_body.slice(5..5 + req_proto_len),
95            };
96
97            // Invoke and handle response
98            match (callback)(req).await {
99                Ok(success) => {
100                    // Create body bytes which requires a flag saying whether compressed, then
101                    // message len, then actual message. So we create a Bytes for those 5 prepend
102                    // parts, then stream it alongside the user-provided Vec. This allows us to
103                    // avoid copying the vec
104                    let mut body_prepend = BytesMut::with_capacity(5);
105                    body_prepend.put_u8(0); // 0 means no compression
106                    body_prepend.put_u32(success.proto.len() as u32);
107                    let stream = stream::iter(vec![
108                        Ok::<_, Status>(Frame::data(Bytes::from(body_prepend))),
109                        Ok::<_, Status>(Frame::data(Bytes::from(success.proto))),
110                    ]);
111                    let stream_body = StreamBody::new(stream);
112                    let full_body = BoxBody::new(stream_body).boxed();
113
114                    // Build response appending headers
115                    let mut resp_builder = Response::builder()
116                        .status(200)
117                        .header(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
118                    for (key, value) in success.headers.iter() {
119                        resp_builder = resp_builder.header(key, value);
120                    }
121                    Ok(resp_builder
122                        .body(tonic::body::Body::new(full_body))
123                        .map_err(|e| anyhow!(e))?)
124                }
125                Err(status) => Ok(status.into_http()),
126            }
127        })
128    }
129}