vellaveto_http_proxy/proxy/
origin.rs1use axum::{
13 http::{HeaderMap, StatusCode},
14 response::{IntoResponse, Response},
15 Json,
16};
17use serde_json::json;
18use std::net::SocketAddr;
19
20pub fn is_loopback_addr(addr: &SocketAddr) -> bool {
25 match addr {
26 SocketAddr::V4(v4) => v4.ip().is_loopback(),
27 SocketAddr::V6(v6) => v6.ip().is_loopback(),
28 }
29}
30
31const LOOPBACK_HOSTS: &[&str] = &["localhost", "127.0.0.1", "[::1]"];
33
34pub fn build_loopback_origins(port: u16) -> Vec<String> {
40 let mut origins = Vec::with_capacity(LOOPBACK_HOSTS.len() * 2);
41 for host in LOOPBACK_HOSTS {
42 origins.push(format!("http://{host}:{port}"));
43 origins.push(format!("https://{host}:{port}"));
44 }
45 origins
46}
47
48#[allow(clippy::result_large_err)]
67pub fn validate_origin(
68 headers: &HeaderMap,
69 bind_addr: &SocketAddr,
70 allowed_origins: &[String],
71) -> Result<(), Response> {
72 let origin = match headers.get("origin").and_then(|o| o.to_str().ok()) {
74 Some(o) => o,
75 None => return Ok(()),
76 };
77
78 if !allowed_origins.is_empty() {
80 if allowed_origins.iter().any(|a| a == origin || a == "*") {
81 if allowed_origins.iter().any(|o| o == "*") {
83 tracing::warn!(
84 target: "vellaveto::security",
85 "SECURITY: allowed_origins contains '*' — CSRF and DNS rebinding protection is DISABLED"
86 );
87 }
88 return Ok(());
89 }
90 tracing::warn!(
91 origin = %origin,
92 "DNS rebinding defense: rejected request with Origin not in allowed_origins"
93 );
94 return Err(make_origin_rejection_response());
95 }
96
97 if is_loopback_addr(bind_addr) {
99 let loopback_origins = build_loopback_origins(bind_addr.port());
104 if loopback_origins.iter().any(|lo| lo == origin) {
105 return Ok(());
106 }
107 tracing::warn!(
108 origin = %origin,
109 bind_addr = %bind_addr,
110 "DNS rebinding defense: rejected non-localhost Origin on loopback-bound proxy"
111 );
112 return Err(make_origin_rejection_response());
113 }
114
115 let host_raw = headers
120 .get("host")
121 .and_then(|h| h.to_str().ok())
122 .unwrap_or("");
123 let host = host_raw.to_lowercase();
124 let host = host.as_str();
125
126 if let Some(origin_authority) = extract_authority_from_origin(origin) {
128 if origin_authority == host {
129 return Ok(());
130 }
131 if let Some(colon_pos) = origin_authority.rfind(':') {
133 if &origin_authority[..colon_pos] == host {
134 return Ok(());
135 }
136 }
137 }
138
139 tracing::warn!(
140 origin = %origin,
141 host = %host_raw,
142 "CSRF protection: rejected request with mismatched Origin and Host"
143 );
144 Err(make_origin_rejection_response())
145}
146
147pub fn make_origin_rejection_response() -> Response {
158 (
159 StatusCode::FORBIDDEN,
160 Json(json!({
161 "jsonrpc": "2.0",
162 "error": {
163 "code": -32001,
164 "message": "Origin not allowed"
165 }
166 })),
167 )
168 .into_response()
169}
170
171pub fn extract_authority_from_origin(origin: &str) -> Option<String> {
178 let authority_start = origin.find("://").map(|i| i + 3)?;
182 let authority = &origin[authority_start..];
183 let authority = authority.split('/').next().unwrap_or(authority);
185 let authority = authority.split('?').next().unwrap_or(authority);
186 let authority = authority.split('#').next().unwrap_or(authority);
187 let authority = if let Some(at_pos) = authority.rfind('@') {
189 &authority[at_pos + 1..]
190 } else {
191 authority
192 };
193 if authority.is_empty()
196 || !authority
197 .chars()
198 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | ':' | '[' | ']'))
199 {
200 return None;
201 }
202 Some(authority.to_lowercase())
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
209
210 #[test]
215 fn test_is_loopback_addr_ipv4_localhost() {
216 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
217 assert!(is_loopback_addr(&addr));
218 }
219
220 #[test]
221 fn test_is_loopback_addr_ipv4_127_range() {
222 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 42), 3000));
223 assert!(is_loopback_addr(&addr));
224 }
225
226 #[test]
227 fn test_is_loopback_addr_ipv6_localhost() {
228 let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 3000, 0, 0));
229 assert!(is_loopback_addr(&addr));
230 }
231
232 #[test]
233 fn test_is_loopback_addr_non_loopback_ipv4() {
234 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3000));
235 assert!(!is_loopback_addr(&addr));
236 }
237
238 #[test]
239 fn test_is_loopback_addr_non_loopback_ipv6() {
240 let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 3000, 0, 0));
241 assert!(!is_loopback_addr(&addr));
242 }
243
244 #[test]
249 fn test_build_loopback_origins_includes_all_variants() {
250 let origins = build_loopback_origins(3000);
251 assert_eq!(origins.len(), 6);
252 assert!(origins.contains(&"http://localhost:3000".to_string()));
253 assert!(origins.contains(&"https://localhost:3000".to_string()));
254 assert!(origins.contains(&"http://127.0.0.1:3000".to_string()));
255 assert!(origins.contains(&"https://127.0.0.1:3000".to_string()));
256 assert!(origins.contains(&"http://[::1]:3000".to_string()));
257 assert!(origins.contains(&"https://[::1]:3000".to_string()));
258 }
259
260 #[test]
261 fn test_build_loopback_origins_different_port() {
262 let origins = build_loopback_origins(8443);
263 assert!(origins.contains(&"http://localhost:8443".to_string()));
264 assert!(origins.contains(&"https://[::1]:8443".to_string()));
265 }
266
267 #[test]
272 fn test_extract_authority_http_localhost_port() {
273 assert_eq!(
274 extract_authority_from_origin("http://localhost:3001"),
275 Some("localhost:3001".to_string())
276 );
277 }
278
279 #[test]
280 fn test_extract_authority_https_domain() {
281 assert_eq!(
282 extract_authority_from_origin("https://example.com"),
283 Some("example.com".to_string())
284 );
285 }
286
287 #[test]
288 fn test_extract_authority_strips_path() {
289 assert_eq!(
290 extract_authority_from_origin("http://example.com:8080/path/to/page"),
291 Some("example.com:8080".to_string())
292 );
293 }
294
295 #[test]
296 fn test_extract_authority_strips_query() {
297 assert_eq!(
298 extract_authority_from_origin("http://example.com?query=val"),
299 Some("example.com".to_string())
300 );
301 }
302
303 #[test]
304 fn test_extract_authority_strips_fragment() {
305 assert_eq!(
306 extract_authority_from_origin("http://example.com#section"),
307 Some("example.com".to_string())
308 );
309 }
310
311 #[test]
312 fn test_extract_authority_strips_userinfo() {
313 assert_eq!(
314 extract_authority_from_origin("http://user:pass@example.com:8080"),
315 Some("example.com:8080".to_string())
316 );
317 }
318
319 #[test]
320 fn test_extract_authority_lowercases() {
321 assert_eq!(
322 extract_authority_from_origin("http://EXAMPLE.COM"),
323 Some("example.com".to_string())
324 );
325 }
326
327 #[test]
328 fn test_extract_authority_ipv6() {
329 assert_eq!(
330 extract_authority_from_origin("http://[::1]:3001"),
331 Some("[::1]:3001".to_string())
332 );
333 }
334
335 #[test]
336 fn test_extract_authority_no_scheme_returns_none() {
337 assert_eq!(extract_authority_from_origin("example.com"), None);
338 }
339
340 #[test]
341 fn test_extract_authority_empty_authority_returns_none() {
342 assert_eq!(extract_authority_from_origin("http:///path"), None);
343 }
344
345 #[test]
346 fn test_extract_authority_invalid_chars_returns_none() {
347 assert_eq!(
348 extract_authority_from_origin("http://example.com<script>"),
349 None
350 );
351 }
352
353 #[test]
358 fn test_validate_origin_no_header_allows() {
359 let headers = HeaderMap::new();
360 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
361 assert!(validate_origin(&headers, &bind, &[]).is_ok());
362 }
363
364 #[test]
365 fn test_validate_origin_allowlist_match() {
366 let mut headers = HeaderMap::new();
367 headers.insert("origin", "http://app.example.com".parse().unwrap());
368 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8080));
369 let allowed = vec!["http://app.example.com".to_string()];
370 assert!(validate_origin(&headers, &bind, &allowed).is_ok());
371 }
372
373 #[test]
374 fn test_validate_origin_allowlist_wildcard() {
375 let mut headers = HeaderMap::new();
376 headers.insert("origin", "http://any-origin.example.com".parse().unwrap());
377 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8080));
378 let allowed = vec!["*".to_string()];
379 assert!(validate_origin(&headers, &bind, &allowed).is_ok());
380 }
381
382 #[test]
383 fn test_validate_origin_allowlist_mismatch_rejected() {
384 let mut headers = HeaderMap::new();
385 headers.insert("origin", "http://evil.com".parse().unwrap());
386 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8080));
387 let allowed = vec!["http://trusted.com".to_string()];
388 assert!(validate_origin(&headers, &bind, &allowed).is_err());
389 }
390
391 #[test]
392 fn test_validate_origin_loopback_localhost_allowed() {
393 let mut headers = HeaderMap::new();
394 headers.insert("origin", "http://localhost:3000".parse().unwrap());
395 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
396 assert!(validate_origin(&headers, &bind, &[]).is_ok());
397 }
398
399 #[test]
400 fn test_validate_origin_loopback_rejects_external() {
401 let mut headers = HeaderMap::new();
402 headers.insert("origin", "http://evil.com".parse().unwrap());
403 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
404 assert!(validate_origin(&headers, &bind, &[]).is_err());
405 }
406
407 #[test]
408 fn test_validate_origin_non_loopback_same_origin_allowed() {
409 let mut headers = HeaderMap::new();
410 headers.insert("origin", "http://myserver.com:8080".parse().unwrap());
411 headers.insert("host", "myserver.com:8080".parse().unwrap());
412 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
413 assert!(validate_origin(&headers, &bind, &[]).is_ok());
414 }
415
416 #[test]
417 fn test_validate_origin_non_loopback_mismatch_rejected() {
418 let mut headers = HeaderMap::new();
419 headers.insert("origin", "http://evil.com:8080".parse().unwrap());
420 headers.insert("host", "myserver.com:8080".parse().unwrap());
421 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
422 assert!(validate_origin(&headers, &bind, &[]).is_err());
423 }
424
425 #[test]
426 fn test_validate_origin_host_without_port_matches_origin_with_port() {
427 let mut headers = HeaderMap::new();
428 headers.insert("origin", "http://myserver.com:8080".parse().unwrap());
429 headers.insert("host", "myserver.com".parse().unwrap());
430 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
431 assert!(validate_origin(&headers, &bind, &[]).is_ok());
432 }
433
434 #[test]
435 fn test_validate_origin_case_insensitive_host() {
436 let mut headers = HeaderMap::new();
437 headers.insert("origin", "http://MyServer.COM:8080".parse().unwrap());
438 headers.insert("host", "MYSERVER.com:8080".parse().unwrap());
439 let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
440 assert!(validate_origin(&headers, &bind, &[]).is_ok());
441 }
442}