1use base64::prelude::*;
2use http_body_util::Empty;
3use hyper::{body::Bytes, header};
4use hyper_util::{
5 client::legacy::{
6 Client,
7 connect::{Connected, Connection},
8 },
9 rt::{TokioExecutor, TokioIo},
10};
11use std::{
12 future::Future,
13 io,
14 pin::Pin,
15 task::{Context, Poll},
16};
17use tokio::{
18 io::{AsyncRead, AsyncWrite, ReadBuf},
19 net::TcpStream,
20};
21use tonic::transport::{Channel, Endpoint};
22use tower::{Service, service_fn};
23
24#[cfg(unix)]
25use tokio::net::UnixStream;
26
27#[derive(Clone, Debug)]
29pub struct HttpConnectProxyOptions {
30 pub target_addr: String,
33 pub basic_auth: Option<(String, String)>,
35}
36
37impl HttpConnectProxyOptions {
38 pub async fn connect_endpoint(
40 &self,
41 endpoint: &Endpoint,
42 ) -> Result<Channel, tonic::transport::Error> {
43 let proxy_options = self.clone();
44 let svc_fn = service_fn(move |uri: tonic::transport::Uri| {
45 let proxy_options = proxy_options.clone();
46 async move { proxy_options.connect(uri).await }
47 });
48 endpoint.connect_with_connector(svc_fn).await
49 }
50
51 async fn connect(
52 &self,
53 uri: tonic::transport::Uri,
54 ) -> anyhow::Result<hyper::upgrade::Upgraded> {
55 let uri = ensure_connect_authority_port(uri);
56 debug!("Connecting to {} via proxy at {}", uri, self.target_addr);
57 let mut req_build = hyper::Request::builder().method("CONNECT").uri(uri);
59 if let Some((user, pass)) = &self.basic_auth {
60 let creds = BASE64_STANDARD.encode(format!("{user}:{pass}"));
61 req_build = req_build.header(header::PROXY_AUTHORIZATION, format!("Basic {creds}"));
62 }
63 let req = req_build.body(Empty::<Bytes>::new())?;
64
65 let client = Client::builder(TokioExecutor::new())
68 .build(OverrideAddrConnector(self.target_addr.clone()));
69
70 let res = client.request(req).await?;
72 if res.status().is_success() {
73 Ok(hyper::upgrade::on(res).await?)
74 } else {
75 Err(anyhow::anyhow!(
76 "CONNECT call failed with status: {}",
77 res.status()
78 ))
79 }
80 }
81}
82
83#[derive(Clone)]
84struct OverrideAddrConnector(String);
85
86impl Service<hyper::Uri> for OverrideAddrConnector {
87 type Response = TokioIo<ProxyStream>;
88
89 type Error = anyhow::Error;
90
91 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
92
93 fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
94 Poll::Ready(Ok(()))
95 }
96
97 fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
98 let target_addr = self.0.clone();
99 let fut = async move {
100 Ok(TokioIo::new(
101 ProxyStream::connect(target_addr.as_str()).await?,
102 ))
103 };
104 Box::pin(fut)
105 }
106}
107
108#[doc(hidden)]
110pub enum ProxyStream {
111 Tcp(TcpStream),
112 #[cfg(unix)]
113 Unix(UnixStream),
114}
115
116impl ProxyStream {
117 async fn connect(target_addr: &str) -> anyhow::Result<Self> {
118 if target_addr.starts_with("unix:/") {
119 #[cfg(unix)]
120 {
121 Ok(ProxyStream::Unix(
122 UnixStream::connect(&target_addr[5..]).await?,
123 ))
124 }
125 #[cfg(not(unix))]
126 {
127 Err(anyhow::anyhow!(
128 "Unix sockets are not supported on this platform"
129 ))
130 }
131 } else {
132 Ok(ProxyStream::Tcp(TcpStream::connect(target_addr).await?))
133 }
134 }
135}
136
137impl AsyncRead for ProxyStream {
138 fn poll_read(
139 self: Pin<&mut Self>,
140 cx: &mut Context<'_>,
141 buf: &mut ReadBuf<'_>,
142 ) -> Poll<io::Result<()>> {
143 match self.get_mut() {
144 ProxyStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
145 #[cfg(unix)]
146 ProxyStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
147 }
148 }
149}
150
151impl AsyncWrite for ProxyStream {
152 fn poll_write(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &[u8],
156 ) -> Poll<io::Result<usize>> {
157 match self.get_mut() {
158 ProxyStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
159 #[cfg(unix)]
160 ProxyStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
161 }
162 }
163
164 fn poll_write_vectored(
165 self: Pin<&mut Self>,
166 cx: &mut Context<'_>,
167 bufs: &[io::IoSlice<'_>],
168 ) -> Poll<io::Result<usize>> {
169 match self.get_mut() {
170 ProxyStream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
171 #[cfg(unix)]
172 ProxyStream::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
173 }
174 }
175
176 fn is_write_vectored(&self) -> bool {
177 match self {
178 ProxyStream::Tcp(s) => s.is_write_vectored(),
179 #[cfg(unix)]
180 ProxyStream::Unix(s) => s.is_write_vectored(),
181 }
182 }
183
184 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185 match self.get_mut() {
186 ProxyStream::Tcp(s) => Pin::new(s).poll_flush(cx),
187 #[cfg(unix)]
188 ProxyStream::Unix(s) => Pin::new(s).poll_flush(cx),
189 }
190 }
191
192 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193 match self.get_mut() {
194 ProxyStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
195 #[cfg(unix)]
196 ProxyStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
197 }
198 }
199}
200
201impl Connection for ProxyStream {
202 fn connected(&self) -> Connected {
203 match self {
204 ProxyStream::Tcp(s) => s.connected(),
205 #[cfg(unix)]
207 ProxyStream::Unix(_) => Connected::new(),
208 }
209 }
210}
211
212fn ensure_connect_authority_port(uri: tonic::transport::Uri) -> tonic::transport::Uri {
215 if uri.port().is_some() {
216 return uri;
217 }
218 let port = match uri.scheme_str() {
219 Some("https") => 443,
220 Some("http") => 80,
221 _ => return uri,
222 };
223 let mut parts = uri.into_parts();
224 if let Some(ref authority) = parts.authority
225 && let Ok(new_auth) = format!("{}:{}", authority.host(), port).parse()
226 {
227 parts.authority = Some(new_auth);
228 }
229 tonic::transport::Uri::from_parts(parts).expect("adding port to valid URI should not fail")
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use tokio::{
236 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
237 net::TcpListener,
238 };
239
240 struct CapturedConnect {
241 request_line: String,
242 headers: Vec<String>,
243 }
244
245 async fn mock_proxy() -> (String, tokio::task::JoinHandle<CapturedConnect>) {
246 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
247 let addr = listener.local_addr().unwrap().to_string();
248 let handle = tokio::spawn(async move {
249 let (stream, _) = listener.accept().await.unwrap();
250 let mut reader = BufReader::new(stream);
251 let mut request_line = String::new();
252 reader.read_line(&mut request_line).await.unwrap();
253 let mut headers = Vec::new();
254 loop {
255 let mut line = String::new();
256 reader.read_line(&mut line).await.unwrap();
257 if line == "\r\n" {
258 break;
259 }
260 headers.push(line.trim_end().to_string());
261 }
262 reader
263 .into_inner()
264 .write_all(b"HTTP/1.1 200 OK\r\n\r\n")
265 .await
266 .unwrap();
267 CapturedConnect {
268 request_line,
269 headers,
270 }
271 });
272 (addr, handle)
273 }
274
275 #[rstest::rstest]
276 #[case("https://example.com/some/path", "CONNECT example.com:443 HTTP/1.1")]
277 #[case("http://example.com", "CONNECT example.com:80 HTTP/1.1")]
278 #[case("https://example.com:7233", "CONNECT example.com:7233 HTTP/1.1")]
279 #[tokio::test]
280 async fn connect_request_line(#[case] uri: &str, #[case] expected: &str) {
281 let (proxy_addr, handle) = mock_proxy().await;
282 let opts = HttpConnectProxyOptions {
283 target_addr: proxy_addr,
284 basic_auth: None,
285 };
286 let uri: tonic::transport::Uri = uri.parse().unwrap();
287 let _ = opts.connect(uri).await;
288
289 let captured = handle.await.unwrap();
290 assert_eq!(captured.request_line.trim(), expected);
291 }
292
293 #[tokio::test]
294 async fn connect_includes_basic_auth() {
295 let (proxy_addr, handle) = mock_proxy().await;
296 let opts = HttpConnectProxyOptions {
297 target_addr: proxy_addr,
298 basic_auth: Some(("user".to_string(), "pass".to_string())),
299 };
300 let uri: tonic::transport::Uri = "https://example.com:7233".parse().unwrap();
301 let _ = opts.connect(uri).await;
302
303 let captured = handle.await.unwrap();
304 let creds = BASE64_STANDARD.encode("user:pass");
305 let auth_header = captured
306 .headers
307 .iter()
308 .find(|h| h.to_lowercase().starts_with("proxy-authorization:"))
309 .expect("missing proxy-authorization header");
310 assert_eq!(
311 auth_header.trim(),
312 format!("proxy-authorization: Basic {creds}")
313 );
314 }
315}