1use rfc7239::{parse, Forwarded, NodeIdentifier, NodeName};
2use std::convert::Infallible;
3use std::iter::once;
4use std::net::{IpAddr, SocketAddr};
5use std::str::FromStr;
6use warp::filters::addr::remote;
7use warp::Filter;
8
9pub fn real_ip(
30 trusted_proxies: Vec<IpAddr>,
31) -> impl Filter<Extract = (Option<IpAddr>,), Error = Infallible> + Clone {
32 remote().and(get_forwarded_for()).map(
33 move |addr: Option<SocketAddr>, forwarded_for: Vec<IpAddr>| {
34 addr.map(|addr| {
35 let hops = forwarded_for.iter().copied().chain(once(addr.ip()));
36 for hop in hops.rev() {
37 if !trusted_proxies.contains(&hop) {
38 return hop;
39 }
40 }
41
42 forwarded_for.first().copied().unwrap_or(addr.ip())
44 })
45 },
46 )
47}
48
49pub fn get_forwarded_for() -> impl Filter<Extract = (Vec<IpAddr>,), Error = Infallible> + Clone {
51 warp::header("x-forwarded-for")
52 .map(|list: CommaSeparated<IpAddr>| list.into_inner())
53 .or(warp::header("x-real-ip").map(|ip| vec![ip]))
54 .unify()
55 .or(warp::header("forwarded").map(|header: String| {
56 parse(&header)
57 .filter_map(|forward| match forward {
58 Ok(Forwarded {
59 forwarded_for:
60 Some(NodeIdentifier {
61 name: NodeName::Ip(ip),
62 ..
63 }),
64 ..
65 }) => Some(ip),
66 _ => None,
67 })
68 .collect::<Vec<_>>()
69 }))
70 .unify()
71 .or(warp::any().map(|| vec![]))
72 .unify()
73}
74
75struct CommaSeparated<T>(Vec<T>);
77
78impl<T> CommaSeparated<T> {
79 pub fn into_inner(self) -> Vec<T> {
80 self.0
81 }
82}
83
84impl<T: FromStr> FromStr for CommaSeparated<T> {
85 type Err = T::Err;
86
87 fn from_str(s: &str) -> Result<Self, Self::Err> {
88 let vec = s
89 .split(',')
90 .map(str::trim)
91 .map(T::from_str)
92 .collect::<Result<Vec<_>, _>>()?;
93 Ok(CommaSeparated(vec))
94 }
95}