third_wheel/
proxy.rs

1use futures::Future;
2use futures::FutureExt;
3use hyper::server::conn::Http;
4use hyper::{client::conn::Builder, service::Service};
5use native_tls::Certificate;
6use openssl::x509::X509;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::io::AsyncRead;
11use tokio::io::AsyncWrite;
12use tokio::net::TcpStream;
13use tower::Layer;
14
15use http::{Request, Response};
16
17use tokio_native_tls::{TlsAcceptor, TlsStream};
18
19use crate::certificates::spoof_certificate;
20use crate::error::Error;
21
22use log::error;
23
24use crate::{
25    certificates::{native_identity, CertificateAuthority},
26    proxy::mitm::ThirdWheel,
27};
28use hyper::service::{make_service_fn, service_fn};
29use hyper::{server::Server, Body};
30
31use self::mitm::RequestSendingSynchronizer;
32
33pub(crate) mod mitm;
34
35// TODO: do this without macro hackery
36// The idea of using of a macro here is borrowed from warp after hitting my head against it for some time.
37// We want to be able to return a make service for reuse of code. But the return
38// type is inordinately complex and/or hidden by hyper's module privacy so instead we inline the code twice.
39// either we should replace this with a private function on MitmProxy, or we should do *something else*
40macro_rules! make_service {
41    ($this:ident) => {{
42        let ca = Arc::new($this.ca);
43        let mitm = $this.mitm_layer;
44        let additional_host_mapping = $this.additional_host_mappings;
45        let additional_root_certificates = $this.additional_root_certificates;
46        make_service_fn(move |_| {
47            // While the state was moved into the make_service closure,
48            // we need to clone it here because this closure is called
49            // once for every connection.
50            //
51            // Each connection could send multiple requests, so
52            // the `Service` needs a clone to handle later requests.
53            let ca = ca.clone();
54            let mitm = mitm.clone();
55            let additional_host_mapping = additional_host_mapping.clone();
56            let additional_root_certificates = additional_root_certificates.clone();
57
58            async move {
59                Ok::<_, Error>(service_fn(move |mut req: Request<Body>| {
60                    log::info!("Received request to connect: {}", req.uri());
61                    let mut res = Response::new(Body::empty());
62
63                    // The proxy can only handle CONNECT requests
64                    if req.method() == http::Method::CONNECT {
65                        let target = target_host_port_from_connect(&req);
66                        match target {
67                            Ok((host, port)) => {
68                                // TODO: handle non-encrypted proxying
69                                // TODO: how to handle port != 80/443
70                                // In the case of a TLS tunnel request we spawn a new
71                                // service to handle the upgrade. This will only happen
72                                // after the currently running function finishes so we need
73                                // to spawn it as a separate future.
74                                let ca = ca.clone();
75                                let mitm = mitm.clone();
76                                let additional_host_mapping = additional_host_mapping.clone();
77                                let additional_root_certificates =
78                                    additional_root_certificates.clone();
79                                tokio::task::spawn(async move {
80                                    match hyper::upgrade::on(&mut req).await {
81                                        Ok(upgraded) => {
82                                            if let Err(e) = run_mitm_on_connection(
83                                                upgraded,
84                                                ca,
85                                                &host,
86                                                &port,
87                                                mitm,
88                                                additional_host_mapping.clone(),
89                                                additional_root_certificates.clone(),
90                                            )
91                                            .await
92                                            {
93                                                error!("Proxy failed: {}", e)
94                                            }
95                                        }
96                                        Err(e) => error!("Failed to upgrade to TLS: {}", e),
97                                    }
98                                });
99                                *res.status_mut() = http::status::StatusCode::OK;
100                            }
101
102                            Err(e) => {
103                                error!(
104                                    "Bad request: unable to parse host from connect request: {}",
105                                    e
106                                );
107                                *res.status_mut() = http::status::StatusCode::BAD_REQUEST;
108                            }
109                        }
110                    } else {
111                        *res.status_mut() = http::status::StatusCode::BAD_REQUEST;
112                    }
113                    async move { Ok::<_, Error>(res) }
114                }))
115            }
116        })
117    }};
118}
119
120/// The main struct of the crate. Start here.
121///
122/// This struct is the workhorse and main interface for third-wheel.
123/// By passing in a Mitm layer this can be customized to perform any required
124/// behavior on HTTP requests and responses. Use the `mitm_layer` function to
125/// easily construct services to pass in to this struct.
126pub struct MitmProxy<T, U>
127where
128    T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
129    U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
130        + std::marker::Sync
131        + std::marker::Send
132        + Clone
133        + 'static,
134    <U as Service<Request<Body>>>::Future: Send,
135    <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
136{
137    mitm_layer: T,
138    ca: CertificateAuthority,
139    additional_root_certificates: Vec<Certificate>,
140    additional_host_mappings: HashMap<String, String>, // TODO: this should be more restrictively typed
141}
142
143/// Builder interface for constructing `MitmProxy`'s
144pub struct MitmProxyBuilder<T, U>
145where
146    T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
147    U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
148        + std::marker::Sync
149        + std::marker::Send
150        + Clone
151        + 'static,
152    <U as Service<Request<Body>>>::Future: Send,
153    <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
154{
155    mitm_layer: T,
156    ca: CertificateAuthority,
157    additional_root_certificates: Vec<Certificate>,
158    additional_host_mappings: HashMap<String, String>,
159}
160
161// impl MitmProxyBuilder
162impl<T, U> MitmProxyBuilder<T, U>
163where
164    T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
165    U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
166        + std::marker::Sync
167        + std::marker::Send
168        + Clone
169        + 'static,
170    <U as Service<Request<Body>>>::Future: Send,
171    <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
172{
173    pub fn build(self) -> MitmProxy<T, U> {
174        MitmProxy {
175            mitm_layer: self.mitm_layer,
176            ca: self.ca,
177            additional_root_certificates: self.additional_root_certificates,
178            additional_host_mappings: self.additional_host_mappings,
179        }
180    }
181
182    /// Add root certificates that the proxy should trust when making outgoing
183    /// connections. This is in addition to the system certificates that are
184    /// already trusted.
185    pub fn additional_root_certificates(
186        mut self,
187        additional_root_certificates: Vec<Certificate>,
188    ) -> Self {
189        self.additional_root_certificates = additional_root_certificates;
190        self
191    }
192
193    /// Add mappings for particular hosts to IP addresses. Useful for testing against local TLS servers.
194    pub fn additional_host_mappings(
195        mut self,
196        additional_host_mappings: HashMap<String, String>,
197    ) -> Self {
198        self.additional_host_mappings = additional_host_mappings;
199        self
200    }
201}
202
203// impl MitmProxy
204impl<T, U> MitmProxy<T, U>
205where
206    T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
207    U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
208        + std::marker::Sync
209        + std::marker::Send
210        + Clone
211        + 'static,
212    <U as Service<Request<Body>>>::Future: Send,
213    <U as Service<Request<Body>>>::Error: std::error::Error + Send + Sync + 'static,
214{
215    pub fn builder(mitm_layer: T, ca: CertificateAuthority) -> MitmProxyBuilder<T, U> {
216        MitmProxyBuilder {
217            mitm_layer,
218            ca,
219            additional_root_certificates: Vec::new(),
220            additional_host_mappings: HashMap::new(),
221        }
222    }
223
224    /// Bind to a socket address. Returns the address actually bound to, and the
225    /// future to be executed that will run the server.
226    pub fn bind(self, addr: SocketAddr) -> (SocketAddr, impl Future<Output = Result<(), Error>>) {
227        let server = Server::bind(&addr).serve(make_service!(self));
228        (
229            server.local_addr(),
230            server.map(|result| result.map_err(|e| e.into())),
231        )
232    }
233
234    /// The same as bind except in the event that signal completes the proxy
235    /// will gracefully terminate itself.
236    /// ```ignore
237    /// let trivial_mitm = MitmProxy::builder(
238    ///     mitm_layer(|req: Request<Body>, mut third_wheel: ThirdWheel| third_wheel.call(req)),
239    ///     third_wheel_ca,
240    /// ).build();
241
242    /// let (third_wheel_killer, receiver) = tokio::sync::oneshot::channel();
243    /// let (third_wheel_address, mitm_fut) = trivial_mitm
244    ///     .bind_with_graceful_shutdown("127.0.0.1:0".parse().unwrap(), async {
245    ///         receiver.await.ok().unwrap()
246    ///     });
247    /// tokio::spawn(mitm_fut);
248    /// // Wait for some stuff to happen
249    /// third_wheel_killer.send(()).unwrap();
250    /// // This kills the proxy
251    /// ```
252    pub fn bind_with_graceful_shutdown<F>(
253        self,
254        addr: SocketAddr,
255        signal: F,
256    ) -> (SocketAddr, impl Future<Output = Result<(), Error>>)
257    where
258        F: Future<Output = ()>,
259    {
260        let server = Server::bind(&addr).serve(make_service!(self));
261        (
262            server.local_addr(),
263            server
264                .with_graceful_shutdown(signal)
265                .map(|result| result.map_err(|e| e.into())),
266        )
267    }
268}
269
270async fn run_mitm_on_connection<S, T, U>(
271    upgraded: S,
272    ca: Arc<CertificateAuthority>,
273    host: &str,
274    port: &str,
275    mitm_maker: T,
276    additional_host_mapping: HashMap<String, String>,
277    additional_root_certificates: Vec<Certificate>,
278) -> Result<(), Error>
279where
280    T: Layer<ThirdWheel, Service = U> + std::marker::Sync + std::marker::Send + 'static + Clone,
281    S: AsyncRead + AsyncWrite + std::marker::Unpin + 'static,
282    U: Service<Request<Body>, Response = <ThirdWheel as Service<Request<Body>>>::Response>
283        + std::marker::Sync
284        + std::marker::Send
285        + 'static
286        + Clone,
287    U::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
288    <U as Service<Request<Body>>>::Future: Send,
289{
290    let (target_stream, target_certificate) = connect_to_target_with_tls(
291        host,
292        port,
293        additional_host_mapping,
294        additional_root_certificates,
295    )
296    .await?;
297    let certificate = spoof_certificate(&target_certificate, &ca)?;
298    let identity = native_identity(&certificate, &ca.key)?;
299    let client = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity)?);
300    let client_stream = client.accept(upgraded).await?;
301
302    let (request_sender, connection) = Builder::new()
303        .handshake::<TlsStream<TcpStream>, Body>(target_stream)
304        .await?;
305    tokio::spawn(connection);
306    let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
307    tokio::spawn(async move {
308        RequestSendingSynchronizer::new(request_sender, receiver)
309            .run()
310            .await
311    });
312    let third_wheel = ThirdWheel::new(sender);
313    let mitm_layer = mitm_maker.layer(third_wheel);
314
315    Http::new()
316        .serve_connection(client_stream, mitm_layer)
317        .await
318        .map_err(|err| err.into())
319}
320
321async fn connect_to_target_with_tls(
322    host: &str,
323    port: &str,
324    additional_host_mapping: HashMap<String, String>,
325    additional_root_certificates: Vec<Certificate>,
326) -> Result<(TlsStream<TcpStream>, X509), Error> {
327    let host_address = additional_host_mapping
328        .get(host)
329        .map(|s| s.as_str())
330        .unwrap_or(host);
331    let target_stream = TcpStream::connect(format!("{}:{}", host_address, port)).await?;
332
333    let mut connector = native_tls::TlsConnector::builder();
334    for root_certificate in additional_root_certificates {
335        connector.add_root_certificate(root_certificate);
336    }
337    let connector = connector.build()?;
338
339    let tokio_connector = tokio_native_tls::TlsConnector::from(connector);
340    let target_stream = tokio_connector.connect(host, target_stream).await?;
341    //TODO: Currently to copy the certificate we do a round trip from one library -> der -> other library. This is inefficient, it should be possible to do it better some how.
342    let certificate = &target_stream.get_ref().peer_certificate()?;
343
344    let certificate = match certificate {
345        Some(cert) => cert,
346        None => {
347            return Err(Error::ServerError(
348                "Server did not provide a certificate for TLS connection".to_string(),
349            ))
350        }
351    };
352    let certificate = openssl::x509::X509::from_der(&certificate.to_der()?)?;
353
354    Ok((target_stream, certificate))
355}
356
357fn target_host_port_from_connect(request: &Request<Body>) -> Result<(String, String), Error> {
358    let host = request
359        .uri()
360        .host()
361        .map(std::string::ToString::to_string)
362        .ok_or(Error::RequestError(
363            "No host found on CONNECT request".to_string(),
364        ))?;
365    let port = request
366        .uri()
367        .port()
368        .map(|x| x.to_string())
369        .ok_or(Error::RequestError(
370            "No port found on CONNECT request".to_string(),
371        ))?;
372    Ok((host, port))
373}