temporalio_client/
proxy.rs1use base64::prelude::*;
2use http_body_util::Empty;
3use hyper::{body::Bytes, header};
4use hyper_util::{
5 client::legacy::{
6 Client,
7 connect::{Connected, Connection},
8 },
9 rt::{TokioExecutor, TokioIo},
10};
11use std::{
12 future::Future,
13 io,
14 pin::Pin,
15 task::{Context, Poll},
16};
17use tokio::{
18 io::{AsyncRead, AsyncWrite, ReadBuf},
19 net::TcpStream,
20};
21use tonic::transport::{Channel, Endpoint};
22use tower::{Service, service_fn};
23
24#[cfg(unix)]
25use tokio::net::UnixStream;
26
27#[derive(Clone, Debug)]
29pub struct HttpConnectProxyOptions {
30 pub target_addr: String,
33 pub basic_auth: Option<(String, String)>,
35}
36
37impl HttpConnectProxyOptions {
38 pub async fn connect_endpoint(
40 &self,
41 endpoint: &Endpoint,
42 ) -> Result<Channel, tonic::transport::Error> {
43 let proxy_options = self.clone();
44 let svc_fn = service_fn(move |uri: tonic::transport::Uri| {
45 let proxy_options = proxy_options.clone();
46 async move { proxy_options.connect(uri).await }
47 });
48 endpoint.connect_with_connector(svc_fn).await
49 }
50
51 async fn connect(
52 &self,
53 uri: tonic::transport::Uri,
54 ) -> anyhow::Result<hyper::upgrade::Upgraded> {
55 debug!("Connecting to {} via proxy at {}", uri, self.target_addr);
56 let mut req_build = hyper::Request::builder().method("CONNECT").uri(uri);
58 if let Some((user, pass)) = &self.basic_auth {
59 let creds = BASE64_STANDARD.encode(format!("{user}:{pass}"));
60 req_build = req_build.header(header::PROXY_AUTHORIZATION, format!("Basic {creds}"));
61 }
62 let req = req_build.body(Empty::<Bytes>::new())?;
63
64 let client = Client::builder(TokioExecutor::new())
67 .build(OverrideAddrConnector(self.target_addr.clone()));
68
69 let res = client.request(req).await?;
71 if res.status().is_success() {
72 Ok(hyper::upgrade::on(res).await?)
73 } else {
74 Err(anyhow::anyhow!(
75 "CONNECT call failed with status: {}",
76 res.status()
77 ))
78 }
79 }
80}
81
82#[derive(Clone)]
83struct OverrideAddrConnector(String);
84
85impl Service<hyper::Uri> for OverrideAddrConnector {
86 type Response = TokioIo<ProxyStream>;
87
88 type Error = anyhow::Error;
89
90 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
91
92 fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
93 Poll::Ready(Ok(()))
94 }
95
96 fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
97 let target_addr = self.0.clone();
98 let fut = async move {
99 Ok(TokioIo::new(
100 ProxyStream::connect(target_addr.as_str()).await?,
101 ))
102 };
103 Box::pin(fut)
104 }
105}
106
107#[doc(hidden)]
109pub enum ProxyStream {
110 Tcp(TcpStream),
111 #[cfg(unix)]
112 Unix(UnixStream),
113}
114
115impl ProxyStream {
116 async fn connect(target_addr: &str) -> anyhow::Result<Self> {
117 if target_addr.starts_with("unix:/") {
118 #[cfg(unix)]
119 {
120 Ok(ProxyStream::Unix(
121 UnixStream::connect(&target_addr[5..]).await?,
122 ))
123 }
124 #[cfg(not(unix))]
125 {
126 Err(anyhow::anyhow!(
127 "Unix sockets are not supported on this platform"
128 ))
129 }
130 } else {
131 Ok(ProxyStream::Tcp(TcpStream::connect(target_addr).await?))
132 }
133 }
134}
135
136impl AsyncRead for ProxyStream {
137 fn poll_read(
138 self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 buf: &mut ReadBuf<'_>,
141 ) -> Poll<io::Result<()>> {
142 match self.get_mut() {
143 ProxyStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
144 #[cfg(unix)]
145 ProxyStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
146 }
147 }
148}
149
150impl AsyncWrite for ProxyStream {
151 fn poll_write(
152 self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 buf: &[u8],
155 ) -> Poll<io::Result<usize>> {
156 match self.get_mut() {
157 ProxyStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
158 #[cfg(unix)]
159 ProxyStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
160 }
161 }
162
163 fn poll_write_vectored(
164 self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 bufs: &[io::IoSlice<'_>],
167 ) -> Poll<io::Result<usize>> {
168 match self.get_mut() {
169 ProxyStream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
170 #[cfg(unix)]
171 ProxyStream::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
172 }
173 }
174
175 fn is_write_vectored(&self) -> bool {
176 match self {
177 ProxyStream::Tcp(s) => s.is_write_vectored(),
178 #[cfg(unix)]
179 ProxyStream::Unix(s) => s.is_write_vectored(),
180 }
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184 match self.get_mut() {
185 ProxyStream::Tcp(s) => Pin::new(s).poll_flush(cx),
186 #[cfg(unix)]
187 ProxyStream::Unix(s) => Pin::new(s).poll_flush(cx),
188 }
189 }
190
191 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192 match self.get_mut() {
193 ProxyStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
194 #[cfg(unix)]
195 ProxyStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
196 }
197 }
198}
199
200impl Connection for ProxyStream {
201 fn connected(&self) -> Connected {
202 match self {
203 ProxyStream::Tcp(s) => s.connected(),
204 #[cfg(unix)]
206 ProxyStream::Unix(_) => Connected::new(),
207 }
208 }
209}