Skip to main content

vellaveto_http_proxy/proxy/
origin.rs

1// Copyright 2026 Paolo Vella
2// SPDX-License-Identifier: BUSL-1.1
3//
4// Use of this software is governed by the Business Source License
5// included in the LICENSE-BSL-1.1 file at the root of this repository.
6//
7// Change Date: Three years from the date of publication of this version.
8// Change License: MPL-2.0
9
10//! CSRF and DNS rebinding origin validation.
11
12use axum::{
13    http::{HeaderMap, StatusCode},
14    response::{IntoResponse, Response},
15    Json,
16};
17use serde_json::json;
18use std::net::SocketAddr;
19
20/// Returns `true` if the socket address is bound to a loopback interface
21/// (IPv4 `127.0.0.0/8` or IPv6 `::1`).  Used to determine whether
22/// automatic localhost-only origin validation should be applied when
23/// `allowed_origins` is empty.
24pub fn is_loopback_addr(addr: &SocketAddr) -> bool {
25    match addr {
26        SocketAddr::V4(v4) => v4.ip().is_loopback(),
27        SocketAddr::V6(v6) => v6.ip().is_loopback(),
28    }
29}
30
31/// Loopback host names used to build the automatic localhost origin allowlist.
32const LOOPBACK_HOSTS: &[&str] = &["localhost", "127.0.0.1", "[::1]"];
33
34/// Build the set of allowed origins for a loopback bind address.
35///
36/// Given a port, returns origins like `http://localhost:<port>`,
37/// `http://127.0.0.1:<port>`, `http://[::1]:<port>` (and their `https://`
38/// equivalents).
39pub fn build_loopback_origins(port: u16) -> Vec<String> {
40    let mut origins = Vec::with_capacity(LOOPBACK_HOSTS.len() * 2);
41    for host in LOOPBACK_HOSTS {
42        origins.push(format!("http://{host}:{port}"));
43        origins.push(format!("https://{host}:{port}"));
44    }
45    origins
46}
47
48/// Validate the Origin header for CSRF and DNS rebinding protection.
49///
50/// DNS rebinding defense (CVE-2025-66414/CVE-2025-66416): When the proxy is
51/// bound to a loopback address (`127.0.0.1`, `[::1]`) and no explicit
52/// `allowed_origins` are configured, only localhost origins are accepted.
53/// This prevents a malicious webpage from rebinding its domain to 127.0.0.1
54/// and making cross-origin requests that bypass browser same-origin policy.
55///
56/// Returns `Ok(())` if:
57/// - No `Origin` header is present (non-browser client — API clients don't send Origin)
58/// - `allowed_origins` is non-empty and contains the Origin value (or `"*"`)
59/// - `allowed_origins` is empty, bind address is loopback, and Origin is a localhost variant
60/// - `allowed_origins` is empty, bind address is non-loopback, and Origin host matches Host header
61///
62/// Returns `Err(response)` with HTTP 403 and a JSON-RPC error if the origin is not allowed.
63///
64/// SECURITY: Logs rejected origins at warn level. Does NOT log Cookie or
65/// Authorization headers to avoid credential leaks in logs.
66#[allow(clippy::result_large_err)]
67pub fn validate_origin(
68    headers: &HeaderMap,
69    bind_addr: &SocketAddr,
70    allowed_origins: &[String],
71) -> Result<(), Response> {
72    // If no Origin header present, allow (non-browser client)
73    let origin = match headers.get("origin").and_then(|o| o.to_str().ok()) {
74        Some(o) => o,
75        None => return Ok(()),
76    };
77
78    // If explicit allowlist is configured, use it
79    if !allowed_origins.is_empty() {
80        if allowed_origins.iter().any(|a| a == origin || a == "*") {
81            // SECURITY (FIND-R51-011): Warn when wildcard "*" disables origin protection.
82            if allowed_origins.iter().any(|o| o == "*") {
83                tracing::warn!(
84                    target: "vellaveto::security",
85                    "SECURITY: allowed_origins contains '*' — CSRF and DNS rebinding protection is DISABLED"
86                );
87            }
88            return Ok(());
89        }
90        tracing::warn!(
91            origin = %origin,
92            "DNS rebinding defense: rejected request with Origin not in allowed_origins"
93        );
94        return Err(make_origin_rejection_response());
95    }
96
97    // No explicit allowlist — use automatic detection based on bind address
98    if is_loopback_addr(bind_addr) {
99        // SECURITY (TASK-015): DNS rebinding defense for localhost-bound proxies.
100        // Only accept origins that resolve to loopback addresses.
101        // A DNS rebinding attack would present an Origin like "http://evil.com"
102        // even though the request reaches 127.0.0.1 — we must reject it.
103        let loopback_origins = build_loopback_origins(bind_addr.port());
104        if loopback_origins.iter().any(|lo| lo == origin) {
105            return Ok(());
106        }
107        tracing::warn!(
108            origin = %origin,
109            bind_addr = %bind_addr,
110            "DNS rebinding defense: rejected non-localhost Origin on loopback-bound proxy"
111        );
112        return Err(make_origin_rejection_response());
113    }
114
115    // Non-loopback bind: fall back to same-origin check (Origin host must match Host header)
116    // SECURITY (R23-PROXY-3): Lowercase the Host header for case-insensitive
117    // comparison — DNS names are case-insensitive per RFC 4343, and
118    // extract_authority_from_origin already lowercases the Origin authority.
119    let host_raw = headers
120        .get("host")
121        .and_then(|h| h.to_str().ok())
122        .unwrap_or("");
123    let host = host_raw.to_lowercase();
124    let host = host.as_str();
125
126    // Extract host:port from origin URL (e.g., "http://localhost:3001" -> "localhost:3001")
127    if let Some(origin_authority) = extract_authority_from_origin(origin) {
128        if origin_authority == host {
129            return Ok(());
130        }
131        // Also match if host lacks a port (e.g., origin "http://localhost:3001" vs host "localhost")
132        if let Some(colon_pos) = origin_authority.rfind(':') {
133            if &origin_authority[..colon_pos] == host {
134                return Ok(());
135            }
136        }
137    }
138
139    tracing::warn!(
140        origin = %origin,
141        host = %host_raw,
142        "CSRF protection: rejected request with mismatched Origin and Host"
143    );
144    Err(make_origin_rejection_response())
145}
146
147/// Build a 403 Forbidden response with a JSON-RPC error body for origin rejection.
148///
149/// Returns a JSON-RPC 2.0 error response instead of a plain REST error because
150/// the HTTP proxy speaks the MCP JSON-RPC protocol. Clients expect errors in
151/// the format `{ "jsonrpc": "2.0", "error": { "code": <int>, "message": <string> } }`.
152/// Code `-32001` is a server-defined error in the JSON-RPC reserved range
153/// (`-32000` to `-32099`), used here for origin/CSRF rejections.
154///
155/// FIND-R56-HTTP-004: Removed unused `origin` parameter — the origin value is
156/// intentionally NOT included in the response body to prevent information leakage.
157pub fn make_origin_rejection_response() -> Response {
158    (
159        StatusCode::FORBIDDEN,
160        Json(json!({
161            "jsonrpc": "2.0",
162            "error": {
163                "code": -32001,
164                "message": "Origin not allowed"
165            }
166        })),
167    )
168        .into_response()
169}
170
171/// Extract the authority (host:port) from an origin URL string.
172///
173/// E.g., `"http://localhost:3001"` -> `Some("localhost:3001")`
174/// E.g., `"https://example.com"` -> `Some("example.com")`
175///
176/// Returns `None` if the URL cannot be parsed.
177pub fn extract_authority_from_origin(origin: &str) -> Option<String> {
178    // Origin format per RFC 6454: "scheme://host[:port]"
179    // Defence-in-depth: strip path, query, fragment, and userinfo even though
180    // a valid Origin header should never contain them.
181    let authority_start = origin.find("://").map(|i| i + 3)?;
182    let authority = &origin[authority_start..];
183    // Strip path, query, and fragment
184    let authority = authority.split('/').next().unwrap_or(authority);
185    let authority = authority.split('?').next().unwrap_or(authority);
186    let authority = authority.split('#').next().unwrap_or(authority);
187    // Strip userinfo (RFC 3986 §3.2.1: userinfo@host)
188    let authority = if let Some(at_pos) = authority.rfind('@') {
189        &authority[at_pos + 1..]
190    } else {
191        authority
192    };
193    // Validate: authority must only contain alphanumeric, '.', '-', ':', '[', ']'
194    // (brackets for IPv6 like [::1]:3001)
195    if authority.is_empty()
196        || !authority
197            .chars()
198            .all(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | ':' | '[' | ']'))
199    {
200        return None;
201    }
202    Some(authority.to_lowercase())
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
209
210    // =========================================================================
211    // is_loopback_addr tests
212    // =========================================================================
213
214    #[test]
215    fn test_is_loopback_addr_ipv4_localhost() {
216        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
217        assert!(is_loopback_addr(&addr));
218    }
219
220    #[test]
221    fn test_is_loopback_addr_ipv4_127_range() {
222        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 42), 3000));
223        assert!(is_loopback_addr(&addr));
224    }
225
226    #[test]
227    fn test_is_loopback_addr_ipv6_localhost() {
228        let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 3000, 0, 0));
229        assert!(is_loopback_addr(&addr));
230    }
231
232    #[test]
233    fn test_is_loopback_addr_non_loopback_ipv4() {
234        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3000));
235        assert!(!is_loopback_addr(&addr));
236    }
237
238    #[test]
239    fn test_is_loopback_addr_non_loopback_ipv6() {
240        let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 3000, 0, 0));
241        assert!(!is_loopback_addr(&addr));
242    }
243
244    // =========================================================================
245    // build_loopback_origins tests
246    // =========================================================================
247
248    #[test]
249    fn test_build_loopback_origins_includes_all_variants() {
250        let origins = build_loopback_origins(3000);
251        assert_eq!(origins.len(), 6);
252        assert!(origins.contains(&"http://localhost:3000".to_string()));
253        assert!(origins.contains(&"https://localhost:3000".to_string()));
254        assert!(origins.contains(&"http://127.0.0.1:3000".to_string()));
255        assert!(origins.contains(&"https://127.0.0.1:3000".to_string()));
256        assert!(origins.contains(&"http://[::1]:3000".to_string()));
257        assert!(origins.contains(&"https://[::1]:3000".to_string()));
258    }
259
260    #[test]
261    fn test_build_loopback_origins_different_port() {
262        let origins = build_loopback_origins(8443);
263        assert!(origins.contains(&"http://localhost:8443".to_string()));
264        assert!(origins.contains(&"https://[::1]:8443".to_string()));
265    }
266
267    // =========================================================================
268    // extract_authority_from_origin tests
269    // =========================================================================
270
271    #[test]
272    fn test_extract_authority_http_localhost_port() {
273        assert_eq!(
274            extract_authority_from_origin("http://localhost:3001"),
275            Some("localhost:3001".to_string())
276        );
277    }
278
279    #[test]
280    fn test_extract_authority_https_domain() {
281        assert_eq!(
282            extract_authority_from_origin("https://example.com"),
283            Some("example.com".to_string())
284        );
285    }
286
287    #[test]
288    fn test_extract_authority_strips_path() {
289        assert_eq!(
290            extract_authority_from_origin("http://example.com:8080/path/to/page"),
291            Some("example.com:8080".to_string())
292        );
293    }
294
295    #[test]
296    fn test_extract_authority_strips_query() {
297        assert_eq!(
298            extract_authority_from_origin("http://example.com?query=val"),
299            Some("example.com".to_string())
300        );
301    }
302
303    #[test]
304    fn test_extract_authority_strips_fragment() {
305        assert_eq!(
306            extract_authority_from_origin("http://example.com#section"),
307            Some("example.com".to_string())
308        );
309    }
310
311    #[test]
312    fn test_extract_authority_strips_userinfo() {
313        assert_eq!(
314            extract_authority_from_origin("http://user:pass@example.com:8080"),
315            Some("example.com:8080".to_string())
316        );
317    }
318
319    #[test]
320    fn test_extract_authority_lowercases() {
321        assert_eq!(
322            extract_authority_from_origin("http://EXAMPLE.COM"),
323            Some("example.com".to_string())
324        );
325    }
326
327    #[test]
328    fn test_extract_authority_ipv6() {
329        assert_eq!(
330            extract_authority_from_origin("http://[::1]:3001"),
331            Some("[::1]:3001".to_string())
332        );
333    }
334
335    #[test]
336    fn test_extract_authority_no_scheme_returns_none() {
337        assert_eq!(extract_authority_from_origin("example.com"), None);
338    }
339
340    #[test]
341    fn test_extract_authority_empty_authority_returns_none() {
342        assert_eq!(extract_authority_from_origin("http:///path"), None);
343    }
344
345    #[test]
346    fn test_extract_authority_invalid_chars_returns_none() {
347        assert_eq!(
348            extract_authority_from_origin("http://example.com<script>"),
349            None
350        );
351    }
352
353    // =========================================================================
354    // validate_origin tests
355    // =========================================================================
356
357    #[test]
358    fn test_validate_origin_no_header_allows() {
359        let headers = HeaderMap::new();
360        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
361        assert!(validate_origin(&headers, &bind, &[]).is_ok());
362    }
363
364    #[test]
365    fn test_validate_origin_allowlist_match() {
366        let mut headers = HeaderMap::new();
367        headers.insert("origin", "http://app.example.com".parse().unwrap());
368        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8080));
369        let allowed = vec!["http://app.example.com".to_string()];
370        assert!(validate_origin(&headers, &bind, &allowed).is_ok());
371    }
372
373    #[test]
374    fn test_validate_origin_allowlist_wildcard() {
375        let mut headers = HeaderMap::new();
376        headers.insert("origin", "http://any-origin.example.com".parse().unwrap());
377        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8080));
378        let allowed = vec!["*".to_string()];
379        assert!(validate_origin(&headers, &bind, &allowed).is_ok());
380    }
381
382    #[test]
383    fn test_validate_origin_allowlist_mismatch_rejected() {
384        let mut headers = HeaderMap::new();
385        headers.insert("origin", "http://evil.com".parse().unwrap());
386        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8080));
387        let allowed = vec!["http://trusted.com".to_string()];
388        assert!(validate_origin(&headers, &bind, &allowed).is_err());
389    }
390
391    #[test]
392    fn test_validate_origin_loopback_localhost_allowed() {
393        let mut headers = HeaderMap::new();
394        headers.insert("origin", "http://localhost:3000".parse().unwrap());
395        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
396        assert!(validate_origin(&headers, &bind, &[]).is_ok());
397    }
398
399    #[test]
400    fn test_validate_origin_loopback_rejects_external() {
401        let mut headers = HeaderMap::new();
402        headers.insert("origin", "http://evil.com".parse().unwrap());
403        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
404        assert!(validate_origin(&headers, &bind, &[]).is_err());
405    }
406
407    #[test]
408    fn test_validate_origin_non_loopback_same_origin_allowed() {
409        let mut headers = HeaderMap::new();
410        headers.insert("origin", "http://myserver.com:8080".parse().unwrap());
411        headers.insert("host", "myserver.com:8080".parse().unwrap());
412        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
413        assert!(validate_origin(&headers, &bind, &[]).is_ok());
414    }
415
416    #[test]
417    fn test_validate_origin_non_loopback_mismatch_rejected() {
418        let mut headers = HeaderMap::new();
419        headers.insert("origin", "http://evil.com:8080".parse().unwrap());
420        headers.insert("host", "myserver.com:8080".parse().unwrap());
421        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
422        assert!(validate_origin(&headers, &bind, &[]).is_err());
423    }
424
425    #[test]
426    fn test_validate_origin_host_without_port_matches_origin_with_port() {
427        let mut headers = HeaderMap::new();
428        headers.insert("origin", "http://myserver.com:8080".parse().unwrap());
429        headers.insert("host", "myserver.com".parse().unwrap());
430        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
431        assert!(validate_origin(&headers, &bind, &[]).is_ok());
432    }
433
434    #[test]
435    fn test_validate_origin_case_insensitive_host() {
436        let mut headers = HeaderMap::new();
437        headers.insert("origin", "http://MyServer.COM:8080".parse().unwrap());
438        headers.insert("host", "MYSERVER.com:8080".parse().unwrap());
439        let bind = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
440        assert!(validate_origin(&headers, &bind, &[]).is_ok());
441    }
442}