1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
35
36use thiserror::Error;
37use tracing::warn;
38
39#[derive(Debug, Error, PartialEq, Eq)]
44pub enum UrlGuardError {
45 #[error("invalid URL: {0}")]
46 InvalidUrl(String),
47
48 #[error("insecure scheme '{0}': only https is allowed for hosted providers")]
49 InsecureScheme(String),
50
51 #[error("blocked host '{0}': private/loopback/link-local addresses are not allowed")]
52 BlockedHost(String),
53
54 #[error("blocked hostname '{0}': internal/metadata hostnames are not allowed")]
55 BlockedHostname(String),
56
57 #[error("resolved address for '{host}' is in a blocked range: {addr}")]
58 BlockedResolvedAddress { host: String, addr: IpAddr },
59}
60
61pub fn validate_provider_url(raw: &str, allow_local: bool) -> Result<(), UrlGuardError> {
78 let parsed =
80 url::Url::parse(raw).map_err(|e| UrlGuardError::InvalidUrl(format!("{raw}: {e}")))?;
81
82 let scheme = parsed.scheme();
84
85 if allow_local {
86 if scheme != "http" && scheme != "https" {
88 return Err(UrlGuardError::InsecureScheme(scheme.to_string()));
89 }
90 return Ok(());
91 }
92
93 if scheme != "https" {
95 return Err(UrlGuardError::InsecureScheme(scheme.to_string()));
96 }
97
98 let host_str = parsed
100 .host_str()
101 .ok_or_else(|| UrlGuardError::InvalidUrl(format!("{raw}: no host")))?;
102
103 check_hostname_denylist(host_str)?;
105
106 let bare_host = if host_str.starts_with('[') && host_str.ends_with(']') {
110 &host_str[1..host_str.len() - 1]
111 } else {
112 host_str
113 };
114
115 if let Ok(ip) = bare_host.parse::<IpAddr>() {
116 if is_blocked_ip(ip) {
117 return Err(UrlGuardError::BlockedHost(host_str.to_string()));
118 }
119 return Ok(());
121 }
122
123 let port = parsed.port_or_known_default().unwrap_or(443);
126 let lookup_target = format!("{host_str}:{port}");
127 match lookup_target.to_socket_addrs() {
128 Ok(addrs) => {
129 for sa in addrs {
130 let ip = sa.ip();
131 if is_blocked_ip(ip) {
132 warn!(
133 host = %host_str,
134 addr = %ip,
135 "validate_provider_url: resolved address is in a blocked range"
136 );
137 return Err(UrlGuardError::BlockedResolvedAddress {
138 host: host_str.to_string(),
139 addr: ip,
140 });
141 }
142 }
143 }
144 Err(e) => {
145 warn!(
148 host = %host_str,
149 error = %e,
150 "validate_provider_url: DNS resolution failed (allowed to proceed)"
151 );
152 }
153 }
154
155 Ok(())
156}
157
158fn is_blocked_ip(ip: IpAddr) -> bool {
173 match ip {
174 IpAddr::V4(v4) => is_blocked_v4(v4),
175 IpAddr::V6(v6) => is_blocked_v6(v6),
176 }
177}
178
179fn is_blocked_v4(ip: Ipv4Addr) -> bool {
180 let o = ip.octets();
181
182 if o[0] == 127 {
184 return true;
185 }
186 if o[0] == 10 {
188 return true;
189 }
190 if o[0] == 172 && (o[1] & 0xf0) == 16 {
192 return true;
193 }
194 if o[0] == 192 && o[1] == 168 {
196 return true;
197 }
198 if o[0] == 169 && o[1] == 254 {
200 return true;
201 }
202 if o[0] == 100 && (o[1] & 0xc0) == 64 {
204 return true;
205 }
206 if o[0] == 0 {
208 return true;
209 }
210 false
211}
212
213fn is_blocked_v6(ip: Ipv6Addr) -> bool {
214 let seg = ip.segments();
215
216 if ip == Ipv6Addr::LOCALHOST {
218 return true;
219 }
220 if (seg[0] & 0xffc0) == 0xfe80 {
222 return true;
223 }
224 if (seg[0] & 0xfe00) == 0xfc00 {
226 return true;
227 }
228 if seg[0] == 0 && seg[1] == 0 && seg[2] == 0 && seg[3] == 0 && seg[4] == 0 && seg[5] == 0xffff {
230 let v4 = Ipv4Addr::new(
231 (seg[6] >> 8) as u8,
232 (seg[6] & 0xff) as u8,
233 (seg[7] >> 8) as u8,
234 (seg[7] & 0xff) as u8,
235 );
236 return is_blocked_v4(v4);
237 }
238 if seg[0] == 0x0064
240 && seg[1] == 0xff9b
241 && seg[2] == 0
242 && seg[3] == 0
243 && seg[4] == 0
244 && seg[5] == 0
245 {
246 let v4 = Ipv4Addr::new(
247 (seg[6] >> 8) as u8,
248 (seg[6] & 0xff) as u8,
249 (seg[7] >> 8) as u8,
250 (seg[7] & 0xff) as u8,
251 );
252 return is_blocked_v4(v4);
253 }
254 false
255}
256
257fn check_hostname_denylist(host: &str) -> Result<(), UrlGuardError> {
259 let lower = host.to_ascii_lowercase();
260
261 if lower == "localhost" {
262 return Err(UrlGuardError::BlockedHostname(host.to_string()));
263 }
264 if lower.ends_with(".local") || lower == "local" {
265 return Err(UrlGuardError::BlockedHostname(host.to_string()));
266 }
267 if lower == "metadata.google.internal" {
268 return Err(UrlGuardError::BlockedHostname(host.to_string()));
269 }
270
271 Ok(())
272}
273
274const DENIED_HEADERS: &[&str] = &[
287 "authorization",
288 "x-api-key",
289 "host",
290 "content-type",
291 "anthropic-version",
292 "connection",
294 "proxy-authorization",
295 "transfer-encoding",
296 "upgrade",
297 "te",
298 "trailer",
299 "keep-alive",
300 "proxy-connection",
301];
302
303pub fn find_denied_header(headers: &[(String, String)]) -> Option<&str> {
312 for (name, _) in headers {
313 let lower = name.to_ascii_lowercase();
314 if DENIED_HEADERS.contains(&lower.as_str()) {
315 return Some(name.as_str());
316 }
317 }
318 None
319}
320
321pub fn filter_extra_headers(headers: &[(String, String)]) -> Vec<(String, String)> {
326 headers
327 .iter()
328 .filter_map(|(name, value)| {
329 let lower = name.to_ascii_lowercase();
330 if DENIED_HEADERS.contains(&lower.as_str()) {
331 warn!(
332 header = %name,
333 "extra_headers: dropping denied header (authorization/host/hop-by-hop)"
334 );
335 None
336 } else {
337 Some((name.clone(), value.clone()))
338 }
339 })
340 .collect()
341}
342
343#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
354 fn accepts_normal_https_url() {
355 assert!(validate_provider_url("https://api.openai.com/v1", false).is_ok());
356 assert!(validate_provider_url("https://api.anthropic.com", false).is_ok());
357 assert!(validate_provider_url("https://api.together.xyz/v1", false).is_ok());
358 }
359
360 #[test]
361 fn rejects_http_for_non_local() {
362 let err = validate_provider_url("http://api.example.com/v1", false).unwrap_err();
363 assert!(matches!(err, UrlGuardError::InsecureScheme(_)));
364 }
365
366 #[test]
367 fn allows_http_when_allow_local() {
368 assert!(validate_provider_url("http://localhost:11434/v1", true).is_ok());
369 assert!(validate_provider_url("http://127.0.0.1:8000/v1", true).is_ok());
370 }
371
372 #[test]
373 fn rejects_cloud_metadata_ip() {
374 let err =
375 validate_provider_url("https://169.254.169.254/latest/meta-data/", false).unwrap_err();
376 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
377 }
378
379 #[test]
380 fn rejects_alibaba_metadata_ip() {
381 let err = validate_provider_url("https://100.100.100.200/meta-data/", false).unwrap_err();
382 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
383 }
384
385 #[test]
386 fn rejects_loopback_ipv4() {
387 let err = validate_provider_url("https://127.0.0.1/v1", false).unwrap_err();
388 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
389 }
390
391 #[test]
392 fn rejects_private_10_x() {
393 let err = validate_provider_url("https://10.0.0.1/v1", false).unwrap_err();
394 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
395 }
396
397 #[test]
398 fn rejects_private_192_168() {
399 let err = validate_provider_url("https://192.168.1.1/v1", false).unwrap_err();
400 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
401 }
402
403 #[test]
404 fn rejects_private_172_16() {
405 let err = validate_provider_url("https://172.16.0.1/v1", false).unwrap_err();
406 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
407 }
408
409 #[test]
410 fn rejects_loopback_ipv6() {
411 let err = validate_provider_url("https://[::1]/v1", false).unwrap_err();
412 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
413 }
414
415 #[test]
416 fn rejects_ula_ipv6_fc00() {
417 let err = validate_provider_url("https://[fc00::1]/v1", false).unwrap_err();
418 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
419 }
420
421 #[test]
422 fn rejects_link_local_ipv6_fe80() {
423 let err = validate_provider_url("https://[fe80::1]/v1", false).unwrap_err();
424 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
425 }
426
427 #[test]
428 fn rejects_localhost_hostname() {
429 let err = validate_provider_url("https://localhost/v1", false).unwrap_err();
430 assert!(
431 matches!(err, UrlGuardError::BlockedHostname(_)),
432 "got: {err}"
433 );
434 }
435
436 #[test]
437 fn rejects_dot_local_hostname() {
438 let err = validate_provider_url("https://myhost.local/v1", false).unwrap_err();
439 assert!(
440 matches!(err, UrlGuardError::BlockedHostname(_)),
441 "got: {err}"
442 );
443 }
444
445 #[test]
446 fn rejects_metadata_google_internal() {
447 let err = validate_provider_url(
448 "https://metadata.google.internal/computeMetadata/v1/",
449 false,
450 )
451 .unwrap_err();
452 assert!(
453 matches!(err, UrlGuardError::BlockedHostname(_)),
454 "got: {err}"
455 );
456 }
457
458 #[test]
459 fn allows_localhost_when_allow_local() {
460 assert!(validate_provider_url("http://localhost:11434/v1", true).is_ok());
461 assert!(validate_provider_url("https://localhost:11434/v1", true).is_ok());
462 }
463
464 #[test]
465 fn rejects_invalid_url() {
466 let err = validate_provider_url("not-a-url", false).unwrap_err();
467 assert!(matches!(err, UrlGuardError::InvalidUrl(_)), "got: {err}");
468 }
469
470 #[test]
471 fn rejects_ftp_scheme() {
472 let err = validate_provider_url("ftp://example.com/v1", false).unwrap_err();
473 assert!(
474 matches!(err, UrlGuardError::InsecureScheme(_)),
475 "got: {err}"
476 );
477 }
478
479 #[test]
482 fn drops_authorization_header() {
483 let headers = vec![
484 ("Authorization".to_string(), "Bearer fake".to_string()),
485 ("X-Custom".to_string(), "value".to_string()),
486 ];
487 let filtered = filter_extra_headers(&headers);
488 assert_eq!(filtered.len(), 1);
489 assert_eq!(filtered[0].0, "X-Custom");
490 }
491
492 #[test]
493 fn drops_host_header() {
494 let headers = vec![
495 ("Host".to_string(), "evil.internal".to_string()),
496 ("X-Org-ID".to_string(), "abc".to_string()),
497 ];
498 let filtered = filter_extra_headers(&headers);
499 assert_eq!(filtered.len(), 1);
500 assert_eq!(filtered[0].0, "X-Org-ID");
501 }
502
503 #[test]
504 fn drops_hop_by_hop_headers() {
505 let headers = vec![
506 ("Connection".to_string(), "close".to_string()),
507 ("Proxy-Authorization".to_string(), "Basic xyz".to_string()),
508 ("Transfer-Encoding".to_string(), "chunked".to_string()),
509 ("Keep-Alive".to_string(), "timeout=5".to_string()),
510 ("X-Real-Header".to_string(), "ok".to_string()),
511 ];
512 let filtered = filter_extra_headers(&headers);
513 assert_eq!(filtered.len(), 1);
514 assert_eq!(filtered[0].0, "X-Real-Header");
515 }
516
517 #[test]
518 fn keeps_legitimate_extra_headers() {
519 let headers = vec![
520 ("X-Custom-Header".to_string(), "custom-value".to_string()),
521 ("X-Org-Id".to_string(), "org-123".to_string()),
522 ("Accept-Language".to_string(), "en".to_string()),
523 ];
524 let filtered = filter_extra_headers(&headers);
525 assert_eq!(filtered.len(), 3);
526 }
527
528 #[test]
529 fn filter_is_case_insensitive() {
530 let headers = vec![
531 ("AUTHORIZATION".to_string(), "Bearer x".to_string()),
532 ("authorization".to_string(), "Bearer y".to_string()),
533 ("Authorization".to_string(), "Bearer z".to_string()),
534 ("x-api-key".to_string(), "sk-...".to_string()),
535 ("X-API-KEY".to_string(), "sk-...".to_string()),
536 ];
537 let filtered = filter_extra_headers(&headers);
538 assert!(
539 filtered.is_empty(),
540 "all auth headers should be dropped, got: {filtered:?}"
541 );
542 }
543
544 #[test]
545 fn cgnat_range_is_blocked() {
546 let err = validate_provider_url("https://100.64.0.1/v1", false).unwrap_err();
548 assert!(matches!(err, UrlGuardError::BlockedHost(_)), "got: {err}");
549 }
550
551 #[test]
552 fn is_blocked_v4_spot_checks() {
553 assert!(is_blocked_v4(Ipv4Addr::new(127, 0, 0, 1)));
554 assert!(is_blocked_v4(Ipv4Addr::new(169, 254, 169, 254)));
555 assert!(is_blocked_v4(Ipv4Addr::new(100, 100, 100, 200)));
556 assert!(is_blocked_v4(Ipv4Addr::new(10, 0, 0, 1)));
557 assert!(is_blocked_v4(Ipv4Addr::new(192, 168, 0, 1)));
558 assert!(is_blocked_v4(Ipv4Addr::new(172, 16, 0, 1)));
559 assert!(is_blocked_v4(Ipv4Addr::new(172, 31, 255, 255)));
560 assert!(!is_blocked_v4(Ipv4Addr::new(1, 1, 1, 1)));
561 assert!(!is_blocked_v4(Ipv4Addr::new(8, 8, 8, 8)));
562 assert!(!is_blocked_v4(Ipv4Addr::new(172, 32, 0, 1))); }
564
565 #[test]
566 fn ipv6_mapped_v4_blocked() {
567 let ip: Ipv6Addr = "::ffff:127.0.0.1".parse().unwrap();
569 assert!(is_blocked_v6(ip));
570 }
571
572 #[test]
573 fn public_ipv6_not_blocked() {
574 let ip: Ipv6Addr = "2001:4860:4860::8888".parse().unwrap(); assert!(!is_blocked_v6(ip));
576 }
577}