spectacles_proxy/
lib.rs

1//!
2//! A simple reverse proxy, to be used with [Hyper].
3//!
4//! The implementation ensures that [Hop-by-hop headers] are stripped correctly in both directions,
5//! and adds the client's IP address to a comma-space-separated list of forwarding addresses in the
6//! `X-Forwarded-For` header.
7//!
8//! The implementation is based on Go's [`httputil.ReverseProxy`].
9//!
10//! [Hyper]: http://hyper.rs/
11//! [Hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
12//! [`httputil.ReverseProxy`]: https://golang.org/pkg/net/http/httputil/#ReverseProxy
13//!
14//! # Example
15//!
16//! Add these dependencies to your `Cargo.toml` file.
17//!
18//! ```toml
19//! [dependencies]
20//! hyper-reverse-proxy = "0.4"
21//! hyper = "0.12"
22//! futures = "0.1"
23//! ```
24//!
25//! The following example will set up a reverse proxy listening on `127.0.0.1:13900`,
26//! and will proxy these calls:
27//!
28//! * `"/target/first"` will be proxied to `http://127.0.0.1:13901`
29//!
30//! * `"/target/second"` will be proxied to `http://127.0.0.1:13902`
31//!
32//! * All other URLs will be handled by `debug_request` function, that will display request information.
33//!
34//! ```rust,no_run
35//! use hyper::server::conn::AddrStream;
36//! use hyper::{Body, Request, Response, Server};
37//! use hyper::service::{service_fn, make_service_fn};
38//! use futures::future::{self, Future};
39//!
40//! type BoxFut = Box<Future<Item=Response<Body>, Error=hyper::Error> + Send>;
41//!
42//! fn debug_request(req: Request<Body>) -> BoxFut {
43//!     let body_str = format!("{:?}", req);
44//!     let response = Response::new(Body::from(body_str));
45//!     Box::new(future::ok(response))
46//! }
47//!
48//! fn main() {
49//!
50//!     // This is our socket address...
51//!     let addr = ([127, 0, 0, 1], 13900).into();
52//!
53//!     // A `Service` is needed for every connection.
54//!     let make_svc = make_service_fn(|socket: &AddrStream| {
55//!         let remote_addr = socket.remote_addr();
56//!         service_fn(move |req: Request<Body>| { // returns BoxFut
57//!
58//!             if req.uri().path().starts_with("/target/first") {
59//!
60//!                 // will forward requests to port 13901
61//!                 return hyper_reverse_proxy::call(remote_addr.ip(), "http://127.0.0.1:13901", req)
62//!
63//!             } else if req.uri().path().starts_with("/target/second") {
64//!
65//!                 // will forward requests to port 13902
66//!                 return hyper_reverse_proxy::call(remote_addr.ip(), "http://127.0.0.1:13902", req)
67//!
68//!             } else {
69//!                 debug_request(req)
70//!             }
71//!         })
72//!     });
73//!
74//!     let server = Server::bind(&addr)
75//!         .serve(make_svc)
76//!         .map_err(|e| eprintln!("server error: {}", e));
77//!
78//!     println!("Running server on {:?}", addr);
79//!
80//!     // Run this server for... forever!
81//!     hyper::rt::run(server);
82//! }
83//! ```
84//!
85
86use hyper::Body;
87use std::sync::Arc;
88use std::io;
89use hyper::client::HttpConnector;
90use std::net::IpAddr;
91use std::str::FromStr;
92use tokio::net::TcpStream;
93use hyper::header::{HeaderMap, HeaderValue};
94use hyper::client::connect::{Connect, Connected, Destination};
95use hyper::{Request, Response, Client, Uri, StatusCode};
96use futures::future::{self, err, Future};
97use tokio_tls::{TlsConnector, TlsStream};
98use lazy_static::lazy_static;
99
100type BoxFut = Box<Future<Item=Response<Body>, Error=hyper::Error> + Send>;
101
102fn is_hop_header(name: &str) -> bool {
103    use unicase::Ascii;
104
105    // A list of the headers, using `unicase` to help us compare without
106    // worrying about the case, and `lazy_static!` to prevent reallocation
107    // of the vector.
108    lazy_static! {
109        static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![
110            Ascii::new("Connection"),
111            Ascii::new("Keep-Alive"),
112            Ascii::new("Proxy-Authenticate"),
113            Ascii::new("Proxy-Authorization"),
114            Ascii::new("Te"),
115            Ascii::new("Trailers"),
116            Ascii::new("Transfer-Encoding"),
117            Ascii::new("Upgrade"),
118        ];
119    }
120
121    HOP_HEADERS.iter().any(|h| h == &name)
122}
123
124/// Returns a clone of the headers without the [hop-by-hop headers].
125///
126/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
127fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
128    let mut result = HeaderMap::new();
129    for (k, v) in headers.iter() {
130        if !is_hop_header(k.as_str()) {
131            result.insert(k.clone(), v.clone());
132        }
133    }
134    result
135}
136
137fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
138    *response.headers_mut() = remove_hop_headers(response.headers());
139    response
140}
141
142fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Uri {
143    let forward_uri = match req.uri().query() {
144        Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
145        None => format!("{}{}", forward_url, req.uri().path()),
146    };
147
148    Uri::from_str(forward_uri.as_str()).unwrap()
149}
150
151fn create_proxied_request<B>(client_ip: IpAddr, forward_url: &str, mut request: Request<B>) -> Request<B> {
152    *request.headers_mut() = remove_hop_headers(request.headers());
153    *request.uri_mut() = forward_uri(forward_url, &request);
154
155    let x_forwarded_for_header_name = "x-forwarded-for";
156
157    // Add forwarding information in the headers
158    match request.headers_mut().entry(x_forwarded_for_header_name) {
159
160        Ok(header_entry) => {
161            match header_entry {
162                hyper::header::Entry::Vacant(entry) => {
163                    let addr = format!("{}", client_ip);
164                    entry.insert(addr.parse().unwrap());
165                },
166
167                hyper::header::Entry::Occupied(mut entry) => {
168                    let addr = format!("{}, {}", entry.get().to_str().unwrap(), client_ip);
169                    entry.insert(addr.parse().unwrap());
170                }
171            }
172        }
173
174        // shouldn't happen...
175        Err(_) => panic!("Invalid header name: {}", x_forwarded_for_header_name),
176    }
177
178    request
179}
180
181pub fn call(client_ip: IpAddr, forward_uri: &str, request: Request<Body>) -> BoxFut {
182
183	let proxied_request = create_proxied_request(client_ip, forward_uri, request);
184    let tls_cx = native_tls::TlsConnector::builder().build().expect("Failed to build TLS connector");
185    let mut connector = HttpsConnector {
186        tls: Arc::new(tls_cx.into()),
187        http: HttpConnector::new(2),
188    };
189    connector.http.enforce_http(false);
190
191	let client = Client::builder().build(connector);
192	let response = client.request(proxied_request).then(|response| {
193
194		let proxied_response = match response {
195            Ok(response) => create_proxied_response(response),
196            Err(error) => {
197                println!("Error: {}", error); // TODO: Configurable logging
198                Response::builder()
199                    .status(StatusCode::INTERNAL_SERVER_ERROR)
200                    .body(Body::empty())
201                    .unwrap()
202            },
203        };
204
205
206        future::ok(proxied_response)
207	});
208
209	Box::new(response)
210}
211
212
213struct HttpsConnector {
214    tls: Arc<TlsConnector>,
215    http: HttpConnector,
216}
217
218impl Connect for HttpsConnector {
219    type Transport = TlsStream<TcpStream>;
220    type Error = io::Error;
221    type Future =  Box<Future<Item = (Self::Transport, Connected), Error = Self::Error> + Send>;
222
223    fn connect(&self, dst: Destination) -> Self::Future {
224
225        if dst.scheme() != "https" {
226            return Box::new(err(io::Error::new(io::ErrorKind::Other,
227                                      "only works with https")))
228        }
229
230        let host = format!("{}{}", dst.host(), dst.port().map(|p| format!(":{}",p)).unwrap_or("".into()));
231
232        let tls_cx = self.tls.clone();
233        Box::new(self.http.connect(dst).and_then(move |(tcp, connected)| {
234            tls_cx.connect(&host, tcp)
235                .map(|s| (s, connected))
236                .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
237        }))
238
239    }
240
241
242}