1use super::BatchRequestConfig;
10use std::{
11 error::Error as StdError,
12 net::{IpAddr, SocketAddr},
13 num::NonZeroU32,
14 str::FromStr,
15};
16
17use crate::DenyUnsafe;
18use forwarded_header_value::ForwardedHeaderValue;
19use http::header::{HeaderName, HeaderValue};
20use ip_network::IpNetwork;
21use jsonrpsee::{server::middleware::http::HostFilterLayer, RpcModule};
22use tower_http::cors::{AllowOrigin, CorsLayer};
23
24const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
25const X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip");
26const FORWARDED: HeaderName = HeaderName::from_static("forwarded");
27
28#[derive(Debug)]
29pub(crate) struct ListenAddrError;
30
31impl std::error::Error for ListenAddrError {}
32
33impl std::fmt::Display for ListenAddrError {
34 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
35 write!(f, "No listen address was successfully bound")
36 }
37}
38
39#[derive(Debug, Copy, Clone)]
41pub enum RpcMethods {
42 Safe,
44 Unsafe,
46 Auto,
48}
49
50impl Default for RpcMethods {
51 fn default() -> Self {
52 RpcMethods::Auto
53 }
54}
55
56impl FromStr for RpcMethods {
57 type Err = String;
58
59 fn from_str(s: &str) -> Result<Self, Self::Err> {
60 match s {
61 "safe" => Ok(RpcMethods::Safe),
62 "unsafe" => Ok(RpcMethods::Unsafe),
63 "auto" => Ok(RpcMethods::Auto),
64 invalid => Err(format!("Invalid rpc methods {invalid}")),
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub(crate) struct RpcSettings {
71 pub(crate) batch_config: BatchRequestConfig,
72 pub(crate) max_connections: u32,
73 pub(crate) max_payload_in_mb: u32,
74 pub(crate) max_payload_out_mb: u32,
75 pub(crate) max_subscriptions_per_connection: u32,
76 pub(crate) max_buffer_capacity_per_connection: u32,
77 pub(crate) rpc_methods: RpcMethods,
78 pub(crate) rate_limit: Option<NonZeroU32>,
79 pub(crate) rate_limit_trust_proxy_headers: bool,
80 pub(crate) rate_limit_whitelisted_ips: Vec<IpNetwork>,
81 pub(crate) cors: CorsLayer,
82 pub(crate) host_filter: Option<HostFilterLayer>,
83}
84
85#[derive(Debug, Clone)]
87pub struct RpcEndpoint {
88 pub listen_addr: SocketAddr,
90 pub batch_config: BatchRequestConfig,
92 pub max_connections: u32,
94 pub max_payload_in_mb: u32,
96 pub max_payload_out_mb: u32,
98 pub max_subscriptions_per_connection: u32,
100 pub max_buffer_capacity_per_connection: u32,
102 pub rate_limit: Option<NonZeroU32>,
104 pub rate_limit_trust_proxy_headers: bool,
106 pub rate_limit_whitelisted_ips: Vec<IpNetwork>,
108 pub cors: Option<Vec<String>>,
110 pub rpc_methods: RpcMethods,
112 pub is_optional: bool,
116 pub retry_random_port: bool,
118}
119
120impl RpcEndpoint {
121 pub(crate) async fn bind(self) -> Result<Listener, Box<dyn StdError + Send + Sync>> {
123 let listener = match tokio::net::TcpListener::bind(self.listen_addr).await {
124 Ok(listener) => listener,
125 Err(_) if self.retry_random_port => {
126 let mut addr = self.listen_addr;
127 addr.set_port(0);
128
129 tokio::net::TcpListener::bind(addr).await?
130 },
131 Err(e) => return Err(e.into()),
132 };
133 let local_addr = listener.local_addr()?;
134 let host_filter = host_filtering(self.cors.is_some(), local_addr);
135 let cors = try_into_cors(self.cors)?;
136
137 Ok(Listener {
138 listener,
139 local_addr,
140 cfg: RpcSettings {
141 batch_config: self.batch_config,
142 max_connections: self.max_connections,
143 max_payload_in_mb: self.max_payload_in_mb,
144 max_payload_out_mb: self.max_payload_out_mb,
145 max_subscriptions_per_connection: self.max_subscriptions_per_connection,
146 max_buffer_capacity_per_connection: self.max_buffer_capacity_per_connection,
147 rpc_methods: self.rpc_methods,
148 rate_limit: self.rate_limit,
149 rate_limit_trust_proxy_headers: self.rate_limit_trust_proxy_headers,
150 rate_limit_whitelisted_ips: self.rate_limit_whitelisted_ips,
151 host_filter,
152 cors,
153 },
154 })
155 }
156}
157
158pub(crate) struct Listener {
160 listener: tokio::net::TcpListener,
161 local_addr: SocketAddr,
162 cfg: RpcSettings,
163}
164
165impl Listener {
166 pub(crate) async fn accept(&mut self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
168 let (sock, remote_addr) = self.listener.accept().await?;
169 Ok((sock, remote_addr))
170 }
171
172 pub fn local_addr(&self) -> SocketAddr {
174 self.local_addr
175 }
176
177 pub fn rpc_settings(&self) -> RpcSettings {
178 self.cfg.clone()
179 }
180}
181
182pub(crate) fn host_filtering(enabled: bool, addr: SocketAddr) -> Option<HostFilterLayer> {
183 if enabled {
184 let hosts = [
187 format!("localhost:{}", addr.port()),
188 format!("127.0.0.1:{}", addr.port()),
189 format!("[::1]:{}", addr.port()),
190 ];
191
192 Some(HostFilterLayer::new(hosts).expect("Valid hosts; qed"))
193 } else {
194 None
195 }
196}
197
198pub(crate) fn build_rpc_api<M: Send + Sync + 'static>(mut rpc_api: RpcModule<M>) -> RpcModule<M> {
199 let mut available_methods = rpc_api.method_names().collect::<Vec<_>>();
200 available_methods.push("rpc_methods");
202 available_methods.sort();
203
204 rpc_api
205 .register_method("rpc_methods", move |_, _, _| {
206 serde_json::json!({
207 "methods": available_methods,
208 })
209 })
210 .expect("infallible all other methods have their own address space; qed");
211
212 rpc_api
213}
214
215pub(crate) fn try_into_cors(
216 maybe_cors: Option<Vec<String>>,
217) -> Result<CorsLayer, Box<dyn StdError + Send + Sync>> {
218 if let Some(cors) = maybe_cors {
219 let mut list = Vec::new();
220
221 for origin in cors {
222 list.push(HeaderValue::from_str(&origin)?)
223 }
224
225 Ok(CorsLayer::new().allow_origin(AllowOrigin::list(list)))
226 } else {
227 Ok(CorsLayer::permissive())
229 }
230}
231
232pub(crate) fn get_proxy_ip<B>(req: &http::Request<B>) -> Option<IpAddr> {
239 if let Some(ip) = req
240 .headers()
241 .get(&FORWARDED)
242 .and_then(|v| v.to_str().ok())
243 .and_then(|v| ForwardedHeaderValue::from_forwarded(v).ok())
244 .and_then(|v| v.remotest_forwarded_for_ip())
245 {
246 return Some(ip);
247 }
248
249 if let Some(ip) = req
250 .headers()
251 .get(&X_FORWARDED_FOR)
252 .and_then(|v| v.to_str().ok())
253 .and_then(|v| ForwardedHeaderValue::from_x_forwarded_for(v).ok())
254 .and_then(|v| v.remotest_forwarded_for_ip())
255 {
256 return Some(ip);
257 }
258
259 if let Some(ip) = req
260 .headers()
261 .get(&X_REAL_IP)
262 .and_then(|v| v.to_str().ok())
263 .and_then(|v| IpAddr::from_str(v).ok())
264 {
265 return Some(ip);
266 }
267
268 None
269}
270
271pub fn deny_unsafe(addr: &SocketAddr, methods: &RpcMethods) -> DenyUnsafe {
273 match (addr.ip().is_loopback(), methods) {
274 (_, RpcMethods::Unsafe) | (true, RpcMethods::Auto) => DenyUnsafe::No,
275 _ => DenyUnsafe::Yes,
276 }
277}
278
279pub(crate) fn format_listen_addrs(addr: &[SocketAddr]) -> String {
280 let mut s = String::new();
281
282 let mut it = addr.iter().peekable();
283
284 while let Some(addr) = it.next() {
285 s.push_str(&addr.to_string());
286
287 if it.peek().is_some() {
288 s.push(',');
289 }
290 }
291
292 if addr.len() == 1 {
293 s.push(',');
294 }
295
296 s
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use hyper::header::HeaderValue;
303 use jsonrpsee::server::{HttpBody, HttpRequest};
304
305 fn request() -> http::Request<HttpBody> {
306 HttpRequest::builder().body(HttpBody::empty()).unwrap()
307 }
308
309 #[test]
310 fn empty_works() {
311 let req = request();
312 let host = get_proxy_ip(&req);
313 assert!(host.is_none())
314 }
315
316 #[test]
317 fn host_from_x_real_ip() {
318 let mut req = request();
319
320 req.headers_mut().insert(&X_REAL_IP, HeaderValue::from_static("127.0.0.1"));
321 let ip = get_proxy_ip(&req);
322 assert_eq!(Some(IpAddr::from_str("127.0.0.1").unwrap()), ip);
323 }
324
325 #[test]
326 fn ip_from_forwarded_works() {
327 let mut req = request();
328
329 req.headers_mut().insert(
330 &FORWARDED,
331 HeaderValue::from_static("for=192.0.2.60;proto=http;by=203.0.113.43;host=example.com"),
332 );
333 let ip = get_proxy_ip(&req);
334 assert_eq!(Some(IpAddr::from_str("192.0.2.60").unwrap()), ip);
335 }
336
337 #[test]
338 fn ip_from_forwarded_multiple() {
339 let mut req = request();
340
341 req.headers_mut().append(&FORWARDED, HeaderValue::from_static("for=127.0.0.1"));
342 req.headers_mut().append(&FORWARDED, HeaderValue::from_static("for=192.0.2.60"));
343 req.headers_mut().append(&FORWARDED, HeaderValue::from_static("for=192.0.2.61"));
344 let ip = get_proxy_ip(&req);
345 assert_eq!(Some(IpAddr::from_str("127.0.0.1").unwrap()), ip);
346 }
347
348 #[test]
349 fn ip_from_x_forwarded_works() {
350 let mut req = request();
351
352 req.headers_mut()
353 .insert(&X_FORWARDED_FOR, HeaderValue::from_static("127.0.0.1,192.0.2.60,0.0.0.1"));
354 let ip = get_proxy_ip(&req);
355 assert_eq!(Some(IpAddr::from_str("127.0.0.1").unwrap()), ip);
356 }
357}