Skip to main content

systemprompt_api/services/
request_base_url.rs

1//! Request-derived base URL for OAuth discovery responses.
2//!
3//! RFC 9728 implementations identify themselves coherently from the host the
4//! client actually dialled. A single gateway reachable via both `127.0.0.1`
5//! and `localhost` must echo whichever the client used in every URL it
6//! returns (`issuer`, `authorization_endpoint`, `token_endpoint`, `resource`…),
7//! otherwise the client's RFC 8707 `resource` indicator won't round-trip
8//! against the configured `api_external_url` origin.
9//!
10//! [`RequestBaseUrl`] is an axum extractor that resolves
11//! `scheme://host[:port]` from the incoming request, validating the host
12//! against a small allowlist seeded from `api_external_url`. On allowlist
13//! miss or missing/invalid header it falls back to `api_external_url` — the
14//! gateway never advertises a hostname an attacker fabricated via Host
15//! header injection.
16
17use axum::extract::FromRequestParts;
18use http::request::Parts;
19use http::{StatusCode, header};
20use systemprompt_models::Config;
21
22#[derive(Debug, Clone)]
23pub struct RequestBaseUrl {
24    base: String,
25    origin: url::Origin,
26}
27
28impl RequestBaseUrl {
29    #[must_use]
30    pub fn as_str(&self) -> &str {
31        &self.base
32    }
33
34    #[must_use]
35    pub const fn origin(&self) -> &url::Origin {
36        &self.origin
37    }
38
39    #[must_use]
40    pub fn into_string(self) -> String {
41        self.base
42    }
43}
44
45fn is_loopback_host(host: &str) -> bool {
46    let bare = host.split(':').next().unwrap_or(host).to_ascii_lowercase();
47    bare == "localhost" || bare == "127.0.0.1" || bare == "[::1]" || bare == "::1"
48}
49
50fn host_in_allowlist(candidate_host: &str, configured: &url::Url) -> bool {
51    let candidate_bare = candidate_host
52        .rsplit_once(':')
53        .map_or(candidate_host, |(h, _)| h)
54        .to_ascii_lowercase();
55    let configured_host = configured.host_str().unwrap_or("").to_ascii_lowercase();
56
57    if candidate_bare == configured_host {
58        return true;
59    }
60    if is_loopback_host(&configured_host) && is_loopback_host(&candidate_bare) {
61        return true;
62    }
63    false
64}
65
66fn fallback_from_url(configured: &url::Url) -> RequestBaseUrl {
67    let trimmed = configured.as_str().trim_end_matches('/').to_owned();
68    RequestBaseUrl {
69        base: trimmed,
70        origin: configured.origin(),
71    }
72}
73
74/// Resolve a [`RequestBaseUrl`] from an optional Host header and configured
75/// `api_external_url`.
76///
77/// Exposed for unit testing — production callers use the [`FromRequestParts`]
78/// impl which reads both from the request and global config.
79#[must_use]
80pub fn resolve(raw_host: Option<&str>, configured: &url::Url) -> RequestBaseUrl {
81    if let Some(host) = raw_host.map(str::trim).filter(|s| !s.is_empty())
82        && let Ok(resolved) = build_from_host(host, configured)
83    {
84        return resolved;
85    }
86    fallback_from_url(configured)
87}
88
89fn build_from_host(raw_host: &str, configured: &url::Url) -> Result<RequestBaseUrl, &'static str> {
90    if raw_host.is_empty() || raw_host.contains('/') || raw_host.contains(' ') {
91        return Err("invalid host header");
92    }
93    if !host_in_allowlist(raw_host, configured) {
94        return Err("host not in allowlist");
95    }
96    let host_bare = raw_host
97        .rsplit_once(':')
98        .map_or(raw_host, |(h, _)| h)
99        .to_ascii_lowercase();
100    let scheme = if is_loopback_host(&host_bare) {
101        "http"
102    } else {
103        configured.scheme()
104    };
105    let base = format!("{scheme}://{raw_host}");
106    let parsed = url::Url::parse(&base).map_err(|_e| "host header did not parse as URL")?;
107    Ok(RequestBaseUrl {
108        base: base.trim_end_matches('/').to_owned(),
109        origin: parsed.origin(),
110    })
111}
112
113impl<S: Send + Sync> FromRequestParts<S> for RequestBaseUrl {
114    type Rejection = (StatusCode, String);
115
116    #[expect(
117        clippy::unused_async_trait_impl,
118        reason = "async signature required by the FromRequestParts trait; this \
119                  extractor resolves the base URL synchronously"
120    )]
121    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
122        let cfg = Config::get().map_err(|e| {
123            tracing::error!(error = %e, "Failed to load config for RequestBaseUrl");
124            (
125                StatusCode::INTERNAL_SERVER_ERROR,
126                "Configuration unavailable".to_owned(),
127            )
128        })?;
129        let configured = url::Url::parse(&cfg.api_external_url).map_err(|e| {
130            tracing::error!(
131                error = %e,
132                api_external_url = %cfg.api_external_url,
133                "api_external_url is not a valid URL — bootstrap validation should have caught this"
134            );
135            (
136                StatusCode::INTERNAL_SERVER_ERROR,
137                "Configuration invalid".to_owned(),
138            )
139        })?;
140
141        let raw_host = parts
142            .headers
143            .get(header::HOST)
144            .and_then(|v| v.to_str().ok());
145        Ok(resolve(raw_host, &configured))
146    }
147}