1use futures::Future;
2use futures::FutureExt;
3use hyper::server::conn::Http;
4use hyper::{client::conn::Builder, service::Service};
5use native_tls::Certificate;
6use openssl::x509::X509;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::io::AsyncRead;
11use tokio::io::AsyncWrite;
12use tokio::net::TcpStream;
13use tower::Layer;
14
15use http::{Request, Response};
16
17use tokio_native_tls::{TlsAcceptor, TlsStream};
18
19use crate::certificates::spoof_certificate;
20use crate::error::Error;
21
22use log::error;
23
24use crate::{
25 certificates::{native_identity, CertificateAuthority},
26 proxy::mitm::ThirdWheel,
27};
28use hyper::service::{make_service_fn, service_fn};
29use hyper::{server::Server, Body};
30
31use self::mitm::RequestSendingSynchronizer;
32
33pub(crate) mod mitm;
34
35macro_rules! make_service {
41 ($this:ident) => {{
42 let ca = Arc::new($this.ca);
43 let mitm = $this.mitm_layer;
44 let additional_host_mapping = $this.additional_host_mappings;
45 let additional_root_certificates = $this.additional_root_certificates;
46 make_service_fn(move |_| {
47 let ca = ca.clone();
54 let mitm = mitm.clone();
55 let additional_host_mapping = additional_host_mapping.clone();
56 let additional_root_certificates = additional_root_certificates.clone();
57
58 async move {
59 Ok::<_, Error>(service_fn(move |mut req: Request<Body>| {
60 log::info!("Received request to connect: {}", req.uri());
61 let mut res = Response::new(Body::empty());
62
63 if req.method() == http::Method::CONNECT {
65 let target = target_host_port_from_connect(&req);
66 match target {
67 Ok((host, port)) => {
68 let ca = ca.clone();
75 let mitm = mitm.clone();
76 let additional_host_mapping = additional_host_mapping.clone();
77 let additional_root_certificates =
78 additional_root_certificates.clone();
79 tokio::task::spawn(async move {
80 match hyper::upgrade::on(&mut req).await {
81 Ok(upgraded) => {
82 if let Err(e) = run_mitm_on_connection(
83 upgraded,
84 ca,
85 &host,
86 &port,
87 mitm,
88 additional_host_mapping.clone(),
89 additional_root_certificates.clone(),
90 )
91 .await
92 {
93 error!("Proxy failed: {}", e)
94 }
95 }
96 Err(e) => error!("Failed to upgrade to TLS: {}", e),
97 }
98 });
99 *res.status_mut() = http::status::StatusCode::OK;
100 }
101
102 Err(e) => {
103 error!(
104 "Bad request: unable to parse host from connect request: {}",
105 e
106 );
107 *res.status_mut() = http::status::StatusCode::BAD_REQUEST;
108 }
109 }
110 } else {
111 *res.status_mut() = http::status::StatusCode::BAD_REQUEST;
112 }
113 async move { Ok::<_, Error>(res) }
114 }))
115 }
116 })
117 }};
118}
119
120pub struct MitmProxy<T, U>
127where
128 T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
129 U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
130 + std::marker::Sync
131 + std::marker::Send
132 + Clone
133 + 'static,
134 <U as Service<Request<Body>>>::Future: Send,
135 <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
136{
137 mitm_layer: T,
138 ca: CertificateAuthority,
139 additional_root_certificates: Vec<Certificate>,
140 additional_host_mappings: HashMap<String, String>, }
142
143pub struct MitmProxyBuilder<T, U>
145where
146 T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
147 U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
148 + std::marker::Sync
149 + std::marker::Send
150 + Clone
151 + 'static,
152 <U as Service<Request<Body>>>::Future: Send,
153 <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
154{
155 mitm_layer: T,
156 ca: CertificateAuthority,
157 additional_root_certificates: Vec<Certificate>,
158 additional_host_mappings: HashMap<String, String>,
159}
160
161impl<T, U> MitmProxyBuilder<T, U>
163where
164 T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
165 U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
166 + std::marker::Sync
167 + std::marker::Send
168 + Clone
169 + 'static,
170 <U as Service<Request<Body>>>::Future: Send,
171 <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
172{
173 pub fn build(self) -> MitmProxy<T, U> {
174 MitmProxy {
175 mitm_layer: self.mitm_layer,
176 ca: self.ca,
177 additional_root_certificates: self.additional_root_certificates,
178 additional_host_mappings: self.additional_host_mappings,
179 }
180 }
181
182 pub fn additional_root_certificates(
186 mut self,
187 additional_root_certificates: Vec<Certificate>,
188 ) -> Self {
189 self.additional_root_certificates = additional_root_certificates;
190 self
191 }
192
193 pub fn additional_host_mappings(
195 mut self,
196 additional_host_mappings: HashMap<String, String>,
197 ) -> Self {
198 self.additional_host_mappings = additional_host_mappings;
199 self
200 }
201}
202
203impl<T, U> MitmProxy<T, U>
205where
206 T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
207 U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
208 + std::marker::Sync
209 + std::marker::Send
210 + Clone
211 + 'static,
212 <U as Service<Request<Body>>>::Future: Send,
213 <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
214{
215 pub fn builder(mitm_layer: T, ca: CertificateAuthority) -> MitmProxyBuilder<T, U> {
216 MitmProxyBuilder {
217 mitm_layer,
218 ca,
219 additional_root_certificates: Vec::new(),
220 additional_host_mappings: HashMap::new(),
221 }
222 }
223
224 pub fn bind(self, addr: SocketAddr) -> (SocketAddr, impl Future<Output = Result<(), Error>>) {
227 let server = Server::bind(&addr).serve(make_service!(self));
228 (
229 server.local_addr(),
230 server.map(|result| result.map_err(|e| e.into())),
231 )
232 }
233
234 pub fn bind_with_graceful_shutdown<F>(
253 self,
254 addr: SocketAddr,
255 signal: F,
256 ) -> (SocketAddr, impl Future<Output = Result<(), Error>>)
257 where
258 F: Future<Output = ()>,
259 {
260 let server = Server::bind(&addr).serve(make_service!(self));
261 (
262 server.local_addr(),
263 server
264 .with_graceful_shutdown(signal)
265 .map(|result| result.map_err(|e| e.into())),
266 )
267 }
268}
269
270async fn run_mitm_on_connection<S, T, U>(
271 upgraded: S,
272 ca: Arc<CertificateAuthority>,
273 host: &str,
274 port: &str,
275 mitm_maker: T,
276 additional_host_mapping: HashMap<String, String>,
277 additional_root_certificates: Vec<Certificate>,
278) -> Result<(), Error>
279where
280 T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
281 S: AsyncRead + AsyncWrite + std::marker::Unpin + 'static,
282 U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
283 + std::marker::Sync
284 + std::marker::Send
285 + 'static
286 + Clone,
287 U::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
288 <U as Service<Request<Body>>>::Future: Send,
289{
290 let (target_stream, target_certificate) = connect_to_target_with_tls(
291 host,
292 port,
293 additional_host_mapping,
294 additional_root_certificates,
295 )
296 .await?;
297 let certificate = spoof_certificate(&target_certificate, &ca)?;
298 let identity = native_identity(&certificate, &ca.key)?;
299 let client = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity)?);
300 let client_stream = client.accept(upgraded).await?;
301
302 let (request_sender, connection) = Builder::new()
303 .handshake::<TlsStream<TcpStream>, Body>(target_stream)
304 .await?;
305 tokio::spawn(connection);
306 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
307 tokio::spawn(async move {
308 RequestSendingSynchronizer::new(request_sender, receiver)
309 .run()
310 .await
311 });
312 let third_wheel = ThirdWheel::new(sender);
313 let mitm_layer = mitm_maker.layer(third_wheel);
314
315 Http::new()
316 .serve_connection(client_stream, mitm_layer)
317 .await
318 .map_err(|err| err.into())
319}
320
321async fn connect_to_target_with_tls(
322 host: &str,
323 port: &str,
324 additional_host_mapping: HashMap<String, String>,
325 additional_root_certificates: Vec<Certificate>,
326) -> Result<(TlsStream<TcpStream>, X509), Error> {
327 let host_address = additional_host_mapping
328 .get(host)
329 .map(|s| s.as_str())
330 .unwrap_or(host);
331 let target_stream = TcpStream::connect(format!("{}:{}", host_address, port)).await?;
332
333 let mut connector = native_tls::TlsConnector::builder();
334 for root_certificate in additional_root_certificates {
335 connector.add_root_certificate(root_certificate);
336 }
337 let connector = connector.build()?;
338
339 let tokio_connector = tokio_native_tls::TlsConnector::from(connector);
340 let target_stream = tokio_connector.connect(host, target_stream).await?;
341 let certificate = &target_stream.get_ref().peer_certificate()?;
343
344 let certificate = match certificate {
345 Some(cert) => cert,
346 None => {
347 return Err(Error::ServerError(
348 "Server did not provide a certificate for TLS connection".to_string(),
349 ))
350 }
351 };
352 let certificate = openssl::x509::X509::from_der(&certificate.to_der()?)?;
353
354 Ok((target_stream, certificate))
355}
356
357fn target_host_port_from_connect(request: &Request<Body>) -> Result<(String, String), Error> {
358 let host = request
359 .uri()
360 .host()
361 .map(std::string::ToString::to_string)
362 .ok_or(Error::RequestError(
363 "No host found on CONNECT request".to_string(),
364 ))?;
365 let port = request
366 .uri()
367 .port()
368 .map(|x| x.to_string())
369 .ok_or(Error::RequestError(
370 "No port found on CONNECT request".to_string(),
371 ))?;
372 Ok((host, port))
373}