1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use std::pin::Pin;

use crate::error::Error;
use futures::Future;
use http::{header::HeaderName, Request, Response};
use hyper::{client::conn::SendRequest, service::Service, Body};
use log::error;
use tokio::sync::{mpsc, oneshot};
use tower::Layer;

pub(crate) struct RequestSendingSynchronizer {
    request_sender: SendRequest<Body>,
    receiver: mpsc::UnboundedReceiver<(
        oneshot::Sender<Result<Response<Body>, Error>>,
        Request<Body>,
    )>,
}

impl RequestSendingSynchronizer {
    pub(crate) fn new(
        request_sender: SendRequest<Body>,
        receiver: mpsc::UnboundedReceiver<(
            oneshot::Sender<Result<Response<Body>, Error>>,
            Request<Body>,
        )>,
    ) -> Self {
        Self {
            request_sender,
            receiver,
        }
    }

    pub(crate) async fn run(&mut self) {
        while let Some((sender, mut request)) = self.receiver.recv().await {
            let relativized_uri = request
                .uri()
                .path_and_query()
                .ok_or_else(|| Error::RequestError("URI did not contain a path".to_string()))
                .and_then(|path| {
                    path.as_str()
                        .parse()
                        .map_err(|_| Error::RequestError("Given URI was invalid".to_string()))
                });
            let response_fut = relativized_uri.and_then(|path| {
                *request.uri_mut() = path;
                // TODO: don't have this unnecessary overhead every time
                let proxy_connection: HeaderName = HeaderName::from_lowercase(b"proxy-connection")
                    .expect("Infallible: hardcoded header name");
                request.headers_mut().remove(&proxy_connection);
                Ok(self.request_sender.send_request(request))
            });
            let response_to_send = match response_fut {
                Ok(response) => response.await.map_err(|e| e.into()),
                Err(e) => Err(e),
            };
            if let Err(e) = sender.send(response_to_send) {
                error!("Requester not available to receive request {:?}", e);
            }
        }
    }
}

/// A service that will proxy traffic to a target server and return unmodified responses
#[derive(Clone)]
pub struct ThirdWheel {
    sender: mpsc::UnboundedSender<(
        oneshot::Sender<Result<Response<Body>, Error>>,
        Request<Body>,
    )>,
}

impl ThirdWheel {
    pub(crate) fn new(
        sender: mpsc::UnboundedSender<(
            oneshot::Sender<Result<Response<Body>, Error>>,
            Request<Body>,
        )>,
    ) -> Self {
        Self { sender }
    }
}

impl Service<Request<Body>> for ThirdWheel {
    type Response = Response<Body>;

    type Error = crate::error::Error;

    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(
        &mut self,
        _: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        std::task::Poll::Ready(Ok(()))
    }

    /// ThirdWheel performs very little modification of the request before
    /// transmitting it, but it does remove the proxy-connection header to
    /// ensure this is not passed to the target
    fn call(&mut self, request: Request<Body>) -> Self::Future {
        let (response_sender, response_receiver) = oneshot::channel();
        let sender = self.sender.clone();
        let fut = async move {
            //TODO: clarify what errors are possible here
            sender.send((response_sender, request)).map_err(|_| {
                Error::ServerError("Failed to connect to server correctly".to_string())
            })?;
            response_receiver
                .await
                .map_err(|_| Error::ServerError("Failed to get response from server".to_string()))?
        };
        return Box::pin(fut);
    }
}

#[derive(Clone)]
pub struct MitmService<F: Clone, S: Clone> {
    f: F,
    inner: S,
}

impl<F, S> Service<Request<Body>> for MitmService<F, S>
where
    S: Service<Request<Body>, Error = crate::error::Error> + Clone,
    F: FnMut(
            Request<Body>,
            S,
        )
            -> Pin<Box<dyn Future<Output = Result<Response<Body>, crate::error::Error>> + Send>>
        + Clone,
{
    type Response = Response<Body>;
    type Error = crate::error::Error;

    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        (self.f)(req, self.inner.clone())
    }
}

#[derive(Clone)]
pub struct MitmLayer<F: Clone> {
    f: F,
}

impl<S: Clone, F: Clone> Layer<S> for MitmLayer<F> {
    type Service = MitmService<F, S>;
    fn layer(&self, inner: S) -> Self::Service {
        MitmService {
            f: self.f.clone(),
            inner,
        }
    }
}

/// A convenience function for generating man-in-the-middle services
///
/// This function generates a struct that implements the necessary traits to be
/// used as a man-in-the-middle service and will suffice for many use cases.
/// ```ignore
/// let mitm = mitm_layer(|req: Request<Body>, mut third_wheel: ThirdWheel| third_wheel.call(req));
/// let mitm_proxy = MitmProxy::builder(mitm, ca).build();
/// ```
pub fn mitm_layer<F>(f: F) -> MitmLayer<F>
where
    F: FnMut(
            Request<Body>,
            ThirdWheel,
        )
            -> Pin<Box<dyn Future<Output = Result<Response<Body>, crate::error::Error>> + Send>>
        + Clone,
{
    return MitmLayer { f };
}