trillium_forwarding/
lib.rs1#![forbid(unsafe_code)]
21#![deny(
22 missing_copy_implementations,
23 rustdoc::missing_crate_level_docs,
24 missing_debug_implementations,
25 missing_docs,
26 nonstandard_style,
27 unused_qualifications
28)]
29mod forwarded;
30pub use forwarded::Forwarded;
31
32mod parse_utils;
33
34use std::{fmt::Debug, net::IpAddr, ops::Deref};
35use trillium::{async_trait, Conn, Handler, Status};
36
37#[derive(Debug)]
38#[non_exhaustive]
39enum TrustProxy {
40 Always,
41 Never,
42 Cidr(Vec<cidr::AnyIpCidr>),
43 Function(TrustFn),
44}
45
46struct TrustFn(Box<dyn Fn(&IpAddr) -> bool + Send + Sync + 'static>);
47impl<F> From<F> for TrustFn
48where
49 F: Fn(&IpAddr) -> bool + Send + Sync + 'static,
50{
51 fn from(f: F) -> Self {
52 Self(Box::new(f))
53 }
54}
55impl Debug for TrustFn {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_tuple("TrustPredicate").field(&"..").finish()
58 }
59}
60
61impl Deref for TrustFn {
62 type Target = dyn Fn(&IpAddr) -> bool + Send + Sync + 'static;
63
64 fn deref(&self) -> &Self::Target {
65 &self.0
66 }
67}
68
69impl TrustProxy {
70 fn is_trusted(&self, ip: Option<IpAddr>) -> bool {
71 match (self, ip) {
72 (TrustProxy::Always, _) => true,
73 (TrustProxy::Cidr(cidrs), Some(ip)) => cidrs.iter().any(|c| c.contains(&ip)),
74 (TrustProxy::Function(trust_predicate), Some(ip)) => trust_predicate(&ip),
75 _ => false,
76 }
77 }
78}
79
80#[derive(Default, Debug)]
86pub struct Forwarding(TrustProxy);
87
88impl From<TrustProxy> for Forwarding {
89 fn from(tp: TrustProxy) -> Self {
90 Self(tp)
91 }
92}
93
94impl Forwarding {
95 pub fn trust_ips<'a>(ips: impl IntoIterator<Item = &'a str>) -> Self {
104 Self(TrustProxy::Cidr(
105 ips.into_iter().map(|ip| ip.parse().unwrap()).collect(),
106 ))
107 }
108
109 pub fn trust_fn<F>(trust_predicate: F) -> Self
123 where
124 F: Fn(&IpAddr) -> bool + Send + Sync + 'static,
125 {
126 Self(TrustProxy::Function(TrustFn::from(trust_predicate)))
127 }
128
129 pub fn trust_always() -> Self {
140 Self(TrustProxy::Always)
141 }
142}
143
144impl Default for TrustProxy {
145 fn default() -> Self {
146 Self::Never
147 }
148}
149
150#[async_trait]
151impl Handler for Forwarding {
152 async fn run(&self, mut conn: Conn) -> Conn {
153 if !self.0.is_trusted(conn.inner().peer_ip()) {
154 return conn;
155 }
156
157 let forwarded = match Forwarded::from_headers(conn.request_headers()) {
158 Ok(Some(forwarded)) => forwarded.into_owned(),
159 Err(error) => {
160 log::error!("{error}");
161 return conn
162 .halt()
163 .with_state(error)
164 .with_status(Status::BadRequest);
165 }
166 Ok(None) => return conn,
167 };
168
169 log::debug!("received trusted forwarded {:?}", &forwarded);
170
171 let inner_mut = conn.inner_mut();
172
173 if let Some(host) = forwarded.host() {
174 inner_mut.set_host(String::from(host));
175 }
176
177 if let Some(proto) = forwarded.proto() {
178 inner_mut.set_secure(proto == "https");
179 }
180
181 if let Some(ip) = forwarded.forwarded_for().first() {
182 if let Ok(ip_addr) = ip.trim_start_matches('[').trim_end_matches(']').parse() {
183 inner_mut.set_peer_ip(Some(ip_addr));
184 }
185 }
186
187 conn.with_state(forwarded)
188 }
189}