trillium_proxy/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11
12/*!
13http reverse and forward proxy trillium handler
14
15*/
16
17mod body_streamer;
18mod forward_proxy_connect;
19pub mod upstream;
20
21use body_streamer::stream_body;
22use full_duplex_async_copy::full_duplex_copy;
23use futures_lite::future::zip;
24use size::{Base, Size};
25use std::{borrow::Cow, fmt::Debug, future::IntoFuture};
26use trillium::{
27    async_trait, Conn, Handler, KnownHeaderName,
28    Status::{NotFound, SwitchingProtocols},
29    Upgrade,
30};
31use trillium_forwarding::Forwarded;
32use trillium_http::{HeaderName, HeaderValue, Headers, Status, Version};
33use upstream::{IntoUpstreamSelector, UpstreamSelector};
34
35pub use forward_proxy_connect::ForwardProxyConnect;
36pub use trillium_client::{Client, Connector};
37pub use url::Url;
38
39/// constructs a new [`Proxy`]. alias of [`Proxy::new`]
40pub fn proxy<I>(client: impl Into<Client>, upstream: I) -> Proxy<I::UpstreamSelector>
41where
42    I: IntoUpstreamSelector,
43{
44    Proxy::new(client, upstream)
45}
46
47/**
48the proxy handler
49*/
50#[derive(Debug)]
51pub struct Proxy<U> {
52    upstream: U,
53    client: Client,
54    pass_through_not_found: bool,
55    halt: bool,
56    via_pseudonym: Option<Cow<'static, str>>,
57    allow_websocket_upgrade: bool,
58}
59
60impl<U: UpstreamSelector> Proxy<U> {
61    /**
62    construct a new proxy handler that sends all requests to the upstream
63    provided
64
65    ```
66    use trillium_smol::ClientConfig;
67    use trillium_proxy::Proxy;
68
69    let proxy = Proxy::new(ClientConfig::default(), "http://docs.trillium.rs/trillium_proxy");
70    ```
71
72     */
73    pub fn new<I>(client: impl Into<Client>, upstream: I) -> Self
74    where
75        I: IntoUpstreamSelector<UpstreamSelector = U>,
76    {
77        Self {
78            upstream: upstream.into_upstream(),
79            client: client.into(),
80            pass_through_not_found: true,
81            halt: true,
82            via_pseudonym: None,
83            allow_websocket_upgrade: false,
84        }
85    }
86
87    /**
88    chainable constructor to set the 404 Not Found handling
89    behavior. By default, this proxy will pass through the trillium
90    Conn unmodified if the proxy response is a 404 not found, allowing
91    it to be chained in a tuple handler. To modify this behavior, call
92    proxy_not_found, and the full 404 response will be forwarded. The
93    Conn will be halted unless [`Proxy::without_halting`] was
94    configured
95
96    ```
97    # use trillium_smol::ClientConfig;
98    # use trillium_proxy::Proxy;
99    let proxy = Proxy::new(ClientConfig::default(), "http://trillium.rs")
100        .proxy_not_found();
101    ```
102    */
103    pub fn proxy_not_found(mut self) -> Self {
104        self.pass_through_not_found = false;
105        self
106    }
107
108    /**
109    The default behavior for this handler is to halt the conn on any
110    response other than a 404. If [`Proxy::proxy_not_found`] has been
111    configured, the default behavior for all response statuses is to
112    halt the trillium conn. To change this behavior, call
113    without_halting when constructing the proxy, and it will not halt
114    the conn. This is useful when passing the proxy reply through
115    [`trillium_html_rewriter`](https://docs.trillium.rs/trillium_html_rewriter).
116
117    ```
118    # use trillium_smol::ClientConfig;
119    # use trillium_proxy::Proxy;
120    let proxy = Proxy::new(ClientConfig::default(), "http://trillium.rs")
121        .without_halting();
122    ```
123    */
124    pub fn without_halting(mut self) -> Self {
125        self.halt = false;
126        self
127    }
128
129    /// populate the pseudonym for a
130    /// [`Via`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Via)
131    /// header. If no pseudonym is provided, no via header will be
132    /// inserted.
133    pub fn with_via_pseudonym(mut self, via_pseudonym: impl Into<Cow<'static, str>>) -> Self {
134        self.via_pseudonym = Some(via_pseudonym.into());
135        self
136    }
137
138    /// Allow websockets to be proxied
139    ///
140    /// This is not currently the default, but that may change at some (semver-minor) point in the
141    /// future
142    pub fn with_websocket_upgrades(mut self) -> Self {
143        self.allow_websocket_upgrade = true;
144        self
145    }
146
147    fn set_via_pseudonym(&self, headers: &mut Headers, version: Version) {
148        if let Some(via) = &self.via_pseudonym {
149            let via = match headers.get_values(KnownHeaderName::Via) {
150                Some(old_via) => format!(
151                    "{version} {via}, {}",
152                    old_via
153                        .iter()
154                        .filter_map(HeaderValue::as_str)
155                        .collect::<Vec<_>>()
156                        .join(", ")
157                ),
158
159                None => format!("{version} {via}"),
160            };
161
162            headers.insert(KnownHeaderName::Via, via);
163        };
164    }
165}
166
167#[derive(Debug)]
168struct UpstreamUpgrade(Upgrade);
169
170#[async_trait]
171impl<U: UpstreamSelector> Handler for Proxy<U> {
172    async fn init(&mut self, _info: &mut trillium::Info) {
173        log::info!("proxying to {:?}", self.upstream);
174    }
175
176    async fn run(&self, mut conn: Conn) -> Conn {
177        let Some(request_url) = self.upstream.determine_upstream(&mut conn) else {
178            return conn;
179        };
180
181        log::debug!("proxying to {}", request_url.as_str());
182
183        let mut forwarded = Forwarded::from_headers(conn.request_headers())
184            .ok()
185            .flatten()
186            .unwrap_or_default()
187            .into_owned();
188
189        if let Some(peer_ip) = conn.peer_ip() {
190            forwarded.add_for(peer_ip.to_string());
191        };
192
193        if let Some(host) = conn.inner().host() {
194            forwarded.set_host(host);
195        }
196
197        let mut request_headers = conn
198            .request_headers()
199            .clone()
200            .without_headers([
201                KnownHeaderName::Connection,
202                KnownHeaderName::KeepAlive,
203                KnownHeaderName::ProxyAuthenticate,
204                KnownHeaderName::ProxyAuthorization,
205                KnownHeaderName::Te,
206                KnownHeaderName::Trailer,
207                KnownHeaderName::TransferEncoding,
208                KnownHeaderName::Upgrade,
209                KnownHeaderName::Host,
210                KnownHeaderName::XforwardedBy,
211                KnownHeaderName::XforwardedFor,
212                KnownHeaderName::XforwardedHost,
213                KnownHeaderName::XforwardedProto,
214                KnownHeaderName::XforwardedSsl,
215            ])
216            .with_inserted_header(KnownHeaderName::Forwarded, forwarded.to_string());
217
218        let mut connection_is_upgrade = false;
219        for header in conn
220            .request_headers()
221            .get_str(KnownHeaderName::Connection)
222            .unwrap_or_default()
223            .split(',')
224            .map(|h| HeaderName::from(h.trim()))
225        {
226            if header == KnownHeaderName::Upgrade {
227                connection_is_upgrade = true;
228            }
229            request_headers.remove(header);
230        }
231
232        if self.allow_websocket_upgrade
233            && connection_is_upgrade
234            && conn
235                .request_headers()
236                .eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")
237        {
238            request_headers.extend([
239                (KnownHeaderName::Upgrade, "WebSocket"),
240                (KnownHeaderName::Connection, "Upgrade"),
241            ]);
242        }
243
244        self.set_via_pseudonym(&mut request_headers, conn.inner().http_version());
245        let content_length = !matches!(
246            conn.request_headers()
247                .get_str(KnownHeaderName::ContentLength),
248            Some("0") | None
249        );
250
251        let chunked = conn
252            .request_headers()
253            .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked");
254        let method = conn.method();
255        let conn_result = if chunked || content_length {
256            let (body_fut, request_body) = stream_body(&mut conn);
257
258            let client_fut = self
259                .client
260                .build_conn(method, request_url)
261                .with_request_headers(request_headers)
262                .with_body(request_body)
263                .into_future();
264
265            zip(body_fut, client_fut).await.1
266        } else {
267            self.client
268                .build_conn(method, request_url)
269                .with_request_headers(request_headers)
270                .await
271        };
272
273        let mut client_conn = match conn_result {
274            Ok(client_conn) => client_conn,
275            Err(e) => {
276                return conn
277                    .with_status(Status::ServiceUnavailable)
278                    .halt()
279                    .with_state(e);
280            }
281        };
282
283        let mut conn = match client_conn.status() {
284            Some(SwitchingProtocols) => {
285                conn.response_headers_mut()
286                    .extend(std::mem::take(client_conn.response_headers_mut()));
287
288                conn.with_state(UpstreamUpgrade(Upgrade::from(client_conn)))
289                    .with_status(SwitchingProtocols)
290            }
291
292            Some(NotFound) if self.pass_through_not_found => {
293                client_conn.recycle().await;
294                return conn;
295            }
296
297            Some(status) => {
298                conn.response_headers_mut()
299                    .append_all(client_conn.response_headers().clone());
300                conn.with_body(client_conn).with_status(status)
301            }
302
303            None => return conn.with_status(Status::ServiceUnavailable).halt(),
304        };
305
306        let connection = conn
307            .response_headers_mut()
308            .remove(KnownHeaderName::Connection);
309
310        conn.response_headers_mut().remove_all(
311            connection
312                .iter()
313                .flatten()
314                .filter_map(|s| s.as_str())
315                .flat_map(|s| s.split(','))
316                .map(|t| HeaderName::from(t.trim()).into_owned()),
317        );
318
319        conn.response_headers_mut().remove_all([
320            KnownHeaderName::KeepAlive,
321            KnownHeaderName::ProxyAuthenticate,
322            KnownHeaderName::ProxyAuthorization,
323            KnownHeaderName::Te,
324            KnownHeaderName::Trailer,
325            KnownHeaderName::TransferEncoding,
326        ]);
327
328        self.set_via_pseudonym(conn.response_headers_mut(), Version::Http1_1);
329
330        if self.halt {
331            conn.halt()
332        } else {
333            conn
334        }
335    }
336
337    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
338        upgrade.state.contains::<UpstreamUpgrade>()
339    }
340
341    async fn upgrade(&self, mut upgrade: Upgrade) {
342        let Some(UpstreamUpgrade(upstream)) = upgrade.state.take() else {
343            return;
344        };
345        let downstream = upgrade;
346        match full_duplex_copy(upstream, downstream).await {
347            Err(e) => log::error!("upgrade stream error: {:?}", e),
348            Ok((up, down)) => {
349                log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
350            }
351        }
352    }
353}
354
355fn bytes(bytes: u64) -> String {
356    Size::from_bytes(bytes)
357        .format()
358        .with_base(Base::Base10)
359        .to_string()
360}