Skip to main content

sts_cat/
oidc.rs

1use crate::error::Error;
2
3const MAX_RESPONSE_SIZE: usize = 100 * 1024; // 100 KiB
4
5static PATH_CHAR_RE: std::sync::LazyLock<regex::Regex> =
6    std::sync::LazyLock::new(|| regex::Regex::new(r"^[a-zA-Z0-9\-._~/]+$").unwrap());
7
8pub fn validate_issuer(issuer: &str) -> Result<(), Error> {
9    if issuer.is_empty() || issuer.chars().count() > 255 {
10        return Err(Error::Unauthenticated(
11            "issuer empty or exceeds 255 characters".into(),
12        ));
13    }
14
15    let parsed = url::Url::parse(issuer)
16        .map_err(|_| Error::Unauthenticated("issuer is not a valid URL".into()))?;
17
18    match parsed.scheme() {
19        "https" => {}
20        "http" => match parsed.host() {
21            Some(url::Host::Domain("localhost")) => {}
22            Some(url::Host::Ipv4(ip)) if ip == std::net::Ipv4Addr::LOCALHOST => {}
23            Some(url::Host::Ipv6(ip)) if ip == std::net::Ipv6Addr::LOCALHOST => {}
24            _ => {
25                return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
26            }
27        },
28        _ => {
29            return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
30        }
31    }
32
33    // Check both parsed and raw: url::Url may normalize away certain encodings
34    if parsed.query().is_some() || parsed.fragment().is_some() {
35        return Err(Error::Unauthenticated(
36            "issuer must not contain query or fragment".into(),
37        ));
38    }
39    if issuer.contains('?') || issuer.contains('#') {
40        return Err(Error::Unauthenticated(
41            "issuer must not contain query or fragment".into(),
42        ));
43    }
44
45    if parsed.host_str().is_none() || parsed.host_str() == Some("") {
46        return Err(Error::Unauthenticated("issuer must have a host".into()));
47    }
48
49    if !parsed.username().is_empty() || parsed.password().is_some() {
50        return Err(Error::Unauthenticated(
51            "issuer must not contain userinfo".into(),
52        ));
53    }
54
55    // ASCII-only hostname — check the raw input string because url::Url
56    // converts IDN to punycode (e.g. exämple.com → xn--exmple-cua.com)
57    let raw_host = {
58        let after_scheme = issuer
59            .strip_prefix(parsed.scheme())
60            .and_then(|s| s.strip_prefix("://"))
61            .unwrap_or("");
62        let host_part = if let Some(pos) = after_scheme.find('/') {
63            &after_scheme[..pos]
64        } else {
65            after_scheme
66        };
67        if host_part.starts_with('[') {
68            // IPv6: take everything including brackets
69            host_part.to_owned()
70        } else if let Some(pos) = host_part.rfind(':') {
71            host_part[..pos].to_owned()
72        } else {
73            host_part.to_owned()
74        }
75    };
76    for ch in raw_host.chars() {
77        if ch as u32 > 127 {
78            return Err(Error::Unauthenticated(
79                "issuer hostname must be ASCII-only".into(),
80            ));
81        }
82        if ch.is_control() || ch.is_whitespace() {
83            return Err(Error::Unauthenticated(
84                "issuer hostname contains invalid characters".into(),
85            ));
86        }
87    }
88
89    // Path validation — use the raw issuer string to extract the path,
90    // since url::Url normalizes away `.` and `..` segments.
91    let raw_path = issuer
92        .strip_prefix(parsed.scheme())
93        .and_then(|s| s.strip_prefix("://"))
94        .and_then(|s| s.find('/').map(|pos| &s[pos..]))
95        .unwrap_or("");
96    let path = if raw_path.is_empty() {
97        parsed.path()
98    } else {
99        raw_path
100    };
101    if !path.is_empty() && path != "/" {
102        if !path.starts_with('/') {
103            return Err(Error::Unauthenticated(
104                "issuer path must start with /".into(),
105            ));
106        }
107        if path.contains("..") {
108            return Err(Error::Unauthenticated(
109                "issuer path must not contain ..".into(),
110            ));
111        }
112        if path.contains("//") {
113            return Err(Error::Unauthenticated(
114                "issuer path must not contain //".into(),
115            ));
116        }
117        if path.contains("~~") {
118            return Err(Error::Unauthenticated(
119                "issuer path must not contain ~~".into(),
120            ));
121        }
122        if path.ends_with('~') {
123            return Err(Error::Unauthenticated(
124                "issuer path must not end with ~".into(),
125            ));
126        }
127
128        if !PATH_CHAR_RE.is_match(path) {
129            return Err(Error::Unauthenticated(
130                "issuer path contains invalid characters".into(),
131            ));
132        }
133
134        for segment in path.split('/') {
135            if segment.is_empty() {
136                continue;
137            }
138            if segment == "." || segment == ".." || segment == "~" {
139                return Err(Error::Unauthenticated(
140                    "issuer path contains invalid segment".into(),
141                ));
142            }
143            if segment.len() > 150 {
144                return Err(Error::Unauthenticated(
145                    "issuer path segment exceeds 150 characters".into(),
146                ));
147            }
148        }
149    }
150
151    Ok(())
152}
153
154const SUBJECT_REJECT_CHARS: &str = "\"'`\\<>;&$(){}[]";
155const AUDIENCE_REJECT_CHARS: &str = "\"'`\\<>;|&$(){}[]@";
156
157pub fn validate_subject(value: &str) -> Result<(), Error> {
158    validate_claim_string(value, SUBJECT_REJECT_CHARS, "subject")
159}
160
161pub fn validate_audience(value: &str) -> Result<(), Error> {
162    validate_claim_string(value, AUDIENCE_REJECT_CHARS, "audience")
163}
164
165fn validate_claim_string(value: &str, reject_chars: &str, field: &str) -> Result<(), Error> {
166    if value.is_empty() {
167        return Err(Error::Unauthenticated(format!("{field} must not be empty")));
168    }
169    if value.chars().count() > 255 {
170        return Err(Error::Unauthenticated(format!(
171            "{field} exceeds 255 characters"
172        )));
173    }
174    for ch in value.chars() {
175        if (ch as u32) <= 0x1f {
176            return Err(Error::Unauthenticated(format!(
177                "{field} contains control characters"
178            )));
179        }
180        if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
181            return Err(Error::Unauthenticated(format!(
182                "{field} contains whitespace"
183            )));
184        }
185        if reject_chars.contains(ch) {
186            return Err(Error::Unauthenticated(format!(
187                "{field} contains invalid character"
188            )));
189        }
190        if !ch.is_alphanumeric() && !ch.is_ascii_punctuation() && ch as u32 > 127 {
191            // Approximate Go's unicode.IsPrint (categories L, M, N, P, S, Zs)
192            if !is_printable(ch) {
193                return Err(Error::Unauthenticated(format!(
194                    "{field} contains non-printable character"
195                )));
196            }
197        }
198    }
199    Ok(())
200}
201
202fn is_printable(ch: char) -> bool {
203    !ch.is_control() && ch as u32 != 0xFFFD
204}
205
206#[derive(Debug, serde::Deserialize)]
207pub(crate) struct OidcDiscoveryDocument {
208    pub(crate) issuer: String,
209    pub(crate) jwks_uri: String,
210}
211
212#[derive(Debug, Clone)]
213pub(crate) struct OidcProvider {
214    pub(crate) jwks: jsonwebtoken::jwk::JwkSet,
215}
216
217pub struct OidcVerifier {
218    http: reqwest::Client,
219    cache: moka::future::Cache<String, std::sync::Arc<OidcProvider>>,
220    allowed_issuers: Option<std::collections::HashSet<String>>,
221}
222
223impl Default for OidcVerifier {
224    fn default() -> Self {
225        Self::new(None)
226    }
227}
228
229#[derive(Debug, serde::Deserialize)]
230pub struct TokenClaims {
231    pub iss: String,
232    pub sub: String,
233    pub aud: OneOrMany,
234    #[serde(flatten)]
235    pub extra: std::collections::HashMap<String, serde_json::Value>,
236}
237
238#[derive(Debug, serde::Deserialize)]
239#[serde(untagged)]
240pub enum OneOrMany {
241    One(String),
242    Many(Vec<String>),
243}
244
245impl OneOrMany {
246    pub fn iter(&self) -> impl Iterator<Item = &str> {
247        let slice: &[String] = match self {
248            OneOrMany::One(s) => std::slice::from_ref(s),
249            OneOrMany::Many(v) => v.as_slice(),
250        };
251        slice.iter().map(|s| s.as_str())
252    }
253}
254
255impl OidcVerifier {
256    pub fn new(allowed_issuer_urls: Option<Vec<String>>) -> Self {
257        let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
258            let url_str = attempt.url().to_string();
259            if validate_issuer(&url_str).is_err() {
260                attempt.error(std::io::Error::new(
261                    std::io::ErrorKind::PermissionDenied,
262                    format!("redirect to invalid issuer URL: {url_str}"),
263                ))
264            } else {
265                attempt.follow()
266            }
267        });
268
269        let http = reqwest::Client::builder()
270            .connect_timeout(std::time::Duration::from_secs(10))
271            .timeout(std::time::Duration::from_secs(30))
272            .redirect(redirect_policy)
273            .user_agent(format!("sts-cat/{}", env!("CARGO_PKG_VERSION")))
274            .build()
275            .expect("failed to build OIDC HTTP client");
276
277        let cache = moka::future::Cache::builder()
278            .max_capacity(100)
279            .time_to_live(std::time::Duration::from_secs(900))
280            .build();
281
282        let allowed_issuers = allowed_issuer_urls.map(|urls| {
283            urls.into_iter()
284                .map(|u| u.trim_end_matches('/').to_owned())
285                .collect()
286        });
287
288        Self {
289            http,
290            cache,
291            allowed_issuers,
292        }
293    }
294
295    async fn discover(&self, issuer: &str) -> Result<std::sync::Arc<OidcProvider>, Error> {
296        if let Some(provider) = self.cache.get(issuer).await {
297            return Ok(provider);
298        }
299
300        let provider = self.discover_with_retry(issuer).await?;
301        let provider = std::sync::Arc::new(provider);
302        self.cache.insert(issuer.to_owned(), provider.clone()).await;
303        Ok(provider)
304    }
305
306    async fn discover_with_retry(&self, issuer: &str) -> Result<OidcProvider, Error> {
307        use backon::Retryable as _;
308
309        let discover_fn = || async { self.discover_once(issuer).await };
310
311        discover_fn
312            .retry(
313                backon::ExponentialBuilder::default()
314                    .with_min_delay(std::time::Duration::from_secs(1))
315                    .with_max_delay(std::time::Duration::from_secs(30))
316                    .with_factor(2.0)
317                    .with_jitter()
318                    .with_max_times(6),
319            )
320            .when(|e| !is_permanent_error(e))
321            .await
322    }
323
324    async fn discover_once(&self, issuer: &str) -> Result<OidcProvider, Error> {
325        let discovery_url = format!(
326            "{}/.well-known/openid-configuration",
327            issuer.trim_end_matches('/')
328        );
329
330        let resp = self
331            .http
332            .get(&discovery_url)
333            .send()
334            .await
335            .map_err(Error::OidcDiscovery)?;
336
337        let status = resp.status();
338        if !status.is_success() {
339            return Err(Error::OidcHttpError(status.as_u16()));
340        }
341
342        let body = read_limited_body(resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
343        let doc: OidcDiscoveryDocument =
344            serde_json::from_slice(&body).map_err(|e| Error::Internal(Box::new(e)))?;
345
346        let expected = issuer.trim_end_matches('/');
347        let actual = doc.issuer.trim_end_matches('/');
348        if expected != actual {
349            return Err(Error::Unauthenticated(
350                "OIDC discovery issuer mismatch".into(),
351            ));
352        }
353
354        let jwks_resp = self
355            .http
356            .get(&doc.jwks_uri)
357            .send()
358            .await
359            .map_err(Error::OidcDiscovery)?;
360
361        if !jwks_resp.status().is_success() {
362            return Err(Error::OidcHttpError(jwks_resp.status().as_u16()));
363        }
364
365        let jwks_body =
366            read_limited_body(jwks_resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
367        let jwks: jsonwebtoken::jwk::JwkSet =
368            serde_json::from_slice(&jwks_body).map_err(|e| Error::Internal(Box::new(e)))?;
369
370        Ok(OidcProvider { jwks })
371    }
372
373    pub async fn verify(&self, token: &str) -> Result<TokenClaims, Error> {
374        let header = jsonwebtoken::decode_header(token)?;
375
376        // Extract issuer without signature verification to discover the OIDC provider
377        let mut validation = jsonwebtoken::Validation::default();
378        validation.insecure_disable_signature_validation();
379        validation.validate_aud = false;
380        validation.validate_exp = false;
381
382        let unverified: jsonwebtoken::TokenData<TokenClaims> = jsonwebtoken::decode(
383            token,
384            &jsonwebtoken::DecodingKey::from_secret(&[]),
385            &validation,
386        )?;
387
388        let issuer = &unverified.claims.iss;
389
390        validate_issuer(issuer)?;
391
392        if let Some(ref allowed) = self.allowed_issuers {
393            let normalized = issuer.trim_end_matches('/');
394            if !allowed.contains(normalized) {
395                return Err(Error::Unauthenticated("issuer not in allowed list".into()));
396            }
397        }
398
399        let provider = self.discover(issuer).await?;
400
401        let kid = header.kid.as_deref();
402        let decoding_key = find_decoding_key(&provider.jwks, kid, &header.alg)?;
403
404        let mut verification = jsonwebtoken::Validation::new(header.alg);
405        verification.validate_aud = false; // Audience checked later by trust policy
406        verification.set_issuer(&[issuer]);
407
408        let token_data: jsonwebtoken::TokenData<TokenClaims> =
409            jsonwebtoken::decode(token, &decoding_key, &verification)?;
410
411        Ok(token_data.claims)
412    }
413}
414
415fn find_decoding_key(
416    jwks: &jsonwebtoken::jwk::JwkSet,
417    kid: Option<&str>,
418    alg: &jsonwebtoken::Algorithm,
419) -> Result<jsonwebtoken::DecodingKey, Error> {
420    let jwk = if let Some(kid) = kid {
421        jwks.find(kid).ok_or_else(|| {
422            Error::Unauthenticated(format!("no matching key found for kid: {kid}"))
423        })?
424    } else {
425        let alg_str = format!("{alg:?}");
426        jwks.keys
427            .iter()
428            .find(|k| {
429                k.common
430                    .key_algorithm
431                    .is_some_and(|ka| format!("{ka:?}") == alg_str)
432            })
433            .or_else(|| jwks.keys.first())
434            .ok_or_else(|| Error::Unauthenticated("no keys in JWKS".into()))?
435    };
436
437    jsonwebtoken::DecodingKey::from_jwk(jwk)
438        .map_err(|e| Error::Unauthenticated(format!("invalid JWK: {e}")))
439}
440
441fn is_permanent_error(e: &Error) -> bool {
442    match e {
443        // HTTP 4xx (except 408, 429) and 501 are permanent
444        Error::OidcHttpError(code) => matches!(
445            code,
446            400 | 401 | 403 | 404 | 405 | 406 | 410 | 415 | 422 | 501
447        ),
448        Error::OidcDiscovery(_) => false, // Network errors are transient
449        _ => true,                        // Parse errors etc. are permanent
450    }
451}
452
453pub(crate) async fn read_limited_body(
454    resp: reqwest::Response,
455    limit: usize,
456    map_err: impl Fn(reqwest::Error) -> Error,
457) -> Result<Vec<u8>, Error> {
458    if let Some(len) = resp.content_length()
459        && len as usize > limit
460    {
461        return Err(Error::Unauthenticated(format!(
462            "response too large: {len} bytes (limit: {limit})"
463        )));
464    }
465
466    use futures_util::StreamExt as _;
467    let initial_capacity = resp
468        .content_length()
469        .map_or(4096, |len| (len as usize).min(limit));
470    let mut stream = resp.bytes_stream();
471    let mut buf = Vec::with_capacity(initial_capacity);
472    while let Some(chunk) = stream.next().await {
473        let chunk = chunk.map_err(&map_err)?;
474        if buf.len() + chunk.len() > limit {
475            return Err(Error::Unauthenticated(format!(
476                "response too large (limit: {limit})"
477            )));
478        }
479        buf.extend_from_slice(&chunk);
480    }
481    Ok(buf)
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_validate_issuer_valid() {
490        assert!(validate_issuer("https://accounts.google.com").is_ok());
491        assert!(validate_issuer("https://token.actions.githubusercontent.com").is_ok());
492        assert!(validate_issuer("https://example.com/path/to/issuer").is_ok());
493        assert!(validate_issuer("http://localhost").is_ok());
494        assert!(validate_issuer("http://127.0.0.1").is_ok());
495        assert!(validate_issuer("http://[::1]").is_ok());
496    }
497
498    #[test]
499    fn test_validate_issuer_rejects_http_non_localhost() {
500        assert!(validate_issuer("http://example.com").is_err());
501    }
502
503    #[test]
504    fn test_validate_issuer_rejects_query_fragment() {
505        assert!(validate_issuer("https://example.com?foo=bar").is_err());
506        assert!(validate_issuer("https://example.com#frag").is_err());
507    }
508
509    #[test]
510    fn test_validate_issuer_rejects_userinfo() {
511        assert!(validate_issuer("https://user:pass@example.com").is_err());
512    }
513
514    #[test]
515    fn test_validate_issuer_rejects_path_traversal() {
516        assert!(validate_issuer("https://example.com/..").is_err());
517        assert!(validate_issuer("https://example.com/a/../b").is_err());
518    }
519
520    #[test]
521    fn test_validate_issuer_rejects_double_slash() {
522        assert!(validate_issuer("https://example.com//path").is_err());
523    }
524
525    #[test]
526    fn test_validate_issuer_rejects_tilde_issues() {
527        assert!(validate_issuer("https://example.com/path~").is_err());
528        assert!(validate_issuer("https://example.com/~~path").is_err());
529        assert!(validate_issuer("https://example.com/~").is_err());
530    }
531
532    #[test]
533    fn test_validate_issuer_rejects_dot_segment() {
534        assert!(validate_issuer("https://example.com/.").is_err());
535    }
536
537    #[test]
538    fn test_validate_issuer_rejects_long_segment() {
539        let long_segment = "a".repeat(151);
540        assert!(validate_issuer(&format!("https://example.com/{long_segment}")).is_err());
541    }
542
543    #[test]
544    fn test_validate_issuer_rejects_non_ascii_host() {
545        assert!(validate_issuer("https://exämple.com").is_err());
546    }
547
548    #[test]
549    fn test_validate_subject_valid() {
550        assert!(validate_subject("repo:org/repo:ref:refs/heads/main").is_ok());
551        assert!(validate_subject("user@example.com").is_ok());
552        assert!(validate_subject("simple-subject").is_ok());
553        assert!(validate_subject("pipe|separated").is_ok());
554    }
555
556    #[test]
557    fn test_validate_subject_rejects() {
558        assert!(validate_subject("").is_err());
559        assert!(validate_subject("has space").is_err());
560        assert!(validate_subject("has\"quote").is_err());
561        assert!(validate_subject("has'quote").is_err());
562        assert!(validate_subject("has\\backslash").is_err());
563        assert!(validate_subject("has<bracket").is_err());
564        assert!(validate_subject("has[bracket]").is_err());
565    }
566
567    #[test]
568    fn test_validate_audience_valid() {
569        assert!(validate_audience("https://example.com").is_ok());
570        assert!(validate_audience("my-audience").is_ok());
571    }
572
573    #[test]
574    fn test_validate_audience_more_restrictive_than_subject() {
575        // Subject allows these, audience rejects them
576        assert!(validate_subject("user@example.com").is_ok());
577        assert!(validate_audience("user@example.com").is_err());
578
579        assert!(validate_subject("pipe|value").is_ok());
580        assert!(validate_audience("pipe|value").is_err());
581
582        assert!(validate_subject("has[bracket]").is_err()); // subject also rejects []
583        assert!(validate_audience("has[bracket]").is_err());
584    }
585}