1use crate::error::Error;
2
3const MAX_RESPONSE_SIZE: usize = 100 * 1024; static PATH_CHAR_RE: std::sync::LazyLock<regex::Regex> =
6 std::sync::LazyLock::new(|| regex::Regex::new(r"^[a-zA-Z0-9\-._~/]+$").unwrap());
7
8pub fn validate_issuer(issuer: &str) -> Result<(), Error> {
9 if issuer.is_empty() || issuer.chars().count() > 255 {
10 return Err(Error::Unauthenticated(
11 "issuer empty or exceeds 255 characters".into(),
12 ));
13 }
14
15 let parsed = url::Url::parse(issuer)
16 .map_err(|_| Error::Unauthenticated("issuer is not a valid URL".into()))?;
17
18 match parsed.scheme() {
19 "https" => {}
20 "http" => match parsed.host() {
21 Some(url::Host::Domain("localhost")) => {}
22 Some(url::Host::Ipv4(ip)) if ip == std::net::Ipv4Addr::LOCALHOST => {}
23 Some(url::Host::Ipv6(ip)) if ip == std::net::Ipv6Addr::LOCALHOST => {}
24 _ => {
25 return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
26 }
27 },
28 _ => {
29 return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
30 }
31 }
32
33 if parsed.query().is_some() || parsed.fragment().is_some() {
35 return Err(Error::Unauthenticated(
36 "issuer must not contain query or fragment".into(),
37 ));
38 }
39 if issuer.contains('?') || issuer.contains('#') {
40 return Err(Error::Unauthenticated(
41 "issuer must not contain query or fragment".into(),
42 ));
43 }
44
45 if parsed.host_str().is_none() || parsed.host_str() == Some("") {
46 return Err(Error::Unauthenticated("issuer must have a host".into()));
47 }
48
49 if !parsed.username().is_empty() || parsed.password().is_some() {
50 return Err(Error::Unauthenticated(
51 "issuer must not contain userinfo".into(),
52 ));
53 }
54
55 let raw_host = {
58 let after_scheme = issuer
59 .strip_prefix(parsed.scheme())
60 .and_then(|s| s.strip_prefix("://"))
61 .unwrap_or("");
62 let host_part = if let Some(pos) = after_scheme.find('/') {
63 &after_scheme[..pos]
64 } else {
65 after_scheme
66 };
67 if host_part.starts_with('[') {
68 host_part.to_owned()
70 } else if let Some(pos) = host_part.rfind(':') {
71 host_part[..pos].to_owned()
72 } else {
73 host_part.to_owned()
74 }
75 };
76 for ch in raw_host.chars() {
77 if ch as u32 > 127 {
78 return Err(Error::Unauthenticated(
79 "issuer hostname must be ASCII-only".into(),
80 ));
81 }
82 if ch.is_control() || ch.is_whitespace() {
83 return Err(Error::Unauthenticated(
84 "issuer hostname contains invalid characters".into(),
85 ));
86 }
87 }
88
89 let raw_path = issuer
92 .strip_prefix(parsed.scheme())
93 .and_then(|s| s.strip_prefix("://"))
94 .and_then(|s| s.find('/').map(|pos| &s[pos..]))
95 .unwrap_or("");
96 let path = if raw_path.is_empty() {
97 parsed.path()
98 } else {
99 raw_path
100 };
101 if !path.is_empty() && path != "/" {
102 if !path.starts_with('/') {
103 return Err(Error::Unauthenticated(
104 "issuer path must start with /".into(),
105 ));
106 }
107 if path.contains("..") {
108 return Err(Error::Unauthenticated(
109 "issuer path must not contain ..".into(),
110 ));
111 }
112 if path.contains("//") {
113 return Err(Error::Unauthenticated(
114 "issuer path must not contain //".into(),
115 ));
116 }
117 if path.contains("~~") {
118 return Err(Error::Unauthenticated(
119 "issuer path must not contain ~~".into(),
120 ));
121 }
122 if path.ends_with('~') {
123 return Err(Error::Unauthenticated(
124 "issuer path must not end with ~".into(),
125 ));
126 }
127
128 if !PATH_CHAR_RE.is_match(path) {
129 return Err(Error::Unauthenticated(
130 "issuer path contains invalid characters".into(),
131 ));
132 }
133
134 for segment in path.split('/') {
135 if segment.is_empty() {
136 continue;
137 }
138 if segment == "." || segment == ".." || segment == "~" {
139 return Err(Error::Unauthenticated(
140 "issuer path contains invalid segment".into(),
141 ));
142 }
143 if segment.len() > 150 {
144 return Err(Error::Unauthenticated(
145 "issuer path segment exceeds 150 characters".into(),
146 ));
147 }
148 }
149 }
150
151 Ok(())
152}
153
154const SUBJECT_REJECT_CHARS: &str = "\"'`\\<>;&$(){}[]";
155const AUDIENCE_REJECT_CHARS: &str = "\"'`\\<>;|&$(){}[]@";
156
157pub fn validate_subject(value: &str) -> Result<(), Error> {
158 validate_claim_string(value, SUBJECT_REJECT_CHARS, "subject")
159}
160
161pub fn validate_audience(value: &str) -> Result<(), Error> {
162 validate_claim_string(value, AUDIENCE_REJECT_CHARS, "audience")
163}
164
165fn validate_claim_string(value: &str, reject_chars: &str, field: &str) -> Result<(), Error> {
166 if value.is_empty() {
167 return Err(Error::Unauthenticated(format!("{field} must not be empty")));
168 }
169 if value.chars().count() > 255 {
170 return Err(Error::Unauthenticated(format!(
171 "{field} exceeds 255 characters"
172 )));
173 }
174 for ch in value.chars() {
175 if (ch as u32) <= 0x1f {
176 return Err(Error::Unauthenticated(format!(
177 "{field} contains control characters"
178 )));
179 }
180 if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
181 return Err(Error::Unauthenticated(format!(
182 "{field} contains whitespace"
183 )));
184 }
185 if reject_chars.contains(ch) {
186 return Err(Error::Unauthenticated(format!(
187 "{field} contains invalid character"
188 )));
189 }
190 if !ch.is_alphanumeric() && !ch.is_ascii_punctuation() && ch as u32 > 127 {
191 if !is_printable(ch) {
193 return Err(Error::Unauthenticated(format!(
194 "{field} contains non-printable character"
195 )));
196 }
197 }
198 }
199 Ok(())
200}
201
202fn is_printable(ch: char) -> bool {
203 !ch.is_control() && ch as u32 != 0xFFFD
204}
205
206#[derive(Debug, serde::Deserialize)]
207pub(crate) struct OidcDiscoveryDocument {
208 pub(crate) issuer: String,
209 pub(crate) jwks_uri: String,
210}
211
212#[derive(Debug, Clone)]
213pub(crate) struct OidcProvider {
214 pub(crate) jwks: jsonwebtoken::jwk::JwkSet,
215}
216
217pub struct OidcVerifier {
218 http: reqwest::Client,
219 cache: moka::future::Cache<String, std::sync::Arc<OidcProvider>>,
220 allowed_issuers: Option<std::collections::HashSet<String>>,
221}
222
223impl Default for OidcVerifier {
224 fn default() -> Self {
225 Self::new(None)
226 }
227}
228
229#[derive(Debug, serde::Deserialize)]
230pub struct TokenClaims {
231 pub iss: String,
232 pub sub: String,
233 pub aud: OneOrMany,
234 #[serde(flatten)]
235 pub extra: std::collections::HashMap<String, serde_json::Value>,
236}
237
238#[derive(Debug, serde::Deserialize)]
239#[serde(untagged)]
240pub enum OneOrMany {
241 One(String),
242 Many(Vec<String>),
243}
244
245impl OneOrMany {
246 pub fn iter(&self) -> impl Iterator<Item = &str> {
247 let slice: &[String] = match self {
248 OneOrMany::One(s) => std::slice::from_ref(s),
249 OneOrMany::Many(v) => v.as_slice(),
250 };
251 slice.iter().map(|s| s.as_str())
252 }
253}
254
255impl OidcVerifier {
256 pub fn new(allowed_issuer_urls: Option<Vec<String>>) -> Self {
257 let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
258 let url_str = attempt.url().to_string();
259 if validate_issuer(&url_str).is_err() {
260 attempt.error(std::io::Error::new(
261 std::io::ErrorKind::PermissionDenied,
262 format!("redirect to invalid issuer URL: {url_str}"),
263 ))
264 } else {
265 attempt.follow()
266 }
267 });
268
269 let http = reqwest::Client::builder()
270 .connect_timeout(std::time::Duration::from_secs(10))
271 .timeout(std::time::Duration::from_secs(30))
272 .redirect(redirect_policy)
273 .user_agent(format!("sts-cat/{}", env!("CARGO_PKG_VERSION")))
274 .build()
275 .expect("failed to build OIDC HTTP client");
276
277 let cache = moka::future::Cache::builder()
278 .max_capacity(100)
279 .time_to_live(std::time::Duration::from_secs(900))
280 .build();
281
282 let allowed_issuers = allowed_issuer_urls.map(|urls| {
283 urls.into_iter()
284 .map(|u| u.trim_end_matches('/').to_owned())
285 .collect()
286 });
287
288 Self {
289 http,
290 cache,
291 allowed_issuers,
292 }
293 }
294
295 async fn discover(&self, issuer: &str) -> Result<std::sync::Arc<OidcProvider>, Error> {
296 if let Some(provider) = self.cache.get(issuer).await {
297 return Ok(provider);
298 }
299
300 let provider = self.discover_with_retry(issuer).await?;
301 let provider = std::sync::Arc::new(provider);
302 self.cache.insert(issuer.to_owned(), provider.clone()).await;
303 Ok(provider)
304 }
305
306 async fn discover_with_retry(&self, issuer: &str) -> Result<OidcProvider, Error> {
307 use backon::Retryable as _;
308
309 let discover_fn = || async { self.discover_once(issuer).await };
310
311 discover_fn
312 .retry(
313 backon::ExponentialBuilder::default()
314 .with_min_delay(std::time::Duration::from_secs(1))
315 .with_max_delay(std::time::Duration::from_secs(30))
316 .with_factor(2.0)
317 .with_jitter()
318 .with_max_times(6),
319 )
320 .when(|e| !is_permanent_error(e))
321 .await
322 }
323
324 async fn discover_once(&self, issuer: &str) -> Result<OidcProvider, Error> {
325 let discovery_url = format!(
326 "{}/.well-known/openid-configuration",
327 issuer.trim_end_matches('/')
328 );
329
330 let resp = self
331 .http
332 .get(&discovery_url)
333 .send()
334 .await
335 .map_err(Error::OidcDiscovery)?;
336
337 let status = resp.status();
338 if !status.is_success() {
339 return Err(Error::OidcHttpError(status.as_u16()));
340 }
341
342 let body = read_limited_body(resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
343 let doc: OidcDiscoveryDocument =
344 serde_json::from_slice(&body).map_err(|e| Error::Internal(Box::new(e)))?;
345
346 let expected = issuer.trim_end_matches('/');
347 let actual = doc.issuer.trim_end_matches('/');
348 if expected != actual {
349 return Err(Error::Unauthenticated(
350 "OIDC discovery issuer mismatch".into(),
351 ));
352 }
353
354 let jwks_resp = self
355 .http
356 .get(&doc.jwks_uri)
357 .send()
358 .await
359 .map_err(Error::OidcDiscovery)?;
360
361 if !jwks_resp.status().is_success() {
362 return Err(Error::OidcHttpError(jwks_resp.status().as_u16()));
363 }
364
365 let jwks_body =
366 read_limited_body(jwks_resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
367 let jwks: jsonwebtoken::jwk::JwkSet =
368 serde_json::from_slice(&jwks_body).map_err(|e| Error::Internal(Box::new(e)))?;
369
370 Ok(OidcProvider { jwks })
371 }
372
373 pub async fn verify(&self, token: &str) -> Result<TokenClaims, Error> {
374 let header = jsonwebtoken::decode_header(token)?;
375
376 let mut validation = jsonwebtoken::Validation::default();
378 validation.insecure_disable_signature_validation();
379 validation.validate_aud = false;
380 validation.validate_exp = false;
381
382 let unverified: jsonwebtoken::TokenData<TokenClaims> = jsonwebtoken::decode(
383 token,
384 &jsonwebtoken::DecodingKey::from_secret(&[]),
385 &validation,
386 )?;
387
388 let issuer = &unverified.claims.iss;
389
390 validate_issuer(issuer)?;
391
392 if let Some(ref allowed) = self.allowed_issuers {
393 let normalized = issuer.trim_end_matches('/');
394 if !allowed.contains(normalized) {
395 return Err(Error::Unauthenticated("issuer not in allowed list".into()));
396 }
397 }
398
399 let provider = self.discover(issuer).await?;
400
401 let kid = header.kid.as_deref();
402 let decoding_key = find_decoding_key(&provider.jwks, kid, &header.alg)?;
403
404 let mut verification = jsonwebtoken::Validation::new(header.alg);
405 verification.validate_aud = false; verification.set_issuer(&[issuer]);
407
408 let token_data: jsonwebtoken::TokenData<TokenClaims> =
409 jsonwebtoken::decode(token, &decoding_key, &verification)?;
410
411 Ok(token_data.claims)
412 }
413}
414
415fn find_decoding_key(
416 jwks: &jsonwebtoken::jwk::JwkSet,
417 kid: Option<&str>,
418 alg: &jsonwebtoken::Algorithm,
419) -> Result<jsonwebtoken::DecodingKey, Error> {
420 let jwk = if let Some(kid) = kid {
421 jwks.find(kid).ok_or_else(|| {
422 Error::Unauthenticated(format!("no matching key found for kid: {kid}"))
423 })?
424 } else {
425 let alg_str = format!("{alg:?}");
426 jwks.keys
427 .iter()
428 .find(|k| {
429 k.common
430 .key_algorithm
431 .is_some_and(|ka| format!("{ka:?}") == alg_str)
432 })
433 .or_else(|| jwks.keys.first())
434 .ok_or_else(|| Error::Unauthenticated("no keys in JWKS".into()))?
435 };
436
437 jsonwebtoken::DecodingKey::from_jwk(jwk)
438 .map_err(|e| Error::Unauthenticated(format!("invalid JWK: {e}")))
439}
440
441fn is_permanent_error(e: &Error) -> bool {
442 match e {
443 Error::OidcHttpError(code) => matches!(
445 code,
446 400 | 401 | 403 | 404 | 405 | 406 | 410 | 415 | 422 | 501
447 ),
448 Error::OidcDiscovery(_) => false, _ => true, }
451}
452
453pub(crate) async fn read_limited_body(
454 resp: reqwest::Response,
455 limit: usize,
456 map_err: impl Fn(reqwest::Error) -> Error,
457) -> Result<Vec<u8>, Error> {
458 if let Some(len) = resp.content_length()
459 && len as usize > limit
460 {
461 return Err(Error::Unauthenticated(format!(
462 "response too large: {len} bytes (limit: {limit})"
463 )));
464 }
465
466 use futures_util::StreamExt as _;
467 let initial_capacity = resp
468 .content_length()
469 .map_or(4096, |len| (len as usize).min(limit));
470 let mut stream = resp.bytes_stream();
471 let mut buf = Vec::with_capacity(initial_capacity);
472 while let Some(chunk) = stream.next().await {
473 let chunk = chunk.map_err(&map_err)?;
474 if buf.len() + chunk.len() > limit {
475 return Err(Error::Unauthenticated(format!(
476 "response too large (limit: {limit})"
477 )));
478 }
479 buf.extend_from_slice(&chunk);
480 }
481 Ok(buf)
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_validate_issuer_valid() {
490 assert!(validate_issuer("https://accounts.google.com").is_ok());
491 assert!(validate_issuer("https://token.actions.githubusercontent.com").is_ok());
492 assert!(validate_issuer("https://example.com/path/to/issuer").is_ok());
493 assert!(validate_issuer("http://localhost").is_ok());
494 assert!(validate_issuer("http://127.0.0.1").is_ok());
495 assert!(validate_issuer("http://[::1]").is_ok());
496 }
497
498 #[test]
499 fn test_validate_issuer_rejects_http_non_localhost() {
500 assert!(validate_issuer("http://example.com").is_err());
501 }
502
503 #[test]
504 fn test_validate_issuer_rejects_query_fragment() {
505 assert!(validate_issuer("https://example.com?foo=bar").is_err());
506 assert!(validate_issuer("https://example.com#frag").is_err());
507 }
508
509 #[test]
510 fn test_validate_issuer_rejects_userinfo() {
511 assert!(validate_issuer("https://user:pass@example.com").is_err());
512 }
513
514 #[test]
515 fn test_validate_issuer_rejects_path_traversal() {
516 assert!(validate_issuer("https://example.com/..").is_err());
517 assert!(validate_issuer("https://example.com/a/../b").is_err());
518 }
519
520 #[test]
521 fn test_validate_issuer_rejects_double_slash() {
522 assert!(validate_issuer("https://example.com//path").is_err());
523 }
524
525 #[test]
526 fn test_validate_issuer_rejects_tilde_issues() {
527 assert!(validate_issuer("https://example.com/path~").is_err());
528 assert!(validate_issuer("https://example.com/~~path").is_err());
529 assert!(validate_issuer("https://example.com/~").is_err());
530 }
531
532 #[test]
533 fn test_validate_issuer_rejects_dot_segment() {
534 assert!(validate_issuer("https://example.com/.").is_err());
535 }
536
537 #[test]
538 fn test_validate_issuer_rejects_long_segment() {
539 let long_segment = "a".repeat(151);
540 assert!(validate_issuer(&format!("https://example.com/{long_segment}")).is_err());
541 }
542
543 #[test]
544 fn test_validate_issuer_rejects_non_ascii_host() {
545 assert!(validate_issuer("https://exämple.com").is_err());
546 }
547
548 #[test]
549 fn test_validate_subject_valid() {
550 assert!(validate_subject("repo:org/repo:ref:refs/heads/main").is_ok());
551 assert!(validate_subject("user@example.com").is_ok());
552 assert!(validate_subject("simple-subject").is_ok());
553 assert!(validate_subject("pipe|separated").is_ok());
554 }
555
556 #[test]
557 fn test_validate_subject_rejects() {
558 assert!(validate_subject("").is_err());
559 assert!(validate_subject("has space").is_err());
560 assert!(validate_subject("has\"quote").is_err());
561 assert!(validate_subject("has'quote").is_err());
562 assert!(validate_subject("has\\backslash").is_err());
563 assert!(validate_subject("has<bracket").is_err());
564 assert!(validate_subject("has[bracket]").is_err());
565 }
566
567 #[test]
568 fn test_validate_audience_valid() {
569 assert!(validate_audience("https://example.com").is_ok());
570 assert!(validate_audience("my-audience").is_ok());
571 }
572
573 #[test]
574 fn test_validate_audience_more_restrictive_than_subject() {
575 assert!(validate_subject("user@example.com").is_ok());
577 assert!(validate_audience("user@example.com").is_err());
578
579 assert!(validate_subject("pipe|value").is_ok());
580 assert!(validate_audience("pipe|value").is_err());
581
582 assert!(validate_subject("has[bracket]").is_err()); assert!(validate_audience("has[bracket]").is_err());
584 }
585}