1use std::collections::HashMap;
8
9use base64::Engine;
10use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use url::Url;
12
13use crate::config::{OAuth, TokenFormat};
14use crate::pkce::PkceChallenge;
15
16pub struct AuthorizationRequest<'a> {
21 pub oauth: &'a OAuth,
23 pub pkce: Option<&'a PkceChallenge>,
25 pub state: &'a str,
27}
28
29#[must_use]
39pub fn build_authorization_url(request: &AuthorizationRequest<'_>) -> Url {
40 let mut url = request.oauth.authorize_url.clone();
41
42 {
43 let mut params = url.query_pairs_mut();
44 params.append_pair("response_type", "code");
45 params.append_pair("client_id", &request.oauth.client_id);
46 params.append_pair("redirect_uri", request.oauth.redirect_uri.as_str());
47 params.append_pair("scope", &request.oauth.scopes.join(" "));
48 params.append_pair("state", request.state);
49
50 if let Some(pkce) = request.pkce {
51 params.append_pair("code_challenge", pkce.challenge());
52 params.append_pair("code_challenge_method", "S256");
53 }
54 }
55
56 url
57}
58
59#[derive(Debug, Clone)]
64pub struct TokenExchangeParams {
65 pub token_url: Url,
67 pub code: String,
69 pub redirect_uri: Url,
71 pub client_id: String,
73 pub code_verifier: Option<String>,
75 pub state: Option<String>,
78 pub token_format: TokenFormat,
80}
81
82pub fn generate_state(rng: &mut impl rand::Rng) -> String {
87 let mut bytes = [0u8; 16];
88 rng.fill_bytes(&mut bytes);
89 URL_SAFE_NO_PAD.encode(bytes)
90}
91
92#[must_use]
98pub fn is_localhost_redirect(url: &Url) -> bool {
99 matches!(url.host_str(), Some("localhost" | "127.0.0.1" | "[::1]"))
100}
101
102#[must_use]
107pub fn redirect_port(url: &Url) -> Option<u16> {
108 url.port()
109}
110
111#[must_use]
118pub fn strip_code_fragment(code: &str) -> &str {
119 code.split_once('#').map_or(code, |(before, _)| before)
120}
121
122impl TokenExchangeParams {
123 #[must_use]
125 pub fn form_params(&self) -> Vec<(&str, &str)> {
126 let mut params = vec![
127 ("grant_type", "authorization_code"),
128 ("code", &self.code),
129 ("redirect_uri", self.redirect_uri.as_str()),
130 ("client_id", &self.client_id),
131 ];
132
133 if let Some(verifier) = &self.code_verifier {
134 params.push(("code_verifier", verifier));
135 }
136
137 if let Some(state) = &self.state {
138 params.push(("state", state));
139 }
140
141 params
142 }
143
144 #[must_use]
146 pub fn json_body(&self) -> HashMap<&str, &str> {
147 let mut map = HashMap::new();
148 map.insert("grant_type", "authorization_code");
149 map.insert("code", &self.code);
150 map.insert("redirect_uri", self.redirect_uri.as_str());
151 map.insert("client_id", &self.client_id);
152
153 if let Some(verifier) = &self.code_verifier {
154 map.insert("code_verifier", verifier);
155 }
156
157 if let Some(state) = &self.state {
158 map.insert("state", state);
159 }
160
161 map
162 }
163}
164
165#[derive(Debug, Clone)]
170pub struct TokenRefreshParams {
171 pub token_url: Url,
173 pub refresh_token: String,
175 pub client_id: String,
177 pub token_format: TokenFormat,
179}
180
181impl TokenRefreshParams {
182 #[must_use]
184 pub fn form_params(&self) -> Vec<(&str, &str)> {
185 vec![
186 ("grant_type", "refresh_token"),
187 ("refresh_token", &self.refresh_token),
188 ("client_id", &self.client_id),
189 ]
190 }
191
192 #[must_use]
194 pub fn json_body(&self) -> HashMap<&str, &str> {
195 let mut map = HashMap::new();
196 map.insert("grant_type", "refresh_token");
197 map.insert("refresh_token", &self.refresh_token);
198 map.insert("client_id", &self.client_id);
199 map
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::config::{Config, TokenFormat};
207
208 const MINIMAL_CONFIG: &str = r#"
209[upstream]
210base_url = "https://api.example.com"
211
212[oauth]
213authorize_url = "https://example.com/oauth/authorize"
214token_url = "https://example.com/oauth/token"
215client_id = "test-client-id"
216scopes = ["scope1", "scope2"]
217redirect_uri = "https://example.com/oauth/callback"
218"#;
219
220 fn test_config() -> Config {
221 Config::from_toml(MINIMAL_CONFIG).unwrap()
222 }
223
224 #[test]
225 fn authorization_url_without_pkce() {
226 let config = test_config();
227 let request = AuthorizationRequest {
228 oauth: &config.oauth,
229 pkce: None,
230 state: "test-state",
231 };
232
233 let url = build_authorization_url(&request);
234
235 assert_eq!(url.scheme(), "https");
236 assert_eq!(url.host_str(), Some("example.com"));
237 assert_eq!(url.path(), "/oauth/authorize");
238
239 let pairs: Vec<(String, String)> = url.query_pairs().into_owned().collect();
240 assert!(pairs.contains(&("response_type".into(), "code".into())));
241 assert!(pairs.contains(&("client_id".into(), "test-client-id".into())));
242 assert!(pairs.contains(&(
243 "redirect_uri".into(),
244 "https://example.com/oauth/callback".into()
245 )));
246 assert!(pairs.contains(&("scope".into(), "scope1 scope2".into())));
247 assert!(pairs.contains(&("state".into(), "test-state".into())));
248 assert!(
249 !pairs.iter().any(|(k, _)| k == "code_challenge"),
250 "should not include code_challenge without PKCE"
251 );
252 }
253
254 #[test]
255 fn authorization_url_with_pkce() {
256 use rand::SeedableRng;
257 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
258 let pkce = PkceChallenge::generate(&mut rng);
259
260 let config = test_config();
261 let request = AuthorizationRequest {
262 oauth: &config.oauth,
263 pkce: Some(&pkce),
264 state: "test-state",
265 };
266
267 let url = build_authorization_url(&request);
268 let pairs: Vec<(String, String)> = url.query_pairs().into_owned().collect();
269
270 assert!(pairs.contains(&("code_challenge".into(), pkce.challenge().to_owned())));
271 assert!(pairs.contains(&("code_challenge_method".into(), "S256".into())));
272 }
273
274 #[test]
275 fn authorization_url_empty_scopes() {
276 let toml = MINIMAL_CONFIG.replace("scopes = [\"scope1\", \"scope2\"]", "scopes = []");
277 let config = Config::from_toml(&toml).unwrap();
278 let request = AuthorizationRequest {
279 oauth: &config.oauth,
280 pkce: None,
281 state: "s",
282 };
283
284 let url = build_authorization_url(&request);
285 assert!(
286 url.query_pairs()
287 .into_owned()
288 .any(|p| p == ("scope".into(), String::new()))
289 );
290 }
291
292 #[test]
293 fn token_exchange_params_with_pkce() {
294 let params = TokenExchangeParams {
295 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
296 code: "auth-code-123".into(),
297 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
298 client_id: "test-client".into(),
299 code_verifier: Some("my-verifier".into()),
300 state: None,
301 token_format: TokenFormat::Form,
302 };
303
304 let form = params.form_params();
305 assert!(form.contains(&("grant_type", "authorization_code")));
306 assert!(form.contains(&("code", "auth-code-123")));
307 assert!(form.contains(&("redirect_uri", "https://example.com/oauth/callback")));
308 assert!(form.contains(&("client_id", "test-client")));
309 assert!(form.contains(&("code_verifier", "my-verifier")));
310 }
311
312 #[test]
313 fn token_exchange_params_without_pkce() {
314 let params = TokenExchangeParams {
315 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
316 code: "auth-code-123".into(),
317 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
318 client_id: "test-client".into(),
319 code_verifier: None,
320 state: None,
321 token_format: TokenFormat::Form,
322 };
323
324 let form = params.form_params();
325 assert!(!form.iter().any(|(k, _)| *k == "code_verifier"));
326 }
327
328 #[test]
329 fn token_exchange_json_body_with_pkce() {
330 let params = TokenExchangeParams {
331 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
332 code: "auth-code-123".into(),
333 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
334 client_id: "test-client".into(),
335 code_verifier: Some("my-verifier".into()),
336 state: None,
337 token_format: TokenFormat::Json,
338 };
339
340 let body = params.json_body();
341 assert_eq!(body.get("grant_type"), Some(&"authorization_code"));
342 assert_eq!(body.get("code"), Some(&"auth-code-123"));
343 assert_eq!(
344 body.get("redirect_uri"),
345 Some(&"https://example.com/oauth/callback")
346 );
347 assert_eq!(body.get("client_id"), Some(&"test-client"));
348 assert_eq!(body.get("code_verifier"), Some(&"my-verifier"));
349 }
350
351 #[test]
352 fn token_exchange_json_body_without_pkce() {
353 let params = TokenExchangeParams {
354 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
355 code: "auth-code-123".into(),
356 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
357 client_id: "test-client".into(),
358 code_verifier: None,
359 state: None,
360 token_format: TokenFormat::Json,
361 };
362
363 let body = params.json_body();
364 assert!(!body.contains_key("code_verifier"));
365 }
366
367 #[test]
368 fn token_exchange_form_params_with_state() {
369 let params = TokenExchangeParams {
370 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
371 code: "auth-code-123".into(),
372 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
373 client_id: "test-client".into(),
374 code_verifier: None,
375 state: Some("test-state".into()),
376 token_format: TokenFormat::Form,
377 };
378
379 let form = params.form_params();
380 assert!(form.contains(&("state", "test-state")));
381 }
382
383 #[test]
384 fn token_exchange_form_params_without_state() {
385 let params = TokenExchangeParams {
386 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
387 code: "auth-code-123".into(),
388 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
389 client_id: "test-client".into(),
390 code_verifier: None,
391 state: None,
392 token_format: TokenFormat::Form,
393 };
394
395 let form = params.form_params();
396 assert!(!form.iter().any(|(k, _)| *k == "state"));
397 }
398
399 #[test]
400 fn token_exchange_json_body_with_state() {
401 let params = TokenExchangeParams {
402 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
403 code: "auth-code-123".into(),
404 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
405 client_id: "test-client".into(),
406 code_verifier: None,
407 state: Some("test-state".into()),
408 token_format: TokenFormat::Json,
409 };
410
411 let body = params.json_body();
412 assert_eq!(body.get("state"), Some(&"test-state"));
413 }
414
415 #[test]
416 fn token_exchange_json_body_without_state() {
417 let params = TokenExchangeParams {
418 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
419 code: "auth-code-123".into(),
420 redirect_uri: Url::parse("https://example.com/oauth/callback").unwrap(),
421 client_id: "test-client".into(),
422 code_verifier: None,
423 state: None,
424 token_format: TokenFormat::Json,
425 };
426
427 let body = params.json_body();
428 assert!(!body.contains_key("state"));
429 }
430
431 #[test]
432 fn generate_state_length() {
433 let mut rng = rand::rng();
434 let state = generate_state(&mut rng);
435 assert_eq!(state.len(), 22, "16 random bytes → 22 base64url chars");
436 }
437
438 #[test]
439 fn generate_state_deterministic() {
440 use rand::SeedableRng;
441 let mut rng1 = rand::rngs::StdRng::seed_from_u64(99);
442 let state1 = generate_state(&mut rng1);
443
444 let mut rng2 = rand::rngs::StdRng::seed_from_u64(99);
445 let state2 = generate_state(&mut rng2);
446
447 assert_eq!(state1, state2);
448 }
449
450 #[test]
451 fn is_localhost_redirect_127_0_0_1() {
452 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
453 assert!(is_localhost_redirect(&url));
454 }
455
456 #[test]
457 fn is_localhost_redirect_localhost() {
458 let url = Url::parse("http://localhost:9000/callback").unwrap();
459 assert!(is_localhost_redirect(&url));
460 }
461
462 #[test]
463 fn is_localhost_redirect_ipv6() {
464 let url = Url::parse("http://[::1]:8080/callback").unwrap();
465 assert!(is_localhost_redirect(&url));
466 }
467
468 #[test]
469 fn is_not_localhost_redirect() {
470 let url = Url::parse("https://example.com/oauth/callback").unwrap();
471 assert!(!is_localhost_redirect(&url));
472 }
473
474 #[test]
475 fn redirect_port_explicit() {
476 let url = Url::parse("http://localhost:8080/callback").unwrap();
477 assert_eq!(redirect_port(&url), Some(8080));
478 }
479
480 #[test]
481 fn redirect_port_default() {
482 let url = Url::parse("http://localhost/callback").unwrap();
483 assert_eq!(redirect_port(&url), None);
484 }
485
486 #[test]
487 fn token_refresh_params_form() {
488 let params = TokenRefreshParams {
489 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
490 refresh_token: "my-refresh-token".into(),
491 client_id: "test-client".into(),
492 token_format: TokenFormat::Form,
493 };
494
495 let form = params.form_params();
496 assert!(form.contains(&("grant_type", "refresh_token")));
497 assert!(form.contains(&("refresh_token", "my-refresh-token")));
498 assert!(form.contains(&("client_id", "test-client")));
499 assert_eq!(form.len(), 3);
500 }
501
502 #[test]
503 fn token_refresh_json_body() {
504 let params = TokenRefreshParams {
505 token_url: Url::parse("https://example.com/oauth/token").unwrap(),
506 refresh_token: "my-refresh-token".into(),
507 client_id: "test-client".into(),
508 token_format: TokenFormat::Json,
509 };
510
511 let body = params.json_body();
512 assert_eq!(body.get("grant_type"), Some(&"refresh_token"));
513 assert_eq!(body.get("refresh_token"), Some(&"my-refresh-token"));
514 assert_eq!(body.get("client_id"), Some(&"test-client"));
515 assert_eq!(body.len(), 3);
516 }
517
518 #[test]
519 fn strip_code_fragment_removes_suffix() {
520 assert_eq!(strip_code_fragment("abc123#state"), "abc123");
521 }
522
523 #[test]
524 fn strip_code_fragment_no_fragment() {
525 assert_eq!(strip_code_fragment("abc123"), "abc123");
526 }
527
528 #[test]
529 fn strip_code_fragment_empty_fragment() {
530 assert_eq!(strip_code_fragment("abc123#"), "abc123");
531 }
532
533 #[test]
534 fn strip_code_fragment_empty_string() {
535 assert_eq!(strip_code_fragment(""), "");
536 }
537
538 #[test]
539 fn strip_code_fragment_multiple_hashes() {
540 assert_eq!(strip_code_fragment("abc#foo#bar"), "abc");
541 }
542}