1use std::collections::BTreeMap;
2use std::fs::{self, File};
3use std::io::{self, Read};
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use sha2::{Digest, Sha256};
9
10use crate::config::OAuthConfig;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct OAuthTokenSet {
14 pub access_token: String,
15 pub refresh_token: Option<String>,
16 pub expires_at: Option<u64>,
17 pub scopes: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct PkceCodePair {
22 pub verifier: String,
23 pub challenge: String,
24 pub challenge_method: PkceChallengeMethod,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum PkceChallengeMethod {
29 S256,
30}
31
32impl PkceChallengeMethod {
33 #[must_use]
34 pub const fn as_str(self) -> &'static str {
35 match self {
36 Self::S256 => "S256",
37 }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct OAuthAuthorizationRequest {
43 pub authorize_url: String,
44 pub client_id: String,
45 pub redirect_uri: String,
46 pub scopes: Vec<String>,
47 pub state: String,
48 pub code_challenge: String,
49 pub code_challenge_method: PkceChallengeMethod,
50 pub extra_params: BTreeMap<String, String>,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct OAuthTokenExchangeRequest {
55 pub grant_type: &'static str,
56 pub code: String,
57 pub redirect_uri: String,
58 pub client_id: String,
59 pub code_verifier: String,
60 pub state: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct OAuthRefreshRequest {
65 pub grant_type: &'static str,
66 pub refresh_token: String,
67 pub client_id: String,
68 pub scopes: Vec<String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct OAuthCallbackParams {
73 pub code: Option<String>,
74 pub state: Option<String>,
75 pub error: Option<String>,
76 pub error_description: Option<String>,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81struct StoredOAuthCredentials {
82 access_token: String,
83 #[serde(default)]
84 refresh_token: Option<String>,
85 #[serde(default)]
86 expires_at: Option<u64>,
87 #[serde(default)]
88 scopes: Vec<String>,
89}
90
91impl From<OAuthTokenSet> for StoredOAuthCredentials {
92 fn from(value: OAuthTokenSet) -> Self {
93 Self {
94 access_token: value.access_token,
95 refresh_token: value.refresh_token,
96 expires_at: value.expires_at,
97 scopes: value.scopes,
98 }
99 }
100}
101
102impl From<StoredOAuthCredentials> for OAuthTokenSet {
103 fn from(value: StoredOAuthCredentials) -> Self {
104 Self {
105 access_token: value.access_token,
106 refresh_token: value.refresh_token,
107 expires_at: value.expires_at,
108 scopes: value.scopes,
109 }
110 }
111}
112
113impl OAuthAuthorizationRequest {
114 #[must_use]
115 pub fn from_config(
116 config: &OAuthConfig,
117 redirect_uri: impl Into<String>,
118 state: impl Into<String>,
119 pkce: &PkceCodePair,
120 ) -> Self {
121 Self {
122 authorize_url: config.authorize_url.clone(),
123 client_id: config.client_id.clone(),
124 redirect_uri: redirect_uri.into(),
125 scopes: config.scopes.clone(),
126 state: state.into(),
127 code_challenge: pkce.challenge.clone(),
128 code_challenge_method: pkce.challenge_method,
129 extra_params: BTreeMap::new(),
130 }
131 }
132
133 #[must_use]
134 pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
135 self.extra_params.insert(key.into(), value.into());
136 self
137 }
138
139 #[must_use]
140 pub fn build_url(&self) -> String {
141 let mut params = vec![
142 ("response_type", "code".to_string()),
143 ("client_id", self.client_id.clone()),
144 ("redirect_uri", self.redirect_uri.clone()),
145 ("scope", self.scopes.join(" ")),
146 ("state", self.state.clone()),
147 ("code_challenge", self.code_challenge.clone()),
148 (
149 "code_challenge_method",
150 self.code_challenge_method.as_str().to_string(),
151 ),
152 ];
153 params.extend(
154 self.extra_params
155 .iter()
156 .map(|(key, value)| (key.as_str(), value.clone())),
157 );
158 let query = params
159 .into_iter()
160 .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
161 .collect::<Vec<_>>()
162 .join("&");
163 format!(
164 "{}{}{}",
165 self.authorize_url,
166 if self.authorize_url.contains('?') {
167 '&'
168 } else {
169 '?'
170 },
171 query
172 )
173 }
174}
175
176impl OAuthTokenExchangeRequest {
177 #[must_use]
178 pub fn from_config(
179 config: &OAuthConfig,
180 code: impl Into<String>,
181 state: impl Into<String>,
182 verifier: impl Into<String>,
183 redirect_uri: impl Into<String>,
184 ) -> Self {
185 Self {
186 grant_type: "authorization_code",
187 code: code.into(),
188 redirect_uri: redirect_uri.into(),
189 client_id: config.client_id.clone(),
190 code_verifier: verifier.into(),
191 state: state.into(),
192 }
193 }
194
195 #[must_use]
196 pub fn form_params(&self) -> BTreeMap<&str, String> {
197 BTreeMap::from([
198 ("grant_type", self.grant_type.to_string()),
199 ("code", self.code.clone()),
200 ("redirect_uri", self.redirect_uri.clone()),
201 ("client_id", self.client_id.clone()),
202 ("code_verifier", self.code_verifier.clone()),
203 ("state", self.state.clone()),
204 ])
205 }
206}
207
208impl OAuthRefreshRequest {
209 #[must_use]
210 pub fn from_config(
211 config: &OAuthConfig,
212 refresh_token: impl Into<String>,
213 scopes: Option<Vec<String>>,
214 ) -> Self {
215 Self {
216 grant_type: "refresh_token",
217 refresh_token: refresh_token.into(),
218 client_id: config.client_id.clone(),
219 scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
220 }
221 }
222
223 #[must_use]
224 pub fn form_params(&self) -> BTreeMap<&str, String> {
225 BTreeMap::from([
226 ("grant_type", self.grant_type.to_string()),
227 ("refresh_token", self.refresh_token.clone()),
228 ("client_id", self.client_id.clone()),
229 ("scope", self.scopes.join(" ")),
230 ])
231 }
232}
233
234pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
235 let verifier = generate_random_token(32)?;
236 Ok(PkceCodePair {
237 challenge: code_challenge_s256(&verifier),
238 verifier,
239 challenge_method: PkceChallengeMethod::S256,
240 })
241}
242
243pub fn generate_state() -> io::Result<String> {
244 generate_random_token(32)
245}
246
247#[must_use]
248pub fn code_challenge_s256(verifier: &str) -> String {
249 let digest = Sha256::digest(verifier.as_bytes());
250 base64url_encode(&digest)
251}
252
253#[must_use]
254pub fn loopback_redirect_uri(port: u16) -> String {
255 format!("http://localhost:{port}/callback")
256}
257
258pub fn credentials_path() -> io::Result<PathBuf> {
259 Ok(credentials_home_dir()?.join("credentials.json"))
260}
261
262pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
263 let path = credentials_path()?;
264 let root = read_credentials_root(&path)?;
265 let Some(oauth) = root.get("oauth") else {
266 return Ok(None);
267 };
268 if oauth.is_null() {
269 return Ok(None);
270 }
271 let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
272 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
273 Ok(Some(stored.into()))
274}
275
276pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
277 let path = credentials_path()?;
278 let mut root = read_credentials_root(&path)?;
279 root.insert(
280 "oauth".to_string(),
281 serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
282 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
283 );
284 write_credentials_root(&path, &root)
285}
286
287pub fn clear_oauth_credentials() -> io::Result<()> {
288 let path = credentials_path()?;
289 let mut root = read_credentials_root(&path)?;
290 root.remove("oauth");
291 write_credentials_root(&path, &root)
292}
293
294pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
295 let (path, query) = target
296 .split_once('?')
297 .map_or((target, ""), |(path, query)| (path, query));
298 if path != "/callback" {
299 return Err(format!("unexpected callback path: {path}"));
300 }
301 parse_oauth_callback_query(query)
302}
303
304pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
305 let mut params = BTreeMap::new();
306 for pair in query.split('&').filter(|pair| !pair.is_empty()) {
307 let (key, value) = pair
308 .split_once('=')
309 .map_or((pair, ""), |(key, value)| (key, value));
310 params.insert(percent_decode(key)?, percent_decode(value)?);
311 }
312 Ok(OAuthCallbackParams {
313 code: params.get("code").cloned(),
314 state: params.get("state").cloned(),
315 error: params.get("error").cloned(),
316 error_description: params.get("error_description").cloned(),
317 })
318}
319
320fn generate_random_token(bytes: usize) -> io::Result<String> {
321 let mut buffer = vec![0_u8; bytes];
322 File::open("/dev/urandom")?.read_exact(&mut buffer)?;
323 Ok(base64url_encode(&buffer))
324}
325
326fn credentials_home_dir() -> io::Result<PathBuf> {
327 if let Some(path) = std::env::var_os("WRAITH_CONFIG_HOME") {
328 return Ok(PathBuf::from(path));
329 }
330 let home = std::env::var_os("HOME")
331 .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
332 Ok(PathBuf::from(home).join(".wraith"))
333}
334
335fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
336 match fs::read_to_string(path) {
337 Ok(contents) => {
338 if contents.trim().is_empty() {
339 return Ok(Map::new());
340 }
341 serde_json::from_str::<Value>(&contents)
342 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
343 .as_object()
344 .cloned()
345 .ok_or_else(|| {
346 io::Error::new(
347 io::ErrorKind::InvalidData,
348 "credentials file must contain a JSON object",
349 )
350 })
351 }
352 Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
353 Err(error) => Err(error),
354 }
355}
356
357fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
358 if let Some(parent) = path.parent() {
359 fs::create_dir_all(parent)?;
360 }
361 let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
362 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
363 let temp_path = path.with_extension("json.tmp");
364 fs::write(&temp_path, format!("{rendered}\n"))?;
365 fs::rename(temp_path, path)
366}
367
368fn base64url_encode(bytes: &[u8]) -> String {
369 const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
370 let mut output = String::new();
371 let mut index = 0;
372 while index + 3 <= bytes.len() {
373 let block = (u32::from(bytes[index]) << 16)
374 | (u32::from(bytes[index + 1]) << 8)
375 | u32::from(bytes[index + 2]);
376 output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
377 output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
378 output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
379 output.push(TABLE[(block & 0x3F) as usize] as char);
380 index += 3;
381 }
382 match bytes.len().saturating_sub(index) {
383 1 => {
384 let block = u32::from(bytes[index]) << 16;
385 output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
386 output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
387 }
388 2 => {
389 let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
390 output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
391 output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
392 output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
393 }
394 _ => {}
395 }
396 output
397}
398
399fn percent_encode(value: &str) -> String {
400 let mut encoded = String::new();
401 for byte in value.bytes() {
402 match byte {
403 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
404 encoded.push(char::from(byte));
405 }
406 _ => {
407 use std::fmt::Write as _;
408 let _ = write!(&mut encoded, "%{byte:02X}");
409 }
410 }
411 }
412 encoded
413}
414
415fn percent_decode(value: &str) -> Result<String, String> {
416 let mut decoded = Vec::with_capacity(value.len());
417 let bytes = value.as_bytes();
418 let mut index = 0;
419 while index < bytes.len() {
420 match bytes[index] {
421 b'%' if index + 2 < bytes.len() => {
422 let hi = decode_hex(bytes[index + 1])?;
423 let lo = decode_hex(bytes[index + 2])?;
424 decoded.push((hi << 4) | lo);
425 index += 3;
426 }
427 b'+' => {
428 decoded.push(b' ');
429 index += 1;
430 }
431 byte => {
432 decoded.push(byte);
433 index += 1;
434 }
435 }
436 }
437 String::from_utf8(decoded).map_err(|error| error.to_string())
438}
439
440fn decode_hex(byte: u8) -> Result<u8, String> {
441 match byte {
442 b'0'..=b'9' => Ok(byte - b'0'),
443 b'a'..=b'f' => Ok(byte - b'a' + 10),
444 b'A'..=b'F' => Ok(byte - b'A' + 10),
445 _ => Err(format!("invalid percent-encoding byte: {byte}")),
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use std::time::{SystemTime, UNIX_EPOCH};
452
453 use super::{
454 clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
455 generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
456 parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
457 OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
458 };
459
460 fn sample_config() -> OAuthConfig {
461 OAuthConfig {
462 client_id: "runtime-client".to_string(),
463 authorize_url: "https://console.test/oauth/authorize".to_string(),
464 token_url: "https://console.test/oauth/token".to_string(),
465 callback_port: Some(4545),
466 manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
467 scopes: vec!["org:read".to_string(), "user:write".to_string()],
468 }
469 }
470
471 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
472 crate::test_env_lock()
473 }
474
475 fn temp_config_home() -> std::path::PathBuf {
476 std::env::temp_dir().join(format!(
477 "runtime-oauth-test-{}-{}",
478 std::process::id(),
479 SystemTime::now()
480 .duration_since(UNIX_EPOCH)
481 .expect("time")
482 .as_nanos()
483 ))
484 }
485
486 #[test]
487 fn s256_challenge_matches_expected_vector() {
488 assert_eq!(
489 code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
490 "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
491 );
492 }
493
494 #[test]
495 fn generates_pkce_pair_and_state() {
496 let pair = generate_pkce_pair().expect("pkce pair");
497 let state = generate_state().expect("state");
498 assert!(!pair.verifier.is_empty());
499 assert!(!pair.challenge.is_empty());
500 assert!(!state.is_empty());
501 }
502
503 #[test]
504 fn builds_authorize_url_and_form_requests() {
505 let config = sample_config();
506 let pair = generate_pkce_pair().expect("pkce");
507 let url = OAuthAuthorizationRequest::from_config(
508 &config,
509 loopback_redirect_uri(4545),
510 "state-123",
511 &pair,
512 )
513 .with_extra_param("login_hint", "user@example.com")
514 .build_url();
515 assert!(url.starts_with("https://console.test/oauth/authorize?"));
516 assert!(url.contains("response_type=code"));
517 assert!(url.contains("client_id=runtime-client"));
518 assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
519 assert!(url.contains("login_hint=user%40example.com"));
520
521 let exchange = OAuthTokenExchangeRequest::from_config(
522 &config,
523 "auth-code",
524 "state-123",
525 pair.verifier,
526 loopback_redirect_uri(4545),
527 );
528 assert_eq!(
529 exchange.form_params().get("grant_type").map(String::as_str),
530 Some("authorization_code")
531 );
532
533 let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
534 assert_eq!(
535 refresh.form_params().get("scope").map(String::as_str),
536 Some("org:read user:write")
537 );
538 }
539
540 #[test]
541 fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
542 let _guard = env_lock();
543 let config_home = temp_config_home();
544 std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
545 let path = credentials_path().expect("credentials path");
546 std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
547 std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
548
549 let token_set = OAuthTokenSet {
550 access_token: "access-token".to_string(),
551 refresh_token: Some("refresh-token".to_string()),
552 expires_at: Some(123),
553 scopes: vec!["scope:a".to_string()],
554 };
555 save_oauth_credentials(&token_set).expect("save credentials");
556 assert_eq!(
557 load_oauth_credentials().expect("load credentials"),
558 Some(token_set)
559 );
560 let saved = std::fs::read_to_string(&path).expect("read saved file");
561 assert!(saved.contains("\"other\": \"value\""));
562 assert!(saved.contains("\"oauth\""));
563
564 clear_oauth_credentials().expect("clear credentials");
565 assert_eq!(load_oauth_credentials().expect("load cleared"), None);
566 let cleared = std::fs::read_to_string(&path).expect("read cleared file");
567 assert!(cleared.contains("\"other\": \"value\""));
568 assert!(!cleared.contains("\"oauth\""));
569
570 std::env::remove_var("WRAITH_CONFIG_HOME");
571 std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
572 }
573
574 #[test]
575 fn parses_callback_query_and_target() {
576 let params =
577 parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
578 .expect("parse query");
579 assert_eq!(params.code.as_deref(), Some("abc123"));
580 assert_eq!(params.state.as_deref(), Some("state-1"));
581 assert_eq!(params.error_description.as_deref(), Some("needs login"));
582
583 let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
584 .expect("parse callback target");
585 assert_eq!(params.code.as_deref(), Some("abc"));
586 assert_eq!(params.state.as_deref(), Some("xyz"));
587 assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
588 }
589}