Skip to main content

wraith_runtime/
oauth.rs

1use std::collections::BTreeMap;
2use std::fs::{self, File};
3use std::io::{self, Read};
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use sha2::{Digest, Sha256};
9
10use crate::config::OAuthConfig;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct OAuthTokenSet {
14    pub access_token: String,
15    pub refresh_token: Option<String>,
16    pub expires_at: Option<u64>,
17    pub scopes: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct PkceCodePair {
22    pub verifier: String,
23    pub challenge: String,
24    pub challenge_method: PkceChallengeMethod,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum PkceChallengeMethod {
29    S256,
30}
31
32impl PkceChallengeMethod {
33    #[must_use]
34    pub const fn as_str(self) -> &'static str {
35        match self {
36            Self::S256 => "S256",
37        }
38    }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct OAuthAuthorizationRequest {
43    pub authorize_url: String,
44    pub client_id: String,
45    pub redirect_uri: String,
46    pub scopes: Vec<String>,
47    pub state: String,
48    pub code_challenge: String,
49    pub code_challenge_method: PkceChallengeMethod,
50    pub extra_params: BTreeMap<String, String>,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct OAuthTokenExchangeRequest {
55    pub grant_type: &'static str,
56    pub code: String,
57    pub redirect_uri: String,
58    pub client_id: String,
59    pub code_verifier: String,
60    pub state: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct OAuthRefreshRequest {
65    pub grant_type: &'static str,
66    pub refresh_token: String,
67    pub client_id: String,
68    pub scopes: Vec<String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct OAuthCallbackParams {
73    pub code: Option<String>,
74    pub state: Option<String>,
75    pub error: Option<String>,
76    pub error_description: Option<String>,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81struct StoredOAuthCredentials {
82    access_token: String,
83    #[serde(default)]
84    refresh_token: Option<String>,
85    #[serde(default)]
86    expires_at: Option<u64>,
87    #[serde(default)]
88    scopes: Vec<String>,
89}
90
91impl From<OAuthTokenSet> for StoredOAuthCredentials {
92    fn from(value: OAuthTokenSet) -> Self {
93        Self {
94            access_token: value.access_token,
95            refresh_token: value.refresh_token,
96            expires_at: value.expires_at,
97            scopes: value.scopes,
98        }
99    }
100}
101
102impl From<StoredOAuthCredentials> for OAuthTokenSet {
103    fn from(value: StoredOAuthCredentials) -> Self {
104        Self {
105            access_token: value.access_token,
106            refresh_token: value.refresh_token,
107            expires_at: value.expires_at,
108            scopes: value.scopes,
109        }
110    }
111}
112
113impl OAuthAuthorizationRequest {
114    #[must_use]
115    pub fn from_config(
116        config: &OAuthConfig,
117        redirect_uri: impl Into<String>,
118        state: impl Into<String>,
119        pkce: &PkceCodePair,
120    ) -> Self {
121        Self {
122            authorize_url: config.authorize_url.clone(),
123            client_id: config.client_id.clone(),
124            redirect_uri: redirect_uri.into(),
125            scopes: config.scopes.clone(),
126            state: state.into(),
127            code_challenge: pkce.challenge.clone(),
128            code_challenge_method: pkce.challenge_method,
129            extra_params: BTreeMap::new(),
130        }
131    }
132
133    #[must_use]
134    pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
135        self.extra_params.insert(key.into(), value.into());
136        self
137    }
138
139    #[must_use]
140    pub fn build_url(&self) -> String {
141        let mut params = vec![
142            ("response_type", "code".to_string()),
143            ("client_id", self.client_id.clone()),
144            ("redirect_uri", self.redirect_uri.clone()),
145            ("scope", self.scopes.join(" ")),
146            ("state", self.state.clone()),
147            ("code_challenge", self.code_challenge.clone()),
148            (
149                "code_challenge_method",
150                self.code_challenge_method.as_str().to_string(),
151            ),
152        ];
153        params.extend(
154            self.extra_params
155                .iter()
156                .map(|(key, value)| (key.as_str(), value.clone())),
157        );
158        let query = params
159            .into_iter()
160            .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
161            .collect::<Vec<_>>()
162            .join("&");
163        format!(
164            "{}{}{}",
165            self.authorize_url,
166            if self.authorize_url.contains('?') {
167                '&'
168            } else {
169                '?'
170            },
171            query
172        )
173    }
174}
175
176impl OAuthTokenExchangeRequest {
177    #[must_use]
178    pub fn from_config(
179        config: &OAuthConfig,
180        code: impl Into<String>,
181        state: impl Into<String>,
182        verifier: impl Into<String>,
183        redirect_uri: impl Into<String>,
184    ) -> Self {
185        Self {
186            grant_type: "authorization_code",
187            code: code.into(),
188            redirect_uri: redirect_uri.into(),
189            client_id: config.client_id.clone(),
190            code_verifier: verifier.into(),
191            state: state.into(),
192        }
193    }
194
195    #[must_use]
196    pub fn form_params(&self) -> BTreeMap<&str, String> {
197        BTreeMap::from([
198            ("grant_type", self.grant_type.to_string()),
199            ("code", self.code.clone()),
200            ("redirect_uri", self.redirect_uri.clone()),
201            ("client_id", self.client_id.clone()),
202            ("code_verifier", self.code_verifier.clone()),
203            ("state", self.state.clone()),
204        ])
205    }
206}
207
208impl OAuthRefreshRequest {
209    #[must_use]
210    pub fn from_config(
211        config: &OAuthConfig,
212        refresh_token: impl Into<String>,
213        scopes: Option<Vec<String>>,
214    ) -> Self {
215        Self {
216            grant_type: "refresh_token",
217            refresh_token: refresh_token.into(),
218            client_id: config.client_id.clone(),
219            scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
220        }
221    }
222
223    #[must_use]
224    pub fn form_params(&self) -> BTreeMap<&str, String> {
225        BTreeMap::from([
226            ("grant_type", self.grant_type.to_string()),
227            ("refresh_token", self.refresh_token.clone()),
228            ("client_id", self.client_id.clone()),
229            ("scope", self.scopes.join(" ")),
230        ])
231    }
232}
233
234pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
235    let verifier = generate_random_token(32)?;
236    Ok(PkceCodePair {
237        challenge: code_challenge_s256(&verifier),
238        verifier,
239        challenge_method: PkceChallengeMethod::S256,
240    })
241}
242
243pub fn generate_state() -> io::Result<String> {
244    generate_random_token(32)
245}
246
247#[must_use]
248pub fn code_challenge_s256(verifier: &str) -> String {
249    let digest = Sha256::digest(verifier.as_bytes());
250    base64url_encode(&digest)
251}
252
253#[must_use]
254pub fn loopback_redirect_uri(port: u16) -> String {
255    format!("http://localhost:{port}/callback")
256}
257
258pub fn credentials_path() -> io::Result<PathBuf> {
259    Ok(credentials_home_dir()?.join("credentials.json"))
260}
261
262pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
263    let path = credentials_path()?;
264    let root = read_credentials_root(&path)?;
265    let Some(oauth) = root.get("oauth") else {
266        return Ok(None);
267    };
268    if oauth.is_null() {
269        return Ok(None);
270    }
271    let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
272        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
273    Ok(Some(stored.into()))
274}
275
276pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
277    let path = credentials_path()?;
278    let mut root = read_credentials_root(&path)?;
279    root.insert(
280        "oauth".to_string(),
281        serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
282            .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
283    );
284    write_credentials_root(&path, &root)
285}
286
287pub fn clear_oauth_credentials() -> io::Result<()> {
288    let path = credentials_path()?;
289    let mut root = read_credentials_root(&path)?;
290    root.remove("oauth");
291    write_credentials_root(&path, &root)
292}
293
294pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
295    let (path, query) = target
296        .split_once('?')
297        .map_or((target, ""), |(path, query)| (path, query));
298    if path != "/callback" {
299        return Err(format!("unexpected callback path: {path}"));
300    }
301    parse_oauth_callback_query(query)
302}
303
304pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
305    let mut params = BTreeMap::new();
306    for pair in query.split('&').filter(|pair| !pair.is_empty()) {
307        let (key, value) = pair
308            .split_once('=')
309            .map_or((pair, ""), |(key, value)| (key, value));
310        params.insert(percent_decode(key)?, percent_decode(value)?);
311    }
312    Ok(OAuthCallbackParams {
313        code: params.get("code").cloned(),
314        state: params.get("state").cloned(),
315        error: params.get("error").cloned(),
316        error_description: params.get("error_description").cloned(),
317    })
318}
319
320fn generate_random_token(bytes: usize) -> io::Result<String> {
321    let mut buffer = vec![0_u8; bytes];
322    File::open("/dev/urandom")?.read_exact(&mut buffer)?;
323    Ok(base64url_encode(&buffer))
324}
325
326fn credentials_home_dir() -> io::Result<PathBuf> {
327    if let Some(path) = std::env::var_os("WRAITH_CONFIG_HOME") {
328        return Ok(PathBuf::from(path));
329    }
330    let home = std::env::var_os("HOME")
331        .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
332    Ok(PathBuf::from(home).join(".wraith"))
333}
334
335fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
336    match fs::read_to_string(path) {
337        Ok(contents) => {
338            if contents.trim().is_empty() {
339                return Ok(Map::new());
340            }
341            serde_json::from_str::<Value>(&contents)
342                .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
343                .as_object()
344                .cloned()
345                .ok_or_else(|| {
346                    io::Error::new(
347                        io::ErrorKind::InvalidData,
348                        "credentials file must contain a JSON object",
349                    )
350                })
351        }
352        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
353        Err(error) => Err(error),
354    }
355}
356
357fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
358    if let Some(parent) = path.parent() {
359        fs::create_dir_all(parent)?;
360    }
361    let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
362        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
363    let temp_path = path.with_extension("json.tmp");
364    fs::write(&temp_path, format!("{rendered}\n"))?;
365    fs::rename(temp_path, path)
366}
367
368fn base64url_encode(bytes: &[u8]) -> String {
369    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
370    let mut output = String::new();
371    let mut index = 0;
372    while index + 3 <= bytes.len() {
373        let block = (u32::from(bytes[index]) << 16)
374            | (u32::from(bytes[index + 1]) << 8)
375            | u32::from(bytes[index + 2]);
376        output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
377        output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
378        output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
379        output.push(TABLE[(block & 0x3F) as usize] as char);
380        index += 3;
381    }
382    match bytes.len().saturating_sub(index) {
383        1 => {
384            let block = u32::from(bytes[index]) << 16;
385            output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
386            output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
387        }
388        2 => {
389            let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
390            output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
391            output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
392            output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
393        }
394        _ => {}
395    }
396    output
397}
398
399fn percent_encode(value: &str) -> String {
400    let mut encoded = String::new();
401    for byte in value.bytes() {
402        match byte {
403            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
404                encoded.push(char::from(byte));
405            }
406            _ => {
407                use std::fmt::Write as _;
408                let _ = write!(&mut encoded, "%{byte:02X}");
409            }
410        }
411    }
412    encoded
413}
414
415fn percent_decode(value: &str) -> Result<String, String> {
416    let mut decoded = Vec::with_capacity(value.len());
417    let bytes = value.as_bytes();
418    let mut index = 0;
419    while index < bytes.len() {
420        match bytes[index] {
421            b'%' if index + 2 < bytes.len() => {
422                let hi = decode_hex(bytes[index + 1])?;
423                let lo = decode_hex(bytes[index + 2])?;
424                decoded.push((hi << 4) | lo);
425                index += 3;
426            }
427            b'+' => {
428                decoded.push(b' ');
429                index += 1;
430            }
431            byte => {
432                decoded.push(byte);
433                index += 1;
434            }
435        }
436    }
437    String::from_utf8(decoded).map_err(|error| error.to_string())
438}
439
440fn decode_hex(byte: u8) -> Result<u8, String> {
441    match byte {
442        b'0'..=b'9' => Ok(byte - b'0'),
443        b'a'..=b'f' => Ok(byte - b'a' + 10),
444        b'A'..=b'F' => Ok(byte - b'A' + 10),
445        _ => Err(format!("invalid percent-encoding byte: {byte}")),
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use std::time::{SystemTime, UNIX_EPOCH};
452
453    use super::{
454        clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
455        generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
456        parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
457        OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
458    };
459
460    fn sample_config() -> OAuthConfig {
461        OAuthConfig {
462            client_id: "runtime-client".to_string(),
463            authorize_url: "https://console.test/oauth/authorize".to_string(),
464            token_url: "https://console.test/oauth/token".to_string(),
465            callback_port: Some(4545),
466            manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
467            scopes: vec!["org:read".to_string(), "user:write".to_string()],
468        }
469    }
470
471    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
472        crate::test_env_lock()
473    }
474
475    fn temp_config_home() -> std::path::PathBuf {
476        std::env::temp_dir().join(format!(
477            "runtime-oauth-test-{}-{}",
478            std::process::id(),
479            SystemTime::now()
480                .duration_since(UNIX_EPOCH)
481                .expect("time")
482                .as_nanos()
483        ))
484    }
485
486    #[test]
487    fn s256_challenge_matches_expected_vector() {
488        assert_eq!(
489            code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
490            "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
491        );
492    }
493
494    #[test]
495    fn generates_pkce_pair_and_state() {
496        let pair = generate_pkce_pair().expect("pkce pair");
497        let state = generate_state().expect("state");
498        assert!(!pair.verifier.is_empty());
499        assert!(!pair.challenge.is_empty());
500        assert!(!state.is_empty());
501    }
502
503    #[test]
504    fn builds_authorize_url_and_form_requests() {
505        let config = sample_config();
506        let pair = generate_pkce_pair().expect("pkce");
507        let url = OAuthAuthorizationRequest::from_config(
508            &config,
509            loopback_redirect_uri(4545),
510            "state-123",
511            &pair,
512        )
513        .with_extra_param("login_hint", "user@example.com")
514        .build_url();
515        assert!(url.starts_with("https://console.test/oauth/authorize?"));
516        assert!(url.contains("response_type=code"));
517        assert!(url.contains("client_id=runtime-client"));
518        assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
519        assert!(url.contains("login_hint=user%40example.com"));
520
521        let exchange = OAuthTokenExchangeRequest::from_config(
522            &config,
523            "auth-code",
524            "state-123",
525            pair.verifier,
526            loopback_redirect_uri(4545),
527        );
528        assert_eq!(
529            exchange.form_params().get("grant_type").map(String::as_str),
530            Some("authorization_code")
531        );
532
533        let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
534        assert_eq!(
535            refresh.form_params().get("scope").map(String::as_str),
536            Some("org:read user:write")
537        );
538    }
539
540    #[test]
541    fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
542        let _guard = env_lock();
543        let config_home = temp_config_home();
544        std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
545        let path = credentials_path().expect("credentials path");
546        std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
547        std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
548
549        let token_set = OAuthTokenSet {
550            access_token: "access-token".to_string(),
551            refresh_token: Some("refresh-token".to_string()),
552            expires_at: Some(123),
553            scopes: vec!["scope:a".to_string()],
554        };
555        save_oauth_credentials(&token_set).expect("save credentials");
556        assert_eq!(
557            load_oauth_credentials().expect("load credentials"),
558            Some(token_set)
559        );
560        let saved = std::fs::read_to_string(&path).expect("read saved file");
561        assert!(saved.contains("\"other\": \"value\""));
562        assert!(saved.contains("\"oauth\""));
563
564        clear_oauth_credentials().expect("clear credentials");
565        assert_eq!(load_oauth_credentials().expect("load cleared"), None);
566        let cleared = std::fs::read_to_string(&path).expect("read cleared file");
567        assert!(cleared.contains("\"other\": \"value\""));
568        assert!(!cleared.contains("\"oauth\""));
569
570        std::env::remove_var("WRAITH_CONFIG_HOME");
571        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
572    }
573
574    #[test]
575    fn parses_callback_query_and_target() {
576        let params =
577            parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
578                .expect("parse query");
579        assert_eq!(params.code.as_deref(), Some("abc123"));
580        assert_eq!(params.state.as_deref(), Some("state-1"));
581        assert_eq!(params.error_description.as_deref(), Some("needs login"));
582
583        let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
584            .expect("parse callback target");
585        assert_eq!(params.code.as_deref(), Some("abc"));
586        assert_eq!(params.state.as_deref(), Some("xyz"));
587        assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
588    }
589}