Skip to main content

temporalio_client/
proxy.rs

1use base64::prelude::*;
2use http_body_util::Empty;
3use hyper::{body::Bytes, header};
4use hyper_util::{
5    client::legacy::{
6        Client,
7        connect::{Connected, Connection},
8    },
9    rt::{TokioExecutor, TokioIo},
10};
11use std::{
12    future::Future,
13    io,
14    pin::Pin,
15    task::{Context, Poll},
16};
17use tokio::{
18    io::{AsyncRead, AsyncWrite, ReadBuf},
19    net::TcpStream,
20};
21use tonic::transport::{Channel, Endpoint};
22use tower::{Service, service_fn};
23
24#[cfg(unix)]
25use tokio::net::UnixStream;
26
27/// Options for HTTP CONNECT proxy.
28#[derive(Clone, Debug)]
29pub struct HttpConnectProxyOptions {
30    /// The host:port to proxy through for TCP, or unix:/path/to/unix.sock for
31    /// Unix socket (which means it must start with "unix:/").
32    pub target_addr: String,
33    /// Optional HTTP basic auth for the proxy as user/pass tuple.
34    pub basic_auth: Option<(String, String)>,
35}
36
37impl HttpConnectProxyOptions {
38    /// Create a channel from the given endpoint that uses the HTTP CONNECT proxy.
39    pub async fn connect_endpoint(
40        &self,
41        endpoint: &Endpoint,
42    ) -> Result<Channel, tonic::transport::Error> {
43        let proxy_options = self.clone();
44        let svc_fn = service_fn(move |uri: tonic::transport::Uri| {
45            let proxy_options = proxy_options.clone();
46            async move { proxy_options.connect(uri).await }
47        });
48        endpoint.connect_with_connector(svc_fn).await
49    }
50
51    async fn connect(
52        &self,
53        uri: tonic::transport::Uri,
54    ) -> anyhow::Result<hyper::upgrade::Upgraded> {
55        let uri = ensure_connect_authority_port(uri);
56        debug!("Connecting to {} via proxy at {}", uri, self.target_addr);
57        // Create CONNECT request
58        let mut req_build = hyper::Request::builder().method("CONNECT").uri(uri);
59        if let Some((user, pass)) = &self.basic_auth {
60            let creds = BASE64_STANDARD.encode(format!("{user}:{pass}"));
61            req_build = req_build.header(header::PROXY_AUTHORIZATION, format!("Basic {creds}"));
62        }
63        let req = req_build.body(Empty::<Bytes>::new())?;
64
65        // We have to create a client with a specific connector because Hyper is
66        // not letting us change the HTTP/2 authority
67        let client = Client::builder(TokioExecutor::new())
68            .build(OverrideAddrConnector(self.target_addr.clone()));
69
70        // Send request
71        let res = client.request(req).await?;
72        if res.status().is_success() {
73            Ok(hyper::upgrade::on(res).await?)
74        } else {
75            Err(anyhow::anyhow!(
76                "CONNECT call failed with status: {}",
77                res.status()
78            ))
79        }
80    }
81}
82
83#[derive(Clone)]
84struct OverrideAddrConnector(String);
85
86impl Service<hyper::Uri> for OverrideAddrConnector {
87    type Response = TokioIo<ProxyStream>;
88
89    type Error = anyhow::Error;
90
91    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
92
93    fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
94        Poll::Ready(Ok(()))
95    }
96
97    fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
98        let target_addr = self.0.clone();
99        let fut = async move {
100            Ok(TokioIo::new(
101                ProxyStream::connect(target_addr.as_str()).await?,
102            ))
103        };
104        Box::pin(fut)
105    }
106}
107
108/// Visible only for tests
109#[doc(hidden)]
110pub enum ProxyStream {
111    Tcp(TcpStream),
112    #[cfg(unix)]
113    Unix(UnixStream),
114}
115
116impl ProxyStream {
117    async fn connect(target_addr: &str) -> anyhow::Result<Self> {
118        if target_addr.starts_with("unix:/") {
119            #[cfg(unix)]
120            {
121                Ok(ProxyStream::Unix(
122                    UnixStream::connect(&target_addr[5..]).await?,
123                ))
124            }
125            #[cfg(not(unix))]
126            {
127                Err(anyhow::anyhow!(
128                    "Unix sockets are not supported on this platform"
129                ))
130            }
131        } else {
132            Ok(ProxyStream::Tcp(TcpStream::connect(target_addr).await?))
133        }
134    }
135}
136
137impl AsyncRead for ProxyStream {
138    fn poll_read(
139        self: Pin<&mut Self>,
140        cx: &mut Context<'_>,
141        buf: &mut ReadBuf<'_>,
142    ) -> Poll<io::Result<()>> {
143        match self.get_mut() {
144            ProxyStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
145            #[cfg(unix)]
146            ProxyStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
147        }
148    }
149}
150
151impl AsyncWrite for ProxyStream {
152    fn poll_write(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &[u8],
156    ) -> Poll<io::Result<usize>> {
157        match self.get_mut() {
158            ProxyStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
159            #[cfg(unix)]
160            ProxyStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
161        }
162    }
163
164    fn poll_write_vectored(
165        self: Pin<&mut Self>,
166        cx: &mut Context<'_>,
167        bufs: &[io::IoSlice<'_>],
168    ) -> Poll<io::Result<usize>> {
169        match self.get_mut() {
170            ProxyStream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
171            #[cfg(unix)]
172            ProxyStream::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
173        }
174    }
175
176    fn is_write_vectored(&self) -> bool {
177        match self {
178            ProxyStream::Tcp(s) => s.is_write_vectored(),
179            #[cfg(unix)]
180            ProxyStream::Unix(s) => s.is_write_vectored(),
181        }
182    }
183
184    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185        match self.get_mut() {
186            ProxyStream::Tcp(s) => Pin::new(s).poll_flush(cx),
187            #[cfg(unix)]
188            ProxyStream::Unix(s) => Pin::new(s).poll_flush(cx),
189        }
190    }
191
192    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193        match self.get_mut() {
194            ProxyStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
195            #[cfg(unix)]
196            ProxyStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
197        }
198    }
199}
200
201impl Connection for ProxyStream {
202    fn connected(&self) -> Connected {
203        match self {
204            ProxyStream::Tcp(s) => s.connected(),
205            // There is no special connected metadata for Unix sockets
206            #[cfg(unix)]
207            ProxyStream::Unix(_) => Connected::new(),
208        }
209    }
210}
211
212/// Ensure the URI authority includes an explicit port so that hyper emits a
213/// RFC 9110-compliant CONNECT request-target (authority-form requires host:port).
214fn ensure_connect_authority_port(uri: tonic::transport::Uri) -> tonic::transport::Uri {
215    if uri.port().is_some() {
216        return uri;
217    }
218    let port = match uri.scheme_str() {
219        Some("https") => 443,
220        Some("http") => 80,
221        _ => return uri,
222    };
223    let mut parts = uri.into_parts();
224    if let Some(ref authority) = parts.authority
225        && let Ok(new_auth) = format!("{}:{}", authority.host(), port).parse()
226    {
227        parts.authority = Some(new_auth);
228    }
229    tonic::transport::Uri::from_parts(parts).expect("adding port to valid URI should not fail")
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use tokio::{
236        io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
237        net::TcpListener,
238    };
239
240    struct CapturedConnect {
241        request_line: String,
242        headers: Vec<String>,
243    }
244
245    async fn mock_proxy() -> (String, tokio::task::JoinHandle<CapturedConnect>) {
246        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
247        let addr = listener.local_addr().unwrap().to_string();
248        let handle = tokio::spawn(async move {
249            let (stream, _) = listener.accept().await.unwrap();
250            let mut reader = BufReader::new(stream);
251            let mut request_line = String::new();
252            reader.read_line(&mut request_line).await.unwrap();
253            let mut headers = Vec::new();
254            loop {
255                let mut line = String::new();
256                reader.read_line(&mut line).await.unwrap();
257                if line == "\r\n" {
258                    break;
259                }
260                headers.push(line.trim_end().to_string());
261            }
262            reader
263                .into_inner()
264                .write_all(b"HTTP/1.1 200 OK\r\n\r\n")
265                .await
266                .unwrap();
267            CapturedConnect {
268                request_line,
269                headers,
270            }
271        });
272        (addr, handle)
273    }
274
275    #[rstest::rstest]
276    #[case("https://example.com/some/path", "CONNECT example.com:443 HTTP/1.1")]
277    #[case("http://example.com", "CONNECT example.com:80 HTTP/1.1")]
278    #[case("https://example.com:7233", "CONNECT example.com:7233 HTTP/1.1")]
279    #[tokio::test]
280    async fn connect_request_line(#[case] uri: &str, #[case] expected: &str) {
281        let (proxy_addr, handle) = mock_proxy().await;
282        let opts = HttpConnectProxyOptions {
283            target_addr: proxy_addr,
284            basic_auth: None,
285        };
286        let uri: tonic::transport::Uri = uri.parse().unwrap();
287        let _ = opts.connect(uri).await;
288
289        let captured = handle.await.unwrap();
290        assert_eq!(captured.request_line.trim(), expected);
291    }
292
293    #[tokio::test]
294    async fn connect_includes_basic_auth() {
295        let (proxy_addr, handle) = mock_proxy().await;
296        let opts = HttpConnectProxyOptions {
297            target_addr: proxy_addr,
298            basic_auth: Some(("user".to_string(), "pass".to_string())),
299        };
300        let uri: tonic::transport::Uri = "https://example.com:7233".parse().unwrap();
301        let _ = opts.connect(uri).await;
302
303        let captured = handle.await.unwrap();
304        let creds = BASE64_STANDARD.encode("user:pass");
305        let auth_header = captured
306            .headers
307            .iter()
308            .find(|h| h.to_lowercase().starts_with("proxy-authorization:"))
309            .expect("missing proxy-authorization header");
310        assert_eq!(
311            auth_header.trim(),
312            format!("proxy-authorization: Basic {creds}")
313        );
314    }
315}