1use std::time::{SystemTime, UNIX_EPOCH};
2
3use reqwest::blocking::Client;
4use reqwest::header::CONTENT_TYPE;
5use ring::digest::{Context, SHA256};
6use serde::Deserialize;
7use serde_json::Value;
8
9use crate::error::ViaError;
10use crate::redaction::Redactor;
11use crate::secrets::SecretValue;
12
13const CACHE_EXPIRY_SKEW_SECONDS: i64 = 60;
14const SERVICE_OAUTH_TYPE: &str = "service_oauth";
15
16pub fn access_token(credential: &SecretValue, redactor: &mut Redactor) -> Result<String, ViaError> {
17 access_token_with_mode(credential, redactor, crate::daemon::OAuthTokenMode::Cached)
18}
19
20pub fn refresh_access_token(
21 credential: &SecretValue,
22 redactor: &mut Redactor,
23) -> Result<String, ViaError> {
24 access_token_with_mode(credential, redactor, crate::daemon::OAuthTokenMode::Refresh)
25}
26
27fn access_token_with_mode(
28 credential: &SecretValue,
29 redactor: &mut Redactor,
30 mode: crate::daemon::OAuthTokenMode,
31) -> Result<String, ViaError> {
32 redactor.add(credential.expose());
33 let bundle = CredentialBundle::parse(credential.expose())?;
34 register_bundle_secrets(&bundle, redactor);
35
36 let token = crate::daemon::oauth_access_token(credential.expose(), mode)?;
37 redactor.add(token.expose());
38 Ok(token.expose().to_owned())
39}
40
41pub fn validate_credential_bundle(raw: &str) -> Result<(), ViaError> {
42 CredentialBundle::parse(raw).map(|_| ())
43}
44
45pub(crate) fn exchange_access_token(
46 client: &Client,
47 bundle: &CredentialBundle,
48 cached: Option<&CachedOAuthToken>,
49 redactor: &mut Redactor,
50) -> Result<OAuthAccessToken, ViaError> {
51 match &bundle.grant {
52 OAuthGrant::RefreshToken { refresh_token } => {
53 let cached_refresh_token = cached.and_then(|cached| cached.refresh_token.as_deref());
54 let refresh_token_for_request = cached_refresh_token.unwrap_or(refresh_token);
55 match exchange_refresh_token(client, bundle, refresh_token_for_request, redactor) {
56 Ok(token) => Ok(token),
57 Err(_error)
58 if cached_refresh_token.is_some_and(|cached| cached != refresh_token) =>
59 {
60 crate::timing::event(
61 "oauth refresh token fallback",
62 "cached_refresh_token_failed",
63 );
64 exchange_refresh_token(client, bundle, refresh_token, redactor)
65 }
66 Err(error) => Err(error),
67 }
68 }
69 OAuthGrant::ClientCredentials { .. } => {
70 exchange_client_credentials(client, bundle, redactor)
71 }
72 }
73}
74
75fn exchange_refresh_token(
76 client: &Client,
77 bundle: &CredentialBundle,
78 refresh_token: &str,
79 redactor: &mut Redactor,
80) -> Result<OAuthAccessToken, ViaError> {
81 redactor.add(refresh_token);
82 let mut form = vec![
83 ("grant_type", "refresh_token"),
84 ("refresh_token", refresh_token),
85 ("client_id", bundle.client_id.as_str()),
86 ];
87 if let Some(client_secret) = bundle.client_secret.as_deref() {
88 form.push(("client_secret", client_secret));
89 }
90
91 exchange_token_form(
92 client,
93 bundle,
94 &form,
95 TokenResponseRefreshMode::PreserveRefreshToken(refresh_token),
96 redactor,
97 )
98}
99
100fn exchange_client_credentials(
101 client: &Client,
102 bundle: &CredentialBundle,
103 redactor: &mut Redactor,
104) -> Result<OAuthAccessToken, ViaError> {
105 let OAuthGrant::ClientCredentials { scope } = &bundle.grant else {
106 unreachable!("caller only passes client_credentials grants");
107 };
108 let client_secret = bundle.client_secret.as_deref().ok_or_else(|| {
109 ViaError::InvalidConfig(
110 "oauth client_credentials credential bundle must include `client_secret`".to_owned(),
111 )
112 })?;
113
114 let form = vec![
115 ("grant_type", "client_credentials"),
116 ("scope", scope.as_str()),
117 ("client_id", bundle.client_id.as_str()),
118 ("client_secret", client_secret),
119 ];
120
121 exchange_token_form(
122 client,
123 bundle,
124 &form,
125 TokenResponseRefreshMode::NoRefreshToken,
126 redactor,
127 )
128}
129
130fn exchange_token_form(
131 client: &Client,
132 bundle: &CredentialBundle,
133 form: &[(&str, &str)],
134 refresh_mode: TokenResponseRefreshMode<'_>,
135 redactor: &mut Redactor,
136) -> Result<OAuthAccessToken, ViaError> {
137 let body = form_encode(form);
138 let exchange_span = crate::timing::span("oauth token exchange");
139 let response = match client
140 .post(&bundle.token_url)
141 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
142 .body(body)
143 .send()
144 {
145 Ok(response) => {
146 let status = response.status();
147 exchange_span.finish(format!("status={status}"));
148 response
149 }
150 Err(error) => {
151 exchange_span.finish("failed");
152 return Err(error.into());
153 }
154 };
155 let status = response.status();
156 let body_span = crate::timing::span("oauth token body");
157 let body = match response.text() {
158 Ok(body) => {
159 body_span.finish(format!("bytes={}", body.len()));
160 body
161 }
162 Err(error) => {
163 body_span.finish("failed");
164 return Err(error.into());
165 }
166 };
167
168 if !status.is_success() {
169 let body = redactor.redact(&body);
170 return Err(ViaError::InvalidArgument(format!(
171 "OAuth token exchange failed with status {status}: {body}"
172 )));
173 }
174
175 parse_token_response(&body, refresh_mode, redactor)
176}
177
178fn parse_token_response(
179 body: &str,
180 refresh_mode: TokenResponseRefreshMode<'_>,
181 redactor: &mut Redactor,
182) -> Result<OAuthAccessToken, ViaError> {
183 let response: TokenResponse = serde_json::from_str(body)?;
184 if let Some(token_type) = &response.token_type {
185 if !token_type.eq_ignore_ascii_case("bearer") {
186 return Err(ViaError::InvalidArgument(format!(
187 "OAuth token response had unsupported token_type `{token_type}`"
188 )));
189 }
190 }
191
192 let refresh_token = match refresh_mode {
193 TokenResponseRefreshMode::PreserveRefreshToken(refresh_token) => Some(
194 response
195 .refresh_token
196 .unwrap_or_else(|| refresh_token.to_owned()),
197 ),
198 TokenResponseRefreshMode::NoRefreshToken => response.refresh_token,
199 };
200 let expires_at = expires_at(response.expires_in)?;
201
202 redactor.add(&response.access_token);
203 if let Some(refresh_token) = &refresh_token {
204 redactor.add(refresh_token);
205 }
206
207 Ok(OAuthAccessToken {
208 access_token: response.access_token,
209 refresh_token,
210 expires_at,
211 })
212}
213
214fn expires_at(expires_in: u64) -> Result<i64, ViaError> {
215 let now = unix_timestamp()?;
216 let expires_in = i64::try_from(expires_in).map_err(|_| {
217 ViaError::InvalidArgument("OAuth token response expires_in is too large".to_owned())
218 })?;
219 now.checked_add(expires_in).ok_or_else(|| {
220 ViaError::InvalidArgument("OAuth token response expires_at is too large".to_owned())
221 })
222}
223
224pub(crate) fn register_bundle_secrets(bundle: &CredentialBundle, redactor: &mut Redactor) {
225 if let Some(client_secret) = &bundle.client_secret {
226 redactor.add(client_secret);
227 }
228 match &bundle.grant {
229 OAuthGrant::RefreshToken { refresh_token } => redactor.add(refresh_token),
230 OAuthGrant::ClientCredentials { .. } => {}
231 }
232}
233
234pub(crate) fn register_cached_secrets(cached: Option<&CachedOAuthToken>, redactor: &mut Redactor) {
235 if let Some(cached) = cached {
236 redactor.add(&cached.access_token);
237 if let Some(refresh_token) = &cached.refresh_token {
238 redactor.add(refresh_token);
239 }
240 }
241}
242
243#[derive(Debug, PartialEq, Eq)]
244pub(crate) struct CredentialBundle {
245 credential_type: String,
246 pub(crate) token_url: String,
247 pub(crate) client_id: String,
248 pub(crate) client_secret: Option<String>,
249 grant: OAuthGrant,
250}
251
252impl CredentialBundle {
253 pub(crate) fn parse(raw: &str) -> Result<Self, ViaError> {
254 let value: Value = serde_json::from_str(raw).map_err(credential_json_error)?;
255 let credential_type = required_string(&value, "type")?;
256 validate_credential_type(&credential_type)?;
257 let token_url = required_string(&value, "token_url")?;
258 let client_id = required_string(&value, "client_id")?;
259 let client_secret = optional_string(&value, "client_secret")?;
260 let configured_grant_type = optional_string(&value, "grant_type")?;
261 let configured_refresh_token = optional_string(&value, "refresh_token")?;
262 let grant = match configured_grant_type.as_deref() {
263 Some("refresh_token") => OAuthGrant::RefreshToken {
264 refresh_token: configured_refresh_token.ok_or_else(|| {
265 ViaError::InvalidConfig(
266 "oauth refresh_token credential bundle must include `refresh_token`"
267 .to_owned(),
268 )
269 })?,
270 },
271 Some("client_credentials") => OAuthGrant::ClientCredentials {
272 scope: required_string(&value, "scope")?,
273 },
274 Some(grant_type) => {
275 return Err(ViaError::InvalidConfig(format!(
276 "unsupported oauth grant_type `{grant_type}`"
277 )));
278 }
279 None => match configured_refresh_token {
280 Some(refresh_token) => OAuthGrant::RefreshToken { refresh_token },
281 None => {
282 return Err(ViaError::InvalidConfig(
283 "oauth credential bundle must include `grant_type`".to_owned(),
284 ));
285 }
286 },
287 };
288
289 Ok(Self {
290 credential_type,
291 token_url,
292 client_id,
293 client_secret,
294 grant,
295 })
296 }
297}
298
299fn validate_credential_type(value: &str) -> Result<(), ViaError> {
300 if value == SERVICE_OAUTH_TYPE {
301 return Ok(());
302 }
303
304 Err(ViaError::InvalidConfig(format!(
305 "unsupported oauth credential type `{value}`; expected `{SERVICE_OAUTH_TYPE}`"
306 )))
307}
308
309#[derive(Debug, PartialEq, Eq)]
310enum OAuthGrant {
311 RefreshToken { refresh_token: String },
312 ClientCredentials { scope: String },
313}
314
315#[derive(Clone, Copy)]
316enum TokenResponseRefreshMode<'a> {
317 PreserveRefreshToken(&'a str),
318 NoRefreshToken,
319}
320
321#[derive(Debug, Deserialize)]
322struct TokenResponse {
323 access_token: String,
324 #[serde(default)]
325 token_type: Option<String>,
326 expires_in: u64,
327 #[serde(default)]
328 refresh_token: Option<String>,
329}
330
331#[derive(Debug)]
332pub(crate) struct OAuthAccessToken {
333 pub(crate) access_token: String,
334 pub(crate) refresh_token: Option<String>,
335 pub(crate) expires_at: i64,
336}
337
338#[derive(Clone, Debug, Deserialize)]
339pub(crate) struct CachedOAuthToken {
340 pub(crate) access_token: String,
341 pub(crate) expires_at: i64,
342 #[serde(default)]
343 pub(crate) refresh_token: Option<String>,
344}
345
346pub(crate) fn cache_key(bundle: &CredentialBundle) -> String {
347 let mut context = Context::new(&SHA256);
348 context.update(bundle.credential_type.as_bytes());
349 context.update(b"\0");
350 context.update(bundle.token_url.as_bytes());
351 context.update(b"\0");
352 context.update(bundle.client_id.as_bytes());
353 context.update(b"\0");
354 match &bundle.grant {
355 OAuthGrant::RefreshToken { refresh_token } => {
356 context.update(b"refresh_token\0");
357 context.update(refresh_token.as_bytes());
358 }
359 OAuthGrant::ClientCredentials { scope } => {
360 context.update(b"client_credentials\0");
361 context.update(scope.as_bytes());
362 }
363 }
364 hex_encode(context.finish().as_ref())
365}
366
367pub(crate) fn cached_access_token(cached: Option<&CachedOAuthToken>, now: i64) -> Option<String> {
368 let cached = cached?;
369 if cached.expires_at <= now + CACHE_EXPIRY_SKEW_SECONDS {
370 return None;
371 }
372 Some(cached.access_token.clone())
373}
374
375pub(crate) fn unix_timestamp() -> Result<i64, ViaError> {
376 let duration = SystemTime::now()
377 .duration_since(UNIX_EPOCH)
378 .map_err(|_| ViaError::InvalidConfig("system clock is before UNIX epoch".to_owned()))?;
379 i64::try_from(duration.as_secs())
380 .map_err(|_| ViaError::InvalidConfig("system clock timestamp is too large".to_owned()))
381}
382
383fn hex_encode(bytes: &[u8]) -> String {
384 const HEX: &[u8; 16] = b"0123456789abcdef";
385 let mut encoded = String::with_capacity(bytes.len() * 2);
386 for byte in bytes {
387 encoded.push(HEX[(byte >> 4) as usize] as char);
388 encoded.push(HEX[(byte & 0x0f) as usize] as char);
389 }
390 encoded
391}
392
393fn form_encode(fields: &[(&str, &str)]) -> String {
394 fields
395 .iter()
396 .map(|(name, value)| {
397 format!(
398 "{}={}",
399 form_percent_encode(name),
400 form_percent_encode(value)
401 )
402 })
403 .collect::<Vec<_>>()
404 .join("&")
405}
406
407fn form_percent_encode(value: &str) -> String {
408 let mut encoded = String::new();
409 for byte in value.bytes() {
410 match byte {
411 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
412 encoded.push(byte as char)
413 }
414 b' ' => encoded.push('+'),
415 _ => encoded.push_str(&format!("%{byte:02X}")),
416 }
417 }
418 encoded
419}
420
421fn credential_json_error(error: serde_json::Error) -> ViaError {
422 ViaError::InvalidConfig(format!(
423 "oauth credential bundle must be valid JSON: {error}"
424 ))
425}
426
427fn required_string(value: &Value, field: &str) -> Result<String, ViaError> {
428 value
429 .get(field)
430 .and_then(Value::as_str)
431 .filter(|value| !value.trim().is_empty())
432 .map(str::to_owned)
433 .ok_or_else(|| {
434 ViaError::InvalidConfig(format!(
435 "oauth credential bundle must include non-empty `{field}`"
436 ))
437 })
438}
439
440fn optional_string(value: &Value, field: &str) -> Result<Option<String>, ViaError> {
441 match value.get(field) {
442 Some(Value::String(value)) if !value.trim().is_empty() => Ok(Some(value.to_owned())),
443 Some(Value::String(_)) | None => Ok(None),
444 Some(_) => Err(ViaError::InvalidConfig(format!(
445 "oauth credential bundle `{field}` must be a string"
446 ))),
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use std::io::{Read, Write};
454 use std::net::TcpListener;
455 use std::thread;
456
457 const LINEAR_TOKEN_URL: &str = "https://api.linear.app/oauth/token";
458
459 #[test]
460 fn parses_service_refresh_token_bundle() {
461 let bundle = CredentialBundle::parse(
462 &serde_json::json!({
463 "type": "service_oauth",
464 "token_url": LINEAR_TOKEN_URL,
465 "grant_type": "refresh_token",
466 "client_id": "client-id",
467 "client_secret": "client-secret",
468 "refresh_token": "refresh-token",
469 })
470 .to_string(),
471 )
472 .unwrap();
473
474 assert_eq!(bundle.credential_type, SERVICE_OAUTH_TYPE);
475 assert_eq!(bundle.token_url, LINEAR_TOKEN_URL);
476 assert_eq!(bundle.client_id, "client-id");
477 assert_eq!(bundle.client_secret.as_deref(), Some("client-secret"));
478 assert_eq!(
479 bundle.grant,
480 OAuthGrant::RefreshToken {
481 refresh_token: "refresh-token".to_owned()
482 }
483 );
484 }
485
486 #[test]
487 fn parses_service_client_credentials_bundle() {
488 let bundle = CredentialBundle::parse(
489 &serde_json::json!({
490 "type": "service_oauth",
491 "token_url": LINEAR_TOKEN_URL,
492 "grant_type": "client_credentials",
493 "client_id": "client-id",
494 "client_secret": "client-secret",
495 "scope": "read,issues:create",
496 })
497 .to_string(),
498 )
499 .unwrap();
500
501 assert_eq!(
502 bundle.grant,
503 OAuthGrant::ClientCredentials {
504 scope: "read,issues:create".to_owned()
505 }
506 );
507 }
508
509 #[test]
510 fn rejects_unsupported_oauth_credential_type() {
511 let error = CredentialBundle::parse(
512 &serde_json::json!({
513 "type": "example_oauth",
514 "token_url": LINEAR_TOKEN_URL,
515 "grant_type": "refresh_token",
516 "client_id": "client-id",
517 "refresh_token": "refresh-token",
518 })
519 .to_string(),
520 )
521 .unwrap_err();
522
523 assert!(matches!(
524 error,
525 ViaError::InvalidConfig(message) if message.contains("unsupported oauth credential type")
526 ));
527 }
528
529 #[test]
530 fn validates_credential_bundle() {
531 validate_credential_bundle(
532 &serde_json::json!({
533 "type": "service_oauth",
534 "token_url": LINEAR_TOKEN_URL,
535 "client_id": "client-id",
536 "refresh_token": "refresh-token",
537 })
538 .to_string(),
539 )
540 .unwrap();
541 }
542
543 #[test]
544 fn returns_unexpired_cached_oauth_token() {
545 let cached = CachedOAuthToken {
546 access_token: "cached-access-token".to_owned(),
547 expires_at: unix_timestamp().unwrap() + 3_600,
548 refresh_token: Some("cached-refresh-token".to_owned()),
549 };
550
551 let token = cached_access_token(Some(&cached), unix_timestamp().unwrap()).unwrap();
552
553 assert_eq!(token, "cached-access-token");
554 }
555
556 #[test]
557 fn refreshes_and_returns_rotated_refresh_token() {
558 let response_body = serde_json::json!({
559 "access_token": "fresh-access-token",
560 "token_type": "Bearer",
561 "expires_in": 3600,
562 "refresh_token": "rotated-refresh-token",
563 "scope": "read write",
564 })
565 .to_string();
566 let (token_url, server) = token_server(response_body);
567 let bundle = test_refresh_bundle(&token_url);
568
569 let client = Client::new();
570 let mut redactor = Redactor::new();
571 let token = exchange_access_token(&client, &bundle, None, &mut redactor).unwrap();
572 let request = server.join().unwrap();
573
574 assert_eq!(token.access_token, "fresh-access-token");
575 assert!(request.starts_with("POST /oauth/token "));
576 assert!(request.contains("content-type: application/x-www-form-urlencoded"));
577 assert!(request.contains("grant_type=refresh_token"));
578 assert!(request.contains("refresh_token=configured-refresh-token"));
579 assert_eq!(
580 token.refresh_token.as_deref(),
581 Some("rotated-refresh-token")
582 );
583 assert_eq!(
584 redactor.redact("fresh-access-token rotated-refresh-token configured-refresh-token"),
585 "[REDACTED] [REDACTED] [REDACTED]"
586 );
587 }
588
589 #[test]
590 fn refreshes_and_preserves_current_refresh_token_when_response_omits_rotation() {
591 let response_body = serde_json::json!({
592 "access_token": "fresh-access-token",
593 "token_type": "Bearer",
594 "expires_in": 3600,
595 })
596 .to_string();
597 let (token_url, server) = token_server(response_body);
598 let bundle = test_refresh_bundle(&token_url);
599
600 let client = Client::new();
601 let mut redactor = Redactor::new();
602 let token = exchange_access_token(&client, &bundle, None, &mut redactor).unwrap();
603 let request = server.join().unwrap();
604
605 assert_eq!(token.access_token, "fresh-access-token");
606 assert!(request.contains("grant_type=refresh_token"));
607 assert_eq!(
608 token.refresh_token.as_deref(),
609 Some("configured-refresh-token")
610 );
611 }
612
613 #[test]
614 fn exchanges_client_credentials_and_returns_access_token() {
615 let response_body = serde_json::json!({
616 "access_token": "client-access-token",
617 "token_type": "Bearer",
618 "expires_in": 3600,
619 "scope": "read issues:create",
620 })
621 .to_string();
622 let (token_url, server) = token_server(response_body);
623 let bundle = CredentialBundle {
624 credential_type: SERVICE_OAUTH_TYPE.to_owned(),
625 token_url,
626 client_id: "client-id".to_owned(),
627 client_secret: Some("client-secret".to_owned()),
628 grant: OAuthGrant::ClientCredentials {
629 scope: "read,issues:create".to_owned(),
630 },
631 };
632
633 let client = Client::new();
634 let mut redactor = Redactor::new();
635 let token = exchange_access_token(&client, &bundle, None, &mut redactor).unwrap();
636 let request = server.join().unwrap();
637
638 assert_eq!(token.access_token, "client-access-token");
639 assert!(request.contains("grant_type=client_credentials"));
640 assert!(request.contains("scope=read%2Cissues%3Acreate"));
641 assert!(request.contains("client_secret=client-secret"));
642 }
643
644 #[test]
645 fn rejects_non_bearer_token_response() {
646 let mut redactor = Redactor::new();
647 let error = parse_token_response(
648 &serde_json::json!({
649 "access_token": "access-token",
650 "token_type": "mac",
651 "expires_in": 3600,
652 "refresh_token": "refresh-token",
653 })
654 .to_string(),
655 TokenResponseRefreshMode::PreserveRefreshToken("refresh-token"),
656 &mut redactor,
657 )
658 .unwrap_err();
659
660 assert!(
661 matches!(error, ViaError::InvalidArgument(message) if message.contains("token_type"))
662 );
663 }
664
665 fn test_refresh_bundle(token_url: &str) -> CredentialBundle {
666 CredentialBundle {
667 credential_type: SERVICE_OAUTH_TYPE.to_owned(),
668 token_url: token_url.to_owned(),
669 client_id: "client-id".to_owned(),
670 client_secret: Some("client-secret".to_owned()),
671 grant: OAuthGrant::RefreshToken {
672 refresh_token: "configured-refresh-token".to_owned(),
673 },
674 }
675 }
676
677 fn token_server(response_body: String) -> (String, thread::JoinHandle<String>) {
678 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
679 let address = listener.local_addr().unwrap();
680 let handle = thread::spawn(move || {
681 let (mut stream, _) = listener.accept().unwrap();
682 let mut buffer = [0_u8; 8192];
683 let read = stream.read(&mut buffer).unwrap();
684 let request = String::from_utf8_lossy(&buffer[..read]).to_string();
685 let response = format!(
686 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
687 response_body.len(),
688 response_body
689 );
690 stream.write_all(response.as_bytes()).unwrap();
691 request
692 });
693
694 (format!("http://{address}/oauth/token"), handle)
695 }
696}