Skip to main content

tt_shared/
url_guard.rs

1//! SSRF and header-injection guard for customer-supplied provider URLs.
2//!
3//! ## URL guard
4//!
5//! [`validate_provider_url`] checks a customer-controlled `base_url` before any
6//! HTTP request is dispatched. It enforces:
7//!
8//! - Scheme must be `https` (or `http` when `allow_local` is true).
9//! - The host must not be a private/loopback/link-local/ULA IP address, the
10//!   cloud-metadata addresses (169.254.169.254, 100.100.100.200), or the
11//!   hostname `localhost`, `*.local`, or `metadata.google.internal`.
12//! - When `allow_local` is `true` (self-hosted / `local` provider), all of the
13//!   above checks are bypassed — only basic URL parsing is performed.
14//! - A best-effort DNS resolution step rejects the URL if **any** resolved
15//!   address falls in a private range.
16//!
17//! ## DNS rebind caveat
18//!
19//! The DNS resolution step is defense-in-depth only. It is subject to TOCTOU
20//! races: a malicious DNS server can return a safe address for the validation
21//! call and a private address for the actual HTTP connection (classic
22//! DNS-rebinding attack). The connect-time enforcement (binding to a local
23//! policy agent or using a kernel eBPF hook) is out of scope here. Operators
24//! concerned about DNS rebinding should additionally run the gateway behind a
25//! network policy that blocks outbound connections to RFC-1918 ranges.
26//!
27//! ## Header filter
28//!
29//! [`filter_extra_headers`] strips any header whose name could override the
30//! adapter-set auth (`authorization`, `x-api-key`, `anthropic-version`,
31//! `content-type`) or the routing (`host`) headers, or inject hop-by-hop
32//! headers that must not be forwarded to upstream HTTP/1.1 servers.
33
34use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
35
36use thiserror::Error;
37use tracing::warn;
38
39// ---------------------------------------------------------------------------
40// Error type
41// ---------------------------------------------------------------------------
42
43#[derive(Debug, Error, PartialEq, Eq)]
44pub enum UrlGuardError {
45    #[error("invalid URL: {0}")]
46    InvalidUrl(String),
47
48    #[error("insecure scheme '{0}': only https is allowed for hosted providers")]
49    InsecureScheme(String),
50
51    #[error("blocked host '{0}': private/loopback/link-local addresses are not allowed")]
52    BlockedHost(String),
53
54    #[error("blocked hostname '{0}': internal/metadata hostnames are not allowed")]
55    BlockedHostname(String),
56
57    #[error("resolved address for '{host}' is in a blocked range: {addr}")]
58    BlockedResolvedAddress { host: String, addr: IpAddr },
59}
60
61// ---------------------------------------------------------------------------
62// Public API
63// ---------------------------------------------------------------------------
64
65/// Validate a customer-supplied provider base URL.
66///
67/// # Arguments
68///
69/// * `raw` — The raw URL string from customer credentials.
70/// * `allow_local` — When `true`, skip scheme and private-range checks
71///   (for the `local` provider that targets Ollama/vLLM/LM Studio).
72///
73/// # Errors
74///
75/// Returns [`UrlGuardError`] if the URL is malformed, uses an insecure scheme
76/// (when `allow_local` is false), or points to a private/internal host.
77pub fn validate_provider_url(raw: &str, allow_local: bool) -> Result<(), UrlGuardError> {
78    // Parse the URL.
79    let parsed =
80        url::Url::parse(raw).map_err(|e| UrlGuardError::InvalidUrl(format!("{raw}: {e}")))?;
81
82    // Extract scheme.
83    let scheme = parsed.scheme();
84
85    if allow_local {
86        // Local provider: accept http or https, no range checks.
87        if scheme != "http" && scheme != "https" {
88            return Err(UrlGuardError::InsecureScheme(scheme.to_string()));
89        }
90        return Ok(());
91    }
92
93    // Hosted providers: require https.
94    if scheme != "https" {
95        return Err(UrlGuardError::InsecureScheme(scheme.to_string()));
96    }
97
98    // Extract the host string.
99    let host_str = parsed
100        .host_str()
101        .ok_or_else(|| UrlGuardError::InvalidUrl(format!("{raw}: no host")))?;
102
103    // Check hostname denylist first.
104    check_hostname_denylist(host_str)?;
105
106    // If the host is a literal IP, check the range.
107    // The `url` crate returns IPv6 addresses without brackets in host_str(),
108    // but strip brackets defensively in case that behaviour changes.
109    let bare_host = if host_str.starts_with('[') && host_str.ends_with(']') {
110        &host_str[1..host_str.len() - 1]
111    } else {
112        host_str
113    };
114
115    if let Ok(ip) = bare_host.parse::<IpAddr>() {
116        if is_blocked_ip(ip) {
117            return Err(UrlGuardError::BlockedHost(host_str.to_string()));
118        }
119        // Literal IP is safe — skip DNS resolution.
120        return Ok(());
121    }
122
123    // Best-effort DNS resolution to catch SSRF via hostname.
124    // TOCTOU/DNS-rebind note: see module-level doc comment.
125    let port = parsed.port_or_known_default().unwrap_or(443);
126    let lookup_target = format!("{host_str}:{port}");
127    match lookup_target.to_socket_addrs() {
128        Ok(addrs) => {
129            for sa in addrs {
130                let ip = sa.ip();
131                if is_blocked_ip(ip) {
132                    warn!(
133                        host = %host_str,
134                        addr = %ip,
135                        "validate_provider_url: resolved address is in a blocked range"
136                    );
137                    return Err(UrlGuardError::BlockedResolvedAddress {
138                        host: host_str.to_string(),
139                        addr: ip,
140                    });
141                }
142            }
143        }
144        Err(e) => {
145            // DNS failure is not treated as a security block — the upstream
146            // request will fail with a network error, which is fine.
147            warn!(
148                host = %host_str,
149                error = %e,
150                "validate_provider_url: DNS resolution failed (allowed to proceed)"
151            );
152        }
153    }
154
155    Ok(())
156}
157
158/// Return `true` if the IP address is in a blocked range.
159///
160/// Blocked ranges:
161/// - IPv4 loopback 127.0.0.0/8
162/// - IPv4 link-local 169.254.0.0/16
163/// - IPv4 private 10.0.0.0/8
164/// - IPv4 private 172.16.0.0/12
165/// - IPv4 private 192.168.0.0/16
166/// - IPv4 shared-address 100.64.0.0/10
167/// - IPv4 cloud-metadata 169.254.169.254 and 100.100.100.200 (covered by above)
168/// - IPv6 loopback ::1
169/// - IPv6 link-local fe80::/10
170/// - IPv6 ULA fc00::/7
171/// - IPv6 mapped / compatible IPv4 in blocked ranges
172fn is_blocked_ip(ip: IpAddr) -> bool {
173    match ip {
174        IpAddr::V4(v4) => is_blocked_v4(v4),
175        IpAddr::V6(v6) => is_blocked_v6(v6),
176    }
177}
178
179fn is_blocked_v4(ip: Ipv4Addr) -> bool {
180    let o = ip.octets();
181
182    // 127.0.0.0/8 — loopback
183    if o[0] == 127 {
184        return true;
185    }
186    // 10.0.0.0/8 — private
187    if o[0] == 10 {
188        return true;
189    }
190    // 172.16.0.0/12 — private
191    if o[0] == 172 && (o[1] & 0xf0) == 16 {
192        return true;
193    }
194    // 192.168.0.0/16 — private
195    if o[0] == 192 && o[1] == 168 {
196        return true;
197    }
198    // 169.254.0.0/16 — link-local (includes AWS/GCP metadata 169.254.169.254)
199    if o[0] == 169 && o[1] == 254 {
200        return true;
201    }
202    // 100.64.0.0/10 — CGNAT / shared address space (includes Alibaba 100.100.100.200)
203    if o[0] == 100 && (o[1] & 0xc0) == 64 {
204        return true;
205    }
206    // 0.0.0.0/8 — "this" network
207    if o[0] == 0 {
208        return true;
209    }
210    false
211}
212
213fn is_blocked_v6(ip: Ipv6Addr) -> bool {
214    let seg = ip.segments();
215
216    // ::1 — loopback
217    if ip == Ipv6Addr::LOCALHOST {
218        return true;
219    }
220    // fe80::/10 — link-local
221    if (seg[0] & 0xffc0) == 0xfe80 {
222        return true;
223    }
224    // fc00::/7 — ULA (unique local addresses)
225    if (seg[0] & 0xfe00) == 0xfc00 {
226        return true;
227    }
228    // ::ffff:0:0/96 — IPv4-mapped: check the embedded v4 address
229    if seg[0] == 0 && seg[1] == 0 && seg[2] == 0 && seg[3] == 0 && seg[4] == 0 && seg[5] == 0xffff {
230        let v4 = Ipv4Addr::new(
231            (seg[6] >> 8) as u8,
232            (seg[6] & 0xff) as u8,
233            (seg[7] >> 8) as u8,
234            (seg[7] & 0xff) as u8,
235        );
236        return is_blocked_v4(v4);
237    }
238    // 64:ff9b::/96 — IPv4/IPv6 translation (RFC 6052)
239    if seg[0] == 0x0064
240        && seg[1] == 0xff9b
241        && seg[2] == 0
242        && seg[3] == 0
243        && seg[4] == 0
244        && seg[5] == 0
245    {
246        let v4 = Ipv4Addr::new(
247            (seg[6] >> 8) as u8,
248            (seg[6] & 0xff) as u8,
249            (seg[7] >> 8) as u8,
250            (seg[7] & 0xff) as u8,
251        );
252        return is_blocked_v4(v4);
253    }
254    false
255}
256
257/// Check hostname-based denylist (not IP-literal hosts).
258fn check_hostname_denylist(host: &str) -> Result<(), UrlGuardError> {
259    let lower = host.to_ascii_lowercase();
260
261    if lower == "localhost" {
262        return Err(UrlGuardError::BlockedHostname(host.to_string()));
263    }
264    if lower.ends_with(".local") || lower == "local" {
265        return Err(UrlGuardError::BlockedHostname(host.to_string()));
266    }
267    if lower == "metadata.google.internal" {
268        return Err(UrlGuardError::BlockedHostname(host.to_string()));
269    }
270
271    Ok(())
272}
273
274// ---------------------------------------------------------------------------
275// Header filter
276// ---------------------------------------------------------------------------
277
278/// Header names that customers must not override via `extra_headers`.
279///
280/// - `authorization` — carries the API key set by the adapter
281/// - `x-api-key`     — Anthropic's auth header
282/// - `host`          — routing; overridable by reqwest/hyper but dangerous
283/// - `content-type`  — set correctly by the adapter
284/// - `anthropic-version` — must not be overridden
285/// - Hop-by-hop headers per RFC 7230 §6.1 that must not be forwarded
286const DENIED_HEADERS: &[&str] = &[
287    "authorization",
288    "x-api-key",
289    "host",
290    "content-type",
291    "anthropic-version",
292    // Hop-by-hop (RFC 7230 §6.1)
293    "connection",
294    "proxy-authorization",
295    "transfer-encoding",
296    "upgrade",
297    "te",
298    "trailer",
299    "keep-alive",
300    "proxy-connection",
301];
302
303/// Check whether any header in `extra_headers` is on the denylist.
304///
305/// Returns the **name** (as supplied by the caller) of the first disallowed
306/// header found, or `None` if all headers are acceptable. Matching is
307/// case-insensitive.
308///
309/// Use this at write time (e.g. credential creation) to reject bad input with
310/// a clear error message rather than silently dropping the header at use time.
311pub fn find_denied_header(headers: &[(String, String)]) -> Option<&str> {
312    for (name, _) in headers {
313        let lower = name.to_ascii_lowercase();
314        if DENIED_HEADERS.contains(&lower.as_str()) {
315            return Some(name.as_str());
316        }
317    }
318    None
319}
320
321/// Filter `extra_headers`, dropping any header whose name is in the denylist.
322///
323/// Matching is case-insensitive. Returns a new `Vec` containing only the
324/// allowed headers. A `warn!` log line is emitted for each dropped header.
325pub fn filter_extra_headers(headers: &[(String, String)]) -> Vec<(String, String)> {
326    headers
327        .iter()
328        .filter_map(|(name, value)| {
329            let lower = name.to_ascii_lowercase();
330            if DENIED_HEADERS.contains(&lower.as_str()) {
331                warn!(
332                    header = %name,
333                    "extra_headers: dropping denied header (authorization/host/hop-by-hop)"
334                );
335                None
336            } else {
337                Some((name.clone(), value.clone()))
338            }
339        })
340        .collect()
341}
342
343// ---------------------------------------------------------------------------
344// Tests
345// ---------------------------------------------------------------------------
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    // -- validate_provider_url --
352
353    #[test]
354    fn accepts_normal_https_url() {
355        assert!(validate_provider_url("https://api.openai.com/v1", false).is_ok());
356        assert!(validate_provider_url("https://api.anthropic.com", false).is_ok());
357        assert!(validate_provider_url("https://api.together.xyz/v1", false).is_ok());
358    }
359
360    #[test]
361    fn rejects_http_for_non_local() {
362        let err = validate_provider_url("http://api.example.com/v1", false).unwrap_err();
363        assert!(matches!(err, UrlGuardError::InsecureScheme(_)));
364    }
365
366    #[test]
367    fn allows_http_when_allow_local() {
368        assert!(validate_provider_url("http://localhost:11434/v1", true).is_ok());
369        assert!(validate_provider_url("http://127.0.0.1:8000/v1", true).is_ok());
370    }
371
372    #[test]
373    fn rejects_cloud_metadata_ip() {
374        let err =
375            validate_provider_url("https://169.254.169.254/latest/meta-data/", false).unwrap_err();
376        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
377    }
378
379    #[test]
380    fn rejects_alibaba_metadata_ip() {
381        let err = validate_provider_url("https://100.100.100.200/meta-data/", false).unwrap_err();
382        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
383    }
384
385    #[test]
386    fn rejects_loopback_ipv4() {
387        let err = validate_provider_url("https://127.0.0.1/v1", false).unwrap_err();
388        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
389    }
390
391    #[test]
392    fn rejects_private_10_x() {
393        let err = validate_provider_url("https://10.0.0.1/v1", false).unwrap_err();
394        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
395    }
396
397    #[test]
398    fn rejects_private_192_168() {
399        let err = validate_provider_url("https://192.168.1.1/v1", false).unwrap_err();
400        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
401    }
402
403    #[test]
404    fn rejects_private_172_16() {
405        let err = validate_provider_url("https://172.16.0.1/v1", false).unwrap_err();
406        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
407    }
408
409    #[test]
410    fn rejects_loopback_ipv6() {
411        let err = validate_provider_url("https://[::1]/v1", false).unwrap_err();
412        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
413    }
414
415    #[test]
416    fn rejects_ula_ipv6_fc00() {
417        let err = validate_provider_url("https://[fc00::1]/v1", false).unwrap_err();
418        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
419    }
420
421    #[test]
422    fn rejects_link_local_ipv6_fe80() {
423        let err = validate_provider_url("https://[fe80::1]/v1", false).unwrap_err();
424        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
425    }
426
427    #[test]
428    fn rejects_localhost_hostname() {
429        let err = validate_provider_url("https://localhost/v1", false).unwrap_err();
430        assert!(
431            matches!(err, UrlGuardError::BlockedHostname(_)),
432            "got: {err}"
433        );
434    }
435
436    #[test]
437    fn rejects_dot_local_hostname() {
438        let err = validate_provider_url("https://myhost.local/v1", false).unwrap_err();
439        assert!(
440            matches!(err, UrlGuardError::BlockedHostname(_)),
441            "got: {err}"
442        );
443    }
444
445    #[test]
446    fn rejects_metadata_google_internal() {
447        let err = validate_provider_url(
448            "https://metadata.google.internal/computeMetadata/v1/",
449            false,
450        )
451        .unwrap_err();
452        assert!(
453            matches!(err, UrlGuardError::BlockedHostname(_)),
454            "got: {err}"
455        );
456    }
457
458    #[test]
459    fn allows_localhost_when_allow_local() {
460        assert!(validate_provider_url("http://localhost:11434/v1", true).is_ok());
461        assert!(validate_provider_url("https://localhost:11434/v1", true).is_ok());
462    }
463
464    #[test]
465    fn rejects_invalid_url() {
466        let err = validate_provider_url("not-a-url", false).unwrap_err();
467        assert!(matches!(err, UrlGuardError::InvalidUrl(_)), "got: {err}");
468    }
469
470    #[test]
471    fn rejects_ftp_scheme() {
472        let err = validate_provider_url("ftp://example.com/v1", false).unwrap_err();
473        assert!(
474            matches!(err, UrlGuardError::InsecureScheme(_)),
475            "got: {err}"
476        );
477    }
478
479    // -- filter_extra_headers --
480
481    #[test]
482    fn drops_authorization_header() {
483        let headers = vec![
484            ("Authorization".to_string(), "Bearer fake".to_string()),
485            ("X-Custom".to_string(), "value".to_string()),
486        ];
487        let filtered = filter_extra_headers(&headers);
488        assert_eq!(filtered.len(), 1);
489        assert_eq!(filtered[0].0, "X-Custom");
490    }
491
492    #[test]
493    fn drops_host_header() {
494        let headers = vec![
495            ("Host".to_string(), "evil.internal".to_string()),
496            ("X-Org-ID".to_string(), "abc".to_string()),
497        ];
498        let filtered = filter_extra_headers(&headers);
499        assert_eq!(filtered.len(), 1);
500        assert_eq!(filtered[0].0, "X-Org-ID");
501    }
502
503    #[test]
504    fn drops_hop_by_hop_headers() {
505        let headers = vec![
506            ("Connection".to_string(), "close".to_string()),
507            ("Proxy-Authorization".to_string(), "Basic xyz".to_string()),
508            ("Transfer-Encoding".to_string(), "chunked".to_string()),
509            ("Keep-Alive".to_string(), "timeout=5".to_string()),
510            ("X-Real-Header".to_string(), "ok".to_string()),
511        ];
512        let filtered = filter_extra_headers(&headers);
513        assert_eq!(filtered.len(), 1);
514        assert_eq!(filtered[0].0, "X-Real-Header");
515    }
516
517    #[test]
518    fn keeps_legitimate_extra_headers() {
519        let headers = vec![
520            ("X-Custom-Header".to_string(), "custom-value".to_string()),
521            ("X-Org-Id".to_string(), "org-123".to_string()),
522            ("Accept-Language".to_string(), "en".to_string()),
523        ];
524        let filtered = filter_extra_headers(&headers);
525        assert_eq!(filtered.len(), 3);
526    }
527
528    #[test]
529    fn filter_is_case_insensitive() {
530        let headers = vec![
531            ("AUTHORIZATION".to_string(), "Bearer x".to_string()),
532            ("authorization".to_string(), "Bearer y".to_string()),
533            ("Authorization".to_string(), "Bearer z".to_string()),
534            ("x-api-key".to_string(), "sk-...".to_string()),
535            ("X-API-KEY".to_string(), "sk-...".to_string()),
536        ];
537        let filtered = filter_extra_headers(&headers);
538        assert!(
539            filtered.is_empty(),
540            "all auth headers should be dropped, got: {filtered:?}"
541        );
542    }
543
544    #[test]
545    fn cgnat_range_is_blocked() {
546        // 100.64.0.0/10 — CGNAT range that includes Alibaba metadata 100.100.100.200
547        let err = validate_provider_url("https://100.64.0.1/v1", false).unwrap_err();
548        assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
549    }
550
551    #[test]
552    fn is_blocked_v4_spot_checks() {
553        assert!(is_blocked_v4(Ipv4Addr::new(127, 0, 0, 1)));
554        assert!(is_blocked_v4(Ipv4Addr::new(169, 254, 169, 254)));
555        assert!(is_blocked_v4(Ipv4Addr::new(100, 100, 100, 200)));
556        assert!(is_blocked_v4(Ipv4Addr::new(10, 0, 0, 1)));
557        assert!(is_blocked_v4(Ipv4Addr::new(192, 168, 0, 1)));
558        assert!(is_blocked_v4(Ipv4Addr::new(172, 16, 0, 1)));
559        assert!(is_blocked_v4(Ipv4Addr::new(172, 31, 255, 255)));
560        assert!(!is_blocked_v4(Ipv4Addr::new(1, 1, 1, 1)));
561        assert!(!is_blocked_v4(Ipv4Addr::new(8, 8, 8, 8)));
562        assert!(!is_blocked_v4(Ipv4Addr::new(172, 32, 0, 1))); // just outside 172.16/12
563    }
564
565    #[test]
566    fn ipv6_mapped_v4_blocked() {
567        // ::ffff:127.0.0.1 is the IPv4-mapped loopback
568        let ip: Ipv6Addr = "::ffff:127.0.0.1".parse().unwrap();
569        assert!(is_blocked_v6(ip));
570    }
571
572    #[test]
573    fn public_ipv6_not_blocked() {
574        let ip: Ipv6Addr = "2001:4860:4860::8888".parse().unwrap(); // Google DNS
575        assert!(!is_blocked_v6(ip));
576    }
577}