squads_temporal_client/
proxy.rs1use base64::prelude::*;
2use http_body_util::Empty;
3use hyper::{body::Bytes, header};
4use hyper_util::{
5 client::legacy::Client,
6 rt::{TokioExecutor, TokioIo},
7};
8use std::{
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tokio::net::TcpStream;
14use tonic::transport::{Channel, Endpoint};
15use tower::{Service, service_fn};
16
17#[derive(Clone, Debug)]
19pub struct HttpConnectProxyOptions {
20 pub target_addr: String,
22 pub basic_auth: Option<(String, String)>,
24}
25
26impl HttpConnectProxyOptions {
27 pub async fn connect_endpoint(
29 &self,
30 endpoint: &Endpoint,
31 ) -> Result<Channel, tonic::transport::Error> {
32 let proxy_options = self.clone();
33 let svc_fn = service_fn(move |uri: tonic::transport::Uri| {
34 let proxy_options = proxy_options.clone();
35 async move { proxy_options.connect(uri).await }
36 });
37 endpoint.connect_with_connector(svc_fn).await
38 }
39
40 async fn connect(
41 &self,
42 uri: tonic::transport::Uri,
43 ) -> anyhow::Result<hyper::upgrade::Upgraded> {
44 debug!("Connecting to {} via proxy at {}", uri, self.target_addr);
45 let mut req_build = hyper::Request::builder().method("CONNECT").uri(uri);
47 if let Some((user, pass)) = &self.basic_auth {
48 let creds = BASE64_STANDARD.encode(format!("{user}:{pass}"));
49 req_build = req_build.header(header::PROXY_AUTHORIZATION, format!("Basic {creds}"));
50 }
51 let req = req_build.body(Empty::<Bytes>::new())?;
52
53 let client = Client::builder(TokioExecutor::new())
56 .build(OverrideAddrConnector(self.target_addr.clone()));
57
58 let res = client.request(req).await?;
60 if res.status().is_success() {
61 Ok(hyper::upgrade::on(res).await?)
62 } else {
63 Err(anyhow::anyhow!(
64 "CONNECT call failed with status: {}",
65 res.status()
66 ))
67 }
68 }
69}
70
71#[derive(Clone)]
72struct OverrideAddrConnector(String);
73
74impl Service<hyper::Uri> for OverrideAddrConnector {
75 type Response = TokioIo<TcpStream>;
76
77 type Error = anyhow::Error;
78
79 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
80
81 fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
82 Poll::Ready(Ok(()))
83 }
84
85 fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
86 let target_addr = self.0.clone();
87 let fut = async move { Ok(TokioIo::new(TcpStream::connect(target_addr).await?)) };
88 Box::pin(fut)
89 }
90}