squads_temporal_client/
proxy.rs

1use base64::prelude::*;
2use http_body_util::Empty;
3use hyper::{body::Bytes, header};
4use hyper_util::{
5    client::legacy::Client,
6    rt::{TokioExecutor, TokioIo},
7};
8use std::{
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use tokio::net::TcpStream;
14use tonic::transport::{Channel, Endpoint};
15use tower::{Service, service_fn};
16
17/// Options for HTTP CONNECT proxy.
18#[derive(Clone, Debug)]
19pub struct HttpConnectProxyOptions {
20    /// The host:port to proxy through.
21    pub target_addr: String,
22    /// Optional HTTP basic auth for the proxy as user/pass tuple.
23    pub basic_auth: Option<(String, String)>,
24}
25
26impl HttpConnectProxyOptions {
27    /// Create a channel from the given endpoint that uses the HTTP CONNECT proxy.
28    pub async fn connect_endpoint(
29        &self,
30        endpoint: &Endpoint,
31    ) -> Result<Channel, tonic::transport::Error> {
32        let proxy_options = self.clone();
33        let svc_fn = service_fn(move |uri: tonic::transport::Uri| {
34            let proxy_options = proxy_options.clone();
35            async move { proxy_options.connect(uri).await }
36        });
37        endpoint.connect_with_connector(svc_fn).await
38    }
39
40    async fn connect(
41        &self,
42        uri: tonic::transport::Uri,
43    ) -> anyhow::Result<hyper::upgrade::Upgraded> {
44        debug!("Connecting to {} via proxy at {}", uri, self.target_addr);
45        // Create CONNECT request
46        let mut req_build = hyper::Request::builder().method("CONNECT").uri(uri);
47        if let Some((user, pass)) = &self.basic_auth {
48            let creds = BASE64_STANDARD.encode(format!("{user}:{pass}"));
49            req_build = req_build.header(header::PROXY_AUTHORIZATION, format!("Basic {creds}"));
50        }
51        let req = req_build.body(Empty::<Bytes>::new())?;
52
53        // We have to create a client with a specific connector because Hyper is
54        // not letting us change the HTTP/2 authority
55        let client = Client::builder(TokioExecutor::new())
56            .build(OverrideAddrConnector(self.target_addr.clone()));
57
58        // Send request
59        let res = client.request(req).await?;
60        if res.status().is_success() {
61            Ok(hyper::upgrade::on(res).await?)
62        } else {
63            Err(anyhow::anyhow!(
64                "CONNECT call failed with status: {}",
65                res.status()
66            ))
67        }
68    }
69}
70
71#[derive(Clone)]
72struct OverrideAddrConnector(String);
73
74impl Service<hyper::Uri> for OverrideAddrConnector {
75    type Response = TokioIo<TcpStream>;
76
77    type Error = anyhow::Error;
78
79    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
80
81    fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
82        Poll::Ready(Ok(()))
83    }
84
85    fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
86        let target_addr = self.0.clone();
87        let fut = async move { Ok(TokioIo::new(TcpStream::connect(target_addr).await?)) };
88        Box::pin(fut)
89    }
90}