1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 rustdoc::missing_crate_level_docs,
6 missing_debug_implementations,
7 missing_docs,
8 nonstandard_style,
9 unused_qualifications
10)]
11
12mod body_streamer;
18mod forward_proxy_connect;
19pub mod upstream;
20
21use body_streamer::stream_body;
22use full_duplex_async_copy::full_duplex_copy;
23use futures_lite::future::zip;
24use size::{Base, Size};
25use std::{borrow::Cow, fmt::Debug, future::IntoFuture};
26use trillium::{
27 async_trait, Conn, Handler, KnownHeaderName,
28 Status::{NotFound, SwitchingProtocols},
29 Upgrade,
30};
31use trillium_forwarding::Forwarded;
32use trillium_http::{HeaderName, HeaderValue, Headers, Status, Version};
33use upstream::{IntoUpstreamSelector, UpstreamSelector};
34
35pub use forward_proxy_connect::ForwardProxyConnect;
36pub use trillium_client::{Client, Connector};
37pub use url::Url;
38
39pub fn proxy<I>(client: impl Into<Client>, upstream: I) -> Proxy<I::UpstreamSelector>
41where
42 I: IntoUpstreamSelector,
43{
44 Proxy::new(client, upstream)
45}
46
47#[derive(Debug)]
51pub struct Proxy<U> {
52 upstream: U,
53 client: Client,
54 pass_through_not_found: bool,
55 halt: bool,
56 via_pseudonym: Option<Cow<'static, str>>,
57 allow_websocket_upgrade: bool,
58}
59
60impl<U: UpstreamSelector> Proxy<U> {
61 pub fn new<I>(client: impl Into<Client>, upstream: I) -> Self
74 where
75 I: IntoUpstreamSelector<UpstreamSelector = U>,
76 {
77 Self {
78 upstream: upstream.into_upstream(),
79 client: client.into(),
80 pass_through_not_found: true,
81 halt: true,
82 via_pseudonym: None,
83 allow_websocket_upgrade: false,
84 }
85 }
86
87 pub fn proxy_not_found(mut self) -> Self {
104 self.pass_through_not_found = false;
105 self
106 }
107
108 pub fn without_halting(mut self) -> Self {
125 self.halt = false;
126 self
127 }
128
129 pub fn with_via_pseudonym(mut self, via_pseudonym: impl Into<Cow<'static, str>>) -> Self {
134 self.via_pseudonym = Some(via_pseudonym.into());
135 self
136 }
137
138 pub fn with_websocket_upgrades(mut self) -> Self {
143 self.allow_websocket_upgrade = true;
144 self
145 }
146
147 fn set_via_pseudonym(&self, headers: &mut Headers, version: Version) {
148 if let Some(via) = &self.via_pseudonym {
149 let via = match headers.get_values(KnownHeaderName::Via) {
150 Some(old_via) => format!(
151 "{version} {via}, {}",
152 old_via
153 .iter()
154 .filter_map(HeaderValue::as_str)
155 .collect::<Vec<_>>()
156 .join(", ")
157 ),
158
159 None => format!("{version} {via}"),
160 };
161
162 headers.insert(KnownHeaderName::Via, via);
163 };
164 }
165}
166
167#[derive(Debug)]
168struct UpstreamUpgrade(Upgrade);
169
170#[async_trait]
171impl<U: UpstreamSelector> Handler for Proxy<U> {
172 async fn init(&mut self, _info: &mut trillium::Info) {
173 log::info!("proxying to {:?}", self.upstream);
174 }
175
176 async fn run(&self, mut conn: Conn) -> Conn {
177 let Some(request_url) = self.upstream.determine_upstream(&mut conn) else {
178 return conn;
179 };
180
181 log::debug!("proxying to {}", request_url.as_str());
182
183 let mut forwarded = Forwarded::from_headers(conn.request_headers())
184 .ok()
185 .flatten()
186 .unwrap_or_default()
187 .into_owned();
188
189 if let Some(peer_ip) = conn.peer_ip() {
190 forwarded.add_for(peer_ip.to_string());
191 };
192
193 if let Some(host) = conn.inner().host() {
194 forwarded.set_host(host);
195 }
196
197 let mut request_headers = conn
198 .request_headers()
199 .clone()
200 .without_headers([
201 KnownHeaderName::Connection,
202 KnownHeaderName::KeepAlive,
203 KnownHeaderName::ProxyAuthenticate,
204 KnownHeaderName::ProxyAuthorization,
205 KnownHeaderName::Te,
206 KnownHeaderName::Trailer,
207 KnownHeaderName::TransferEncoding,
208 KnownHeaderName::Upgrade,
209 KnownHeaderName::Host,
210 KnownHeaderName::XforwardedBy,
211 KnownHeaderName::XforwardedFor,
212 KnownHeaderName::XforwardedHost,
213 KnownHeaderName::XforwardedProto,
214 KnownHeaderName::XforwardedSsl,
215 ])
216 .with_inserted_header(KnownHeaderName::Forwarded, forwarded.to_string());
217
218 let mut connection_is_upgrade = false;
219 for header in conn
220 .request_headers()
221 .get_str(KnownHeaderName::Connection)
222 .unwrap_or_default()
223 .split(',')
224 .map(|h| HeaderName::from(h.trim()))
225 {
226 if header == KnownHeaderName::Upgrade {
227 connection_is_upgrade = true;
228 }
229 request_headers.remove(header);
230 }
231
232 if self.allow_websocket_upgrade
233 && connection_is_upgrade
234 && conn
235 .request_headers()
236 .eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")
237 {
238 request_headers.extend([
239 (KnownHeaderName::Upgrade, "WebSocket"),
240 (KnownHeaderName::Connection, "Upgrade"),
241 ]);
242 }
243
244 self.set_via_pseudonym(&mut request_headers, conn.inner().http_version());
245 let content_length = !matches!(
246 conn.request_headers()
247 .get_str(KnownHeaderName::ContentLength),
248 Some("0") | None
249 );
250
251 let chunked = conn
252 .request_headers()
253 .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked");
254 let method = conn.method();
255 let conn_result = if chunked || content_length {
256 let (body_fut, request_body) = stream_body(&mut conn);
257
258 let client_fut = self
259 .client
260 .build_conn(method, request_url)
261 .with_request_headers(request_headers)
262 .with_body(request_body)
263 .into_future();
264
265 zip(body_fut, client_fut).await.1
266 } else {
267 self.client
268 .build_conn(method, request_url)
269 .with_request_headers(request_headers)
270 .await
271 };
272
273 let mut client_conn = match conn_result {
274 Ok(client_conn) => client_conn,
275 Err(e) => {
276 return conn
277 .with_status(Status::ServiceUnavailable)
278 .halt()
279 .with_state(e);
280 }
281 };
282
283 let mut conn = match client_conn.status() {
284 Some(SwitchingProtocols) => {
285 conn.response_headers_mut()
286 .extend(std::mem::take(client_conn.response_headers_mut()));
287
288 conn.with_state(UpstreamUpgrade(Upgrade::from(client_conn)))
289 .with_status(SwitchingProtocols)
290 }
291
292 Some(NotFound) if self.pass_through_not_found => {
293 client_conn.recycle().await;
294 return conn;
295 }
296
297 Some(status) => {
298 conn.response_headers_mut()
299 .append_all(client_conn.response_headers().clone());
300 conn.with_body(client_conn).with_status(status)
301 }
302
303 None => return conn.with_status(Status::ServiceUnavailable).halt(),
304 };
305
306 let connection = conn
307 .response_headers_mut()
308 .remove(KnownHeaderName::Connection);
309
310 conn.response_headers_mut().remove_all(
311 connection
312 .iter()
313 .flatten()
314 .filter_map(|s| s.as_str())
315 .flat_map(|s| s.split(','))
316 .map(|t| HeaderName::from(t.trim()).into_owned()),
317 );
318
319 conn.response_headers_mut().remove_all([
320 KnownHeaderName::KeepAlive,
321 KnownHeaderName::ProxyAuthenticate,
322 KnownHeaderName::ProxyAuthorization,
323 KnownHeaderName::Te,
324 KnownHeaderName::Trailer,
325 KnownHeaderName::TransferEncoding,
326 ]);
327
328 self.set_via_pseudonym(conn.response_headers_mut(), Version::Http1_1);
329
330 if self.halt {
331 conn.halt()
332 } else {
333 conn
334 }
335 }
336
337 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
338 upgrade.state.contains::<UpstreamUpgrade>()
339 }
340
341 async fn upgrade(&self, mut upgrade: Upgrade) {
342 let Some(UpstreamUpgrade(upstream)) = upgrade.state.take() else {
343 return;
344 };
345 let downstream = upgrade;
346 match full_duplex_copy(upstream, downstream).await {
347 Err(e) => log::error!("upgrade stream error: {:?}", e),
348 Ok((up, down)) => {
349 log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
350 }
351 }
352 }
353}
354
355fn bytes(bytes: u64) -> String {
356 Size::from_bytes(bytes)
357 .format()
358 .with_base(Base::Base10)
359 .to_string()
360}