Skip to main content

stoat_core/
oauth.rs

1//! OAuth authorization URL construction and helpers.
2//!
3//! Builds the authorization URL with the required query parameters for an
4//! OAuth 2.0 PKCE authorization code flow. This module is pure — it only
5//! manipulates URLs and strings.
6
7use std::collections::HashMap;
8
9use base64::Engine;
10use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use url::Url;
12
13use crate::config::{OAuth, TokenFormat};
14use crate::pkce::PkceChallenge;
15
16/// Parameters for constructing an authorization URL.
17///
18/// Separating these from the [`OAuth`] config allows the caller to control
19/// the state parameter and whether PKCE parameters are included.
20pub struct AuthorizationRequest<'a> {
21    /// The OAuth configuration from the config file.
22    pub oauth: &'a OAuth,
23    /// The PKCE challenge (included only when PKCE is enabled).
24    pub pkce: Option<&'a PkceChallenge>,
25    /// An opaque state value for CSRF protection.
26    pub state: &'a str,
27}
28
29/// Build the authorization URL for an OAuth 2.0 authorization code flow.
30///
31/// The returned URL includes the following query parameters:
32/// - `response_type=code`
33/// - `client_id`
34/// - `redirect_uri`
35/// - `scope` (space-separated)
36/// - `state`
37/// - `code_challenge` and `code_challenge_method=S256` (when PKCE is provided)
38#[must_use]
39pub fn build_authorization_url(request: &AuthorizationRequest<'_>) -> Url {
40    let mut url = request.oauth.authorize_url.clone();
41
42    {
43        let mut params = url.query_pairs_mut();
44        params.append_pair("response_type", "code");
45        params.append_pair("client_id", &request.oauth.client_id);
46        params.append_pair("redirect_uri", request.oauth.redirect_uri.as_str());
47        params.append_pair("scope", &request.oauth.scopes.join(" "));
48        params.append_pair("state", request.state);
49
50        if let Some(pkce) = request.pkce {
51            params.append_pair("code_challenge", pkce.challenge());
52            params.append_pair("code_challenge_method", "S256");
53        }
54    }
55
56    url
57}
58
59/// Parameters for the token exchange request body.
60///
61/// This is a pure data structure — the actual HTTP POST is performed by the
62/// I/O layer.
63#[derive(Debug, Clone)]
64pub struct TokenExchangeParams {
65    /// The token endpoint URL.
66    pub token_url: Url,
67    /// The authorization code received from the authorization server.
68    pub code: String,
69    /// The redirect URI (must match the one used in the authorization request).
70    pub redirect_uri: Url,
71    /// The OAuth client identifier.
72    pub client_id: String,
73    /// The PKCE code verifier (if PKCE was used).
74    pub code_verifier: Option<String>,
75    /// The OAuth state parameter (included when the provider requires it in the
76    /// token exchange body).
77    pub state: Option<String>,
78    /// The body format for the token endpoint request.
79    pub token_format: TokenFormat,
80}
81
82/// Generate a random state parameter for CSRF protection.
83///
84/// Returns 16 random bytes encoded as base64url (no padding), producing
85/// a 22-character string.
86pub fn generate_state(rng: &mut impl rand::Rng) -> String {
87    let mut bytes = [0u8; 16];
88    rng.fill_bytes(&mut bytes);
89    URL_SAFE_NO_PAD.encode(bytes)
90}
91
92/// Check whether a redirect URI points to a localhost address.
93///
94/// Returns `true` if the URL's host is `localhost`, `127.0.0.1`, or `[::1]`.
95/// This determines whether the callback can be received via a local HTTP
96/// listener rather than paste mode.
97#[must_use]
98pub fn is_localhost_redirect(url: &Url) -> bool {
99    matches!(url.host_str(), Some("localhost" | "127.0.0.1" | "[::1]"))
100}
101
102/// Extract the port from a redirect URI, if present.
103///
104/// Returns `None` if no explicit port is set (the URL uses the scheme's
105/// default port).
106#[must_use]
107pub fn redirect_port(url: &Url) -> Option<u16> {
108    url.port()
109}
110
111/// Strip a URI fragment suffix from a pasted authorization code.
112///
113/// Some OAuth providers display the authorization code alongside a
114/// fragment (e.g. `CODE#STATE`).  The `#` character is a URI fragment
115/// delimiter and is never part of a valid authorization code, so
116/// everything from the first `#` onward can be safely removed.
117#[must_use]
118pub fn strip_code_fragment(code: &str) -> &str {
119    code.split_once('#').map_or(code, |(before, _)| before)
120}
121
122impl TokenExchangeParams {
123    /// Build the form parameters for the token exchange POST body.
124    #[must_use]
125    pub fn form_params(&self) -> Vec<(&str, &str)> {
126        let mut params = vec![
127            ("grant_type", "authorization_code"),
128            ("code", &self.code),
129            ("redirect_uri", self.redirect_uri.as_str()),
130            ("client_id", &self.client_id),
131        ];
132
133        if let Some(verifier) = &self.code_verifier {
134            params.push(("code_verifier", verifier));
135        }
136
137        if let Some(state) = &self.state {
138            params.push(("state", state));
139        }
140
141        params
142    }
143
144    /// Build a JSON-serializable map for the token exchange POST body.
145    #[must_use]
146    pub fn json_body(&self) -> HashMap<&str, &str> {
147        let mut map = HashMap::new();
148        map.insert("grant_type", "authorization_code");
149        map.insert("code", &self.code);
150        map.insert("redirect_uri", self.redirect_uri.as_str());
151        map.insert("client_id", &self.client_id);
152
153        if let Some(verifier) = &self.code_verifier {
154            map.insert("code_verifier", verifier);
155        }
156
157        if let Some(state) = &self.state {
158            map.insert("state", state);
159        }
160
161        map
162    }
163}
164
165/// Parameters for the token refresh request body.
166///
167/// This is a pure data structure — the actual HTTP POST is performed by the
168/// I/O layer. Corresponds to an OAuth 2.0 `grant_type=refresh_token` request.
169#[derive(Debug, Clone)]
170pub struct TokenRefreshParams {
171    /// The token endpoint URL.
172    pub token_url: Url,
173    /// The refresh token to exchange for a new access token.
174    pub refresh_token: String,
175    /// The OAuth client identifier.
176    pub client_id: String,
177    /// The body format for the token endpoint request.
178    pub token_format: TokenFormat,
179}
180
181impl TokenRefreshParams {
182    /// Build the form parameters for the token refresh POST body.
183    #[must_use]
184    pub fn form_params(&self) -> Vec<(&str, &str)> {
185        vec![
186            ("grant_type", "refresh_token"),
187            ("refresh_token", &self.refresh_token),
188            ("client_id", &self.client_id),
189        ]
190    }
191
192    /// Build a JSON-serializable map for the token refresh POST body.
193    #[must_use]
194    pub fn json_body(&self) -> HashMap<&str, &str> {
195        let mut map = HashMap::new();
196        map.insert("grant_type", "refresh_token");
197        map.insert("refresh_token", &self.refresh_token);
198        map.insert("client_id", &self.client_id);
199        map
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::config::{Config, TokenFormat};
207
208    const MINIMAL_CONFIG: &str = r#"
209[upstream]
210base_url = "https://api.example.com"
211
212[oauth]
213authorize_url = "https://example.com/oauth/authorize"
214token_url = "https://example.com/oauth/token"
215client_id = "test-client-id"
216scopes = ["scope1", "scope2"]
217redirect_uri = "https://example.com/oauth/callback"
218"#;
219
220    fn test_config() -> Config {
221        Config::from_toml(MINIMAL_CONFIG).unwrap()
222    }
223
224    #[test]
225    fn authorization_url_without_pkce() {
226        let config = test_config();
227        let request = AuthorizationRequest {
228            oauth: &config.oauth,
229            pkce: None,
230            state: "test-state",
231        };
232
233        let url = build_authorization_url(&request);
234
235        assert_eq!(url.scheme(), "https");
236        assert_eq!(url.host_str(), Some("example.com"));
237        assert_eq!(url.path(), "/oauth/authorize");
238
239        let pairs: Vec<(String, String)> = url.query_pairs().into_owned().collect();
240        assert!(pairs.contains(&("response_type".into(), "code".into())));
241        assert!(pairs.contains(&("client_id".into(), "test-client-id".into())));
242        assert!(pairs.contains(&(
243            "redirect_uri".into(),
244            "https://example.com/oauth/callback".into()
245        )));
246        assert!(pairs.contains(&("scope".into(), "scope1 scope2".into())));
247        assert!(pairs.contains(&("state".into(), "test-state".into())));
248        assert!(
249            !pairs.iter().any(|(k, _)| k == "code_challenge"),
250            "should not include code_challenge without PKCE"
251        );
252    }
253
254    #[test]
255    fn authorization_url_with_pkce() {
256        use rand::SeedableRng;
257        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
258        let pkce = PkceChallenge::generate(&mut rng);
259
260        let config = test_config();
261        let request = AuthorizationRequest {
262            oauth: &config.oauth,
263            pkce: Some(&pkce),
264            state: "test-state",
265        };
266
267        let url = build_authorization_url(&request);
268        let pairs: Vec<(String, String)> = url.query_pairs().into_owned().collect();
269
270        assert!(pairs.contains(&("code_challenge".into(), pkce.challenge().to_owned())));
271        assert!(pairs.contains(&("code_challenge_method".into(), "S256".into())));
272    }
273
274    #[test]
275    fn authorization_url_empty_scopes() {
276        let toml = MINIMAL_CONFIG.replace("scopes = [\"scope1\", \"scope2\"]", "scopes = []");
277        let config = Config::from_toml(&toml).unwrap();
278        let request = AuthorizationRequest {
279            oauth: &config.oauth,
280            pkce: None,
281            state: "s",
282        };
283
284        let url = build_authorization_url(&request);
285        assert!(
286            url.query_pairs()
287                .into_owned()
288                .any(|p| p == ("scope".into(), String::new()))
289        );
290    }
291
292    #[test]
293    fn token_exchange_params_with_pkce() {
294        let params = TokenExchangeParams {
295            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
296            code: "auth-code-123".into(),
297            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
298            client_id: "test-client".into(),
299            code_verifier: Some("my-verifier".into()),
300            state: None,
301            token_format: TokenFormat::Form,
302        };
303
304        let form = params.form_params();
305        assert!(form.contains(&("grant_type", "authorization_code")));
306        assert!(form.contains(&("code", "auth-code-123")));
307        assert!(form.contains(&("redirect_uri", "https://example.com/oauth/callback")));
308        assert!(form.contains(&("client_id", "test-client")));
309        assert!(form.contains(&("code_verifier", "my-verifier")));
310    }
311
312    #[test]
313    fn token_exchange_params_without_pkce() {
314        let params = TokenExchangeParams {
315            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
316            code: "auth-code-123".into(),
317            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
318            client_id: "test-client".into(),
319            code_verifier: None,
320            state: None,
321            token_format: TokenFormat::Form,
322        };
323
324        let form = params.form_params();
325        assert!(!form.iter().any(|(k, _)| *k == "code_verifier"));
326    }
327
328    #[test]
329    fn token_exchange_json_body_with_pkce() {
330        let params = TokenExchangeParams {
331            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
332            code: "auth-code-123".into(),
333            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
334            client_id: "test-client".into(),
335            code_verifier: Some("my-verifier".into()),
336            state: None,
337            token_format: TokenFormat::Json,
338        };
339
340        let body = params.json_body();
341        assert_eq!(body.get("grant_type"), Some(&"authorization_code"));
342        assert_eq!(body.get("code"), Some(&"auth-code-123"));
343        assert_eq!(
344            body.get("redirect_uri"),
345            Some(&"https://example.com/oauth/callback")
346        );
347        assert_eq!(body.get("client_id"), Some(&"test-client"));
348        assert_eq!(body.get("code_verifier"), Some(&"my-verifier"));
349    }
350
351    #[test]
352    fn token_exchange_json_body_without_pkce() {
353        let params = TokenExchangeParams {
354            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
355            code: "auth-code-123".into(),
356            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
357            client_id: "test-client".into(),
358            code_verifier: None,
359            state: None,
360            token_format: TokenFormat::Json,
361        };
362
363        let body = params.json_body();
364        assert!(!body.contains_key("code_verifier"));
365    }
366
367    #[test]
368    fn token_exchange_form_params_with_state() {
369        let params = TokenExchangeParams {
370            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
371            code: "auth-code-123".into(),
372            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
373            client_id: "test-client".into(),
374            code_verifier: None,
375            state: Some("test-state".into()),
376            token_format: TokenFormat::Form,
377        };
378
379        let form = params.form_params();
380        assert!(form.contains(&("state", "test-state")));
381    }
382
383    #[test]
384    fn token_exchange_form_params_without_state() {
385        let params = TokenExchangeParams {
386            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
387            code: "auth-code-123".into(),
388            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
389            client_id: "test-client".into(),
390            code_verifier: None,
391            state: None,
392            token_format: TokenFormat::Form,
393        };
394
395        let form = params.form_params();
396        assert!(!form.iter().any(|(k, _)| *k == "state"));
397    }
398
399    #[test]
400    fn token_exchange_json_body_with_state() {
401        let params = TokenExchangeParams {
402            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
403            code: "auth-code-123".into(),
404            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
405            client_id: "test-client".into(),
406            code_verifier: None,
407            state: Some("test-state".into()),
408            token_format: TokenFormat::Json,
409        };
410
411        let body = params.json_body();
412        assert_eq!(body.get("state"), Some(&"test-state"));
413    }
414
415    #[test]
416    fn token_exchange_json_body_without_state() {
417        let params = TokenExchangeParams {
418            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
419            code: "auth-code-123".into(),
420            redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
421            client_id: "test-client".into(),
422            code_verifier: None,
423            state: None,
424            token_format: TokenFormat::Json,
425        };
426
427        let body = params.json_body();
428        assert!(!body.contains_key("state"));
429    }
430
431    #[test]
432    fn generate_state_length() {
433        let mut rng = rand::rng();
434        let state = generate_state(&mut rng);
435        assert_eq!(state.len(), 22, "16 random bytes → 22 base64url chars");
436    }
437
438    #[test]
439    fn generate_state_deterministic() {
440        use rand::SeedableRng;
441        let mut rng1 = rand::rngs::StdRng::seed_from_u64(99);
442        let state1 = generate_state(&mut rng1);
443
444        let mut rng2 = rand::rngs::StdRng::seed_from_u64(99);
445        let state2 = generate_state(&mut rng2);
446
447        assert_eq!(state1, state2);
448    }
449
450    #[test]
451    fn is_localhost_redirect_127_0_0_1() {
452        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
453        assert!(is_localhost_redirect(&url));
454    }
455
456    #[test]
457    fn is_localhost_redirect_localhost() {
458        let url = Url::parse("http://localhost:9000/callback").unwrap();
459        assert!(is_localhost_redirect(&url));
460    }
461
462    #[test]
463    fn is_localhost_redirect_ipv6() {
464        let url = Url::parse("http://[::1]:8080/callback").unwrap();
465        assert!(is_localhost_redirect(&url));
466    }
467
468    #[test]
469    fn is_not_localhost_redirect() {
470        let url = Url::parse("https://example.com/oauth/callback").unwrap();
471        assert!(!is_localhost_redirect(&url));
472    }
473
474    #[test]
475    fn redirect_port_explicit() {
476        let url = Url::parse("http://localhost:8080/callback").unwrap();
477        assert_eq!(redirect_port(&url), Some(8080));
478    }
479
480    #[test]
481    fn redirect_port_default() {
482        let url = Url::parse("http://localhost/callback").unwrap();
483        assert_eq!(redirect_port(&url), None);
484    }
485
486    #[test]
487    fn token_refresh_params_form() {
488        let params = TokenRefreshParams {
489            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
490            refresh_token: "my-refresh-token".into(),
491            client_id: "test-client".into(),
492            token_format: TokenFormat::Form,
493        };
494
495        let form = params.form_params();
496        assert!(form.contains(&("grant_type", "refresh_token")));
497        assert!(form.contains(&("refresh_token", "my-refresh-token")));
498        assert!(form.contains(&("client_id", "test-client")));
499        assert_eq!(form.len(), 3);
500    }
501
502    #[test]
503    fn token_refresh_json_body() {
504        let params = TokenRefreshParams {
505            token_url: Url::parse("https://example.com/oauth/token").unwrap(),
506            refresh_token: "my-refresh-token".into(),
507            client_id: "test-client".into(),
508            token_format: TokenFormat::Json,
509        };
510
511        let body = params.json_body();
512        assert_eq!(body.get("grant_type"), Some(&"refresh_token"));
513        assert_eq!(body.get("refresh_token"), Some(&"my-refresh-token"));
514        assert_eq!(body.get("client_id"), Some(&"test-client"));
515        assert_eq!(body.len(), 3);
516    }
517
518    #[test]
519    fn strip_code_fragment_removes_suffix() {
520        assert_eq!(strip_code_fragment("abc123#state"), "abc123");
521    }
522
523    #[test]
524    fn strip_code_fragment_no_fragment() {
525        assert_eq!(strip_code_fragment("abc123"), "abc123");
526    }
527
528    #[test]
529    fn strip_code_fragment_empty_fragment() {
530        assert_eq!(strip_code_fragment("abc123#"), "abc123");
531    }
532
533    #[test]
534    fn strip_code_fragment_empty_string() {
535        assert_eq!(strip_code_fragment(""), "");
536    }
537
538    #[test]
539    fn strip_code_fragment_multiple_hashes() {
540        assert_eq!(strip_code_fragment("abc#foo#bar"), "abc");
541    }
542}