Skip to main content

temporalio_client/
proxy.rs

1use base64::prelude::*;
2use http_body_util::Empty;
3use hyper::{body::Bytes, header};
4use hyper_util::{
5    client::legacy::{
6        Client,
7        connect::{Connected, Connection},
8    },
9    rt::{TokioExecutor, TokioIo},
10};
11use std::{
12    future::Future,
13    io,
14    pin::Pin,
15    task::{Context, Poll},
16};
17use tokio::{
18    io::{AsyncRead, AsyncWrite, ReadBuf},
19    net::TcpStream,
20};
21use tonic::transport::{Channel, Endpoint};
22use tower::{Service, service_fn};
23
24#[cfg(unix)]
25use tokio::net::UnixStream;
26
27/// Options for HTTP CONNECT proxy.
28#[derive(Clone, Debug)]
29pub struct HttpConnectProxyOptions {
30    /// The host:port to proxy through for TCP, or unix:/path/to/unix.sock for
31    /// Unix socket (which means it must start with "unix:/").
32    pub target_addr: String,
33    /// Optional HTTP basic auth for the proxy as user/pass tuple.
34    pub basic_auth: Option<(String, String)>,
35}
36
37impl HttpConnectProxyOptions {
38    /// Create a channel from the given endpoint that uses the HTTP CONNECT proxy.
39    pub async fn connect_endpoint(
40        &self,
41        endpoint: &Endpoint,
42    ) -> Result<Channel, tonic::transport::Error> {
43        let proxy_options = self.clone();
44        let svc_fn = service_fn(move |uri: tonic::transport::Uri| {
45            let proxy_options = proxy_options.clone();
46            async move { proxy_options.connect(uri).await }
47        });
48        endpoint.connect_with_connector(svc_fn).await
49    }
50
51    async fn connect(
52        &self,
53        uri: tonic::transport::Uri,
54    ) -> anyhow::Result<hyper::upgrade::Upgraded> {
55        debug!("Connecting to {} via proxy at {}", uri, self.target_addr);
56        // Create CONNECT request
57        let mut req_build = hyper::Request::builder().method("CONNECT").uri(uri);
58        if let Some((user, pass)) = &self.basic_auth {
59            let creds = BASE64_STANDARD.encode(format!("{user}:{pass}"));
60            req_build = req_build.header(header::PROXY_AUTHORIZATION, format!("Basic {creds}"));
61        }
62        let req = req_build.body(Empty::<Bytes>::new())?;
63
64        // We have to create a client with a specific connector because Hyper is
65        // not letting us change the HTTP/2 authority
66        let client = Client::builder(TokioExecutor::new())
67            .build(OverrideAddrConnector(self.target_addr.clone()));
68
69        // Send request
70        let res = client.request(req).await?;
71        if res.status().is_success() {
72            Ok(hyper::upgrade::on(res).await?)
73        } else {
74            Err(anyhow::anyhow!(
75                "CONNECT call failed with status: {}",
76                res.status()
77            ))
78        }
79    }
80}
81
82#[derive(Clone)]
83struct OverrideAddrConnector(String);
84
85impl Service<hyper::Uri> for OverrideAddrConnector {
86    type Response = TokioIo<ProxyStream>;
87
88    type Error = anyhow::Error;
89
90    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
91
92    fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
93        Poll::Ready(Ok(()))
94    }
95
96    fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
97        let target_addr = self.0.clone();
98        let fut = async move {
99            Ok(TokioIo::new(
100                ProxyStream::connect(target_addr.as_str()).await?,
101            ))
102        };
103        Box::pin(fut)
104    }
105}
106
107/// Visible only for tests
108#[doc(hidden)]
109pub enum ProxyStream {
110    Tcp(TcpStream),
111    #[cfg(unix)]
112    Unix(UnixStream),
113}
114
115impl ProxyStream {
116    async fn connect(target_addr: &str) -> anyhow::Result<Self> {
117        if target_addr.starts_with("unix:/") {
118            #[cfg(unix)]
119            {
120                Ok(ProxyStream::Unix(
121                    UnixStream::connect(&target_addr[5..]).await?,
122                ))
123            }
124            #[cfg(not(unix))]
125            {
126                Err(anyhow::anyhow!(
127                    "Unix sockets are not supported on this platform"
128                ))
129            }
130        } else {
131            Ok(ProxyStream::Tcp(TcpStream::connect(target_addr).await?))
132        }
133    }
134}
135
136impl AsyncRead for ProxyStream {
137    fn poll_read(
138        self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140        buf: &mut ReadBuf<'_>,
141    ) -> Poll<io::Result<()>> {
142        match self.get_mut() {
143            ProxyStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
144            #[cfg(unix)]
145            ProxyStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
146        }
147    }
148}
149
150impl AsyncWrite for ProxyStream {
151    fn poll_write(
152        self: Pin<&mut Self>,
153        cx: &mut Context<'_>,
154        buf: &[u8],
155    ) -> Poll<io::Result<usize>> {
156        match self.get_mut() {
157            ProxyStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
158            #[cfg(unix)]
159            ProxyStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
160        }
161    }
162
163    fn poll_write_vectored(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        bufs: &[io::IoSlice<'_>],
167    ) -> Poll<io::Result<usize>> {
168        match self.get_mut() {
169            ProxyStream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
170            #[cfg(unix)]
171            ProxyStream::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
172        }
173    }
174
175    fn is_write_vectored(&self) -> bool {
176        match self {
177            ProxyStream::Tcp(s) => s.is_write_vectored(),
178            #[cfg(unix)]
179            ProxyStream::Unix(s) => s.is_write_vectored(),
180        }
181    }
182
183    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184        match self.get_mut() {
185            ProxyStream::Tcp(s) => Pin::new(s).poll_flush(cx),
186            #[cfg(unix)]
187            ProxyStream::Unix(s) => Pin::new(s).poll_flush(cx),
188        }
189    }
190
191    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192        match self.get_mut() {
193            ProxyStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
194            #[cfg(unix)]
195            ProxyStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
196        }
197    }
198}
199
200impl Connection for ProxyStream {
201    fn connected(&self) -> Connected {
202        match self {
203            ProxyStream::Tcp(s) => s.connected(),
204            // There is no special connected metadata for Unix sockets
205            #[cfg(unix)]
206            ProxyStream::Unix(_) => Connected::new(),
207        }
208    }
209}