1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 rustdoc::missing_crate_level_docs,
6 missing_debug_implementations,
7 missing_docs,
8 nonstandard_style,
9 unused_qualifications
10)]
11
12#[cfg(test)]
15#[doc = include_str!("../README.md")]
16mod readme {}
17
18mod body_streamer;
19mod forward_proxy_connect;
20pub mod upstream;
21
22use body_streamer::stream_body;
23pub use forward_proxy_connect::ForwardProxyConnect;
24use full_duplex_async_copy::full_duplex_copy;
25use futures_lite::future::zip;
26use size::{Base, Size};
27use std::{borrow::Cow, fmt::Debug, future::IntoFuture};
28use trillium::{
29 Conn, Handler, KnownHeaderName,
30 Status::{NotFound, SwitchingProtocols},
31 Upgrade,
32};
33use trillium_client::ConnExt as _;
34pub use trillium_client::{Client, Connector};
35use trillium_forwarding::Forwarded;
36use trillium_http::{HeaderName, Headers, HttpContext, Status, Version};
37use upstream::{IntoUpstreamSelector, UpstreamSelector};
38pub use url::Url;
39
40pub fn proxy<I>(client: impl Into<Client>, upstream: I) -> Proxy<I::UpstreamSelector>
42where
43 I: IntoUpstreamSelector,
44{
45 Proxy::new(client, upstream)
46}
47
48#[derive(Debug)]
50pub struct Proxy<U> {
51 upstream: U,
52 client: Client,
53 pass_through_not_found: bool,
54 halt: bool,
55 via_pseudonym: Option<Cow<'static, str>>,
56 allow_websocket_upgrade: bool,
57}
58
59impl<U: UpstreamSelector> Proxy<U> {
60 pub fn new<I>(client: impl Into<Client>, upstream: I) -> Self
73 where
74 I: IntoUpstreamSelector<UpstreamSelector = U>,
75 {
76 let client = client
77 .into()
78 .without_default_header(KnownHeaderName::UserAgent)
79 .without_default_header(KnownHeaderName::Accept);
80
81 Self {
82 upstream: upstream.into_upstream(),
83 client,
84 pass_through_not_found: true,
85 halt: true,
86 via_pseudonym: None,
87 allow_websocket_upgrade: false,
88 }
89 }
90
91 pub fn proxy_not_found(mut self) -> Self {
105 self.pass_through_not_found = false;
106 self
107 }
108
109 pub fn without_halting(mut self) -> Self {
123 self.halt = false;
124 self
125 }
126
127 pub fn with_via_pseudonym(mut self, via_pseudonym: impl Into<Cow<'static, str>>) -> Self {
132 self.via_pseudonym = Some(via_pseudonym.into());
133 self
134 }
135
136 pub fn with_websocket_upgrades(mut self) -> Self {
141 self.allow_websocket_upgrade = true;
142 self
143 }
144
145 fn set_via_pseudonym(&self, headers: &mut Headers, version: Version) {
146 if self.via_pseudonym.is_none() {
147 return;
148 }
149
150 use std::fmt::Write;
151 let mut via = String::new();
152 let _ = write!(&mut via, "{version}");
153
154 if let Some(pseudonym) = &self.via_pseudonym {
155 let _ = write!(&mut via, " {pseudonym}");
156 }
157
158 if let Some(old_via) = headers.get_values(KnownHeaderName::Via) {
159 for old_via in old_via {
160 let _ = write!(&mut via, ", {old_via}");
161 }
162 }
163
164 headers.insert(KnownHeaderName::Via, via);
165 }
166}
167
168#[derive(Debug)]
169struct UpstreamUpgrade(Upgrade);
170
171impl<U: UpstreamSelector> Handler for Proxy<U> {
172 async fn init(&mut self, info: &mut trillium::Info) {
173 let old_context = self.client.context();
176 let new_context = HttpContext::default()
177 .with_config(*old_context.config())
178 .with_swansong(info.swansong().clone());
179 self.client.set_context(new_context);
180 log::info!("proxying to {:?}", self.upstream);
181 }
182
183 async fn run(&self, mut conn: Conn) -> Conn {
184 let Some(request_url) = self.upstream.determine_upstream(&mut conn) else {
185 return conn;
186 };
187
188 log::debug!("proxying to {}", request_url.as_str());
189
190 let mut forwarded = Forwarded::from_headers(conn.request_headers())
191 .ok()
192 .flatten()
193 .unwrap_or_default()
194 .into_owned();
195
196 if let Some(peer_ip) = conn.peer_ip() {
197 forwarded.add_for(peer_ip.to_string());
198 };
199
200 if let Some(host) = conn.host() {
201 forwarded.set_host(host);
202 }
203
204 let mut request_headers = conn
205 .request_headers()
206 .clone()
207 .without_headers([
208 KnownHeaderName::Connection,
209 KnownHeaderName::KeepAlive,
210 KnownHeaderName::ProxyAuthenticate,
211 KnownHeaderName::ProxyAuthorization,
212 KnownHeaderName::Te,
213 KnownHeaderName::Trailer,
214 KnownHeaderName::TransferEncoding,
215 KnownHeaderName::Upgrade,
216 KnownHeaderName::Host,
217 KnownHeaderName::XforwardedBy,
218 KnownHeaderName::XforwardedFor,
219 KnownHeaderName::XforwardedHost,
220 KnownHeaderName::XforwardedProto,
221 KnownHeaderName::XforwardedSsl,
222 KnownHeaderName::AltUsed,
223 ])
224 .with_inserted_header(KnownHeaderName::Forwarded, forwarded.to_string());
225
226 let mut connection_is_upgrade = false;
227 for header in conn
228 .request_headers()
229 .get_str(KnownHeaderName::Connection)
230 .unwrap_or_default()
231 .split(',')
232 .map(|h| HeaderName::from(h.trim()))
233 {
234 if header == KnownHeaderName::Upgrade {
235 connection_is_upgrade = true;
236 }
237 request_headers.remove(header);
238 }
239
240 if self.allow_websocket_upgrade
241 && connection_is_upgrade
242 && conn
243 .request_headers()
244 .eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")
245 {
246 request_headers.extend([
247 (KnownHeaderName::Upgrade, "WebSocket"),
248 (KnownHeaderName::Connection, "Upgrade"),
249 ]);
250 }
251
252 self.set_via_pseudonym(&mut request_headers, conn.http_version());
253
254 let content_length = !matches!(
255 conn.request_headers()
256 .get_str(KnownHeaderName::ContentLength),
257 Some("0") | None
258 );
259
260 let chunked = conn
261 .request_headers()
262 .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked");
263
264 let method = conn.method();
265 let conn_result = if chunked || content_length {
266 let (body_fut, request_body) = stream_body(&mut conn);
267
268 let client_fut = self
269 .client
270 .build_conn(method, request_url)
271 .with_request_headers(request_headers)
272 .with_body(request_body)
273 .into_future();
274
275 zip(body_fut, client_fut).await.1
276 } else {
277 self.client
278 .build_conn(method, request_url)
279 .with_request_headers(request_headers)
280 .await
281 };
282
283 let mut client_conn = match conn_result {
284 Ok(client_conn) => client_conn,
285 Err(e) => {
286 return conn
287 .with_status(Status::ServiceUnavailable)
288 .halt()
289 .with_state(e);
290 }
291 };
292
293 let client_conn_version = client_conn.http_version();
294
295 let mut conn = match client_conn.status() {
296 Some(SwitchingProtocols) => {
297 conn.response_headers_mut()
298 .extend(std::mem::take(client_conn.response_headers_mut()));
299
300 conn.with_state(UpstreamUpgrade(
301 trillium_http::Upgrade::from(client_conn).into(),
302 ))
303 .with_status(SwitchingProtocols)
304 }
305
306 Some(NotFound) if self.pass_through_not_found => {
307 client_conn.recycle().await;
308 return conn;
309 }
310
311 Some(status) => {
312 conn.response_headers_mut().remove(KnownHeaderName::Server);
313 conn.response_headers_mut()
314 .append_all(client_conn.response_headers().clone());
315 conn.with_body(client_conn).with_status(status)
316 }
317
318 None => return conn.with_status(Status::ServiceUnavailable).halt(),
319 };
320
321 if Some(SwitchingProtocols) != conn.status()
322 || !conn
323 .response_headers()
324 .eq_ignore_ascii_case(KnownHeaderName::Connection, "Upgrade")
325 {
326 let connection = conn
327 .response_headers_mut()
328 .remove(KnownHeaderName::Connection);
329
330 conn.response_headers_mut().remove_all(
331 connection
332 .iter()
333 .flatten()
334 .filter_map(|s| s.as_str())
335 .flat_map(|s| s.split(','))
336 .map(|t| HeaderName::from(t.trim()).into_owned()),
337 );
338 }
339
340 conn.response_headers_mut().remove_all([
341 KnownHeaderName::KeepAlive,
342 KnownHeaderName::ProxyAuthenticate,
343 KnownHeaderName::ProxyAuthorization,
344 KnownHeaderName::Te,
345 KnownHeaderName::Trailer,
346 KnownHeaderName::TransferEncoding,
347 ]);
348
349 self.set_via_pseudonym(conn.response_headers_mut(), client_conn_version);
350
351 if self.halt { conn.halt() } else { conn }
352 }
353
354 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
355 upgrade.state().contains::<UpstreamUpgrade>()
356 }
357
358 async fn upgrade(&self, mut upgrade: Upgrade) {
359 let Some(UpstreamUpgrade(upstream)) = upgrade.state_mut().take() else {
360 return;
361 };
362 let downstream = upgrade;
363 match full_duplex_copy(upstream, downstream).await {
364 Err(e) => log::error!("upgrade stream error: {:?}", e),
365 Ok((up, down)) => {
366 log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
367 }
368 }
369 }
370}
371
372fn bytes(bytes: u64) -> String {
373 Size::from_bytes(bytes)
374 .format()
375 .with_base(Base::Base10)
376 .to_string()
377}