1use crate::{
4 errors::{Error, Result},
5 types::{AuthUrlResult, BuildAuthUrl, CallbackParams, EndSession, OidcProviderMetadata},
6};
7use rand::{distributions::Alphanumeric, Rng};
8#[cfg(test)]
9use std::collections::HashMap;
10use url::Url;
11
12pub fn build_auth_url(params: BuildAuthUrl) -> Result<AuthUrlResult> {
32 if params.issuer.is_empty() {
34 return Err(Error::InvalidParam("issuer cannot be empty"));
35 }
36 if params.client_id.is_empty() {
37 return Err(Error::InvalidParam("client_id cannot be empty"));
38 }
39 if params.redirect_uri.is_empty() {
40 return Err(Error::InvalidParam("redirect_uri cannot be empty"));
41 }
42 if params.code_challenge.is_empty() {
43 return Err(Error::InvalidParam("code_challenge cannot be empty"));
44 }
45
46 let auth_endpoint = params
49 .authorization_endpoint
50 .ok_or_else(|| Error::InvalidParam("authorization_endpoint is required. Please use OIDC discovery to obtain the correct endpoint"))?;
51
52 let mut url = Url::parse(&auth_endpoint)?;
53
54 let scope = if params.scope.is_empty() { "openid profile email" } else { ¶ms.scope };
56 let state = params.state.unwrap_or_else(generate_state);
57 let nonce = if scope.contains("openid") {
58 Some(params.nonce.unwrap_or_else(generate_nonce))
59 } else {
60 None
61 };
62
63 {
65 let mut query = url.query_pairs_mut();
66 query.append_pair("response_type", "code");
67 query.append_pair("client_id", ¶ms.client_id);
68 query.append_pair("redirect_uri", ¶ms.redirect_uri);
69
70 query.append_pair("scope", scope);
72
73 query.append_pair("state", &state);
75
76 query.append_pair("code_challenge", ¶ms.code_challenge);
78 query.append_pair("code_challenge_method", "S256");
79
80 if let Some(ref nonce) = nonce {
82 query.append_pair("nonce", nonce);
83 }
84
85 if let Some(prompt) = ¶ms.prompt {
87 query.append_pair("prompt", prompt);
88 }
89
90 if let Some(tenant) = ¶ms.tenant {
92 query.append_pair("tenant", tenant);
93 }
94
95 if let Some(extra) = ¶ms.extra_params {
97 for (key, value) in extra {
98 query.append_pair(key, value);
99 }
100 }
101 }
102
103 Ok(AuthUrlResult { url, state, nonce })
104}
105
106pub fn build_end_session_url(params: EndSession) -> Result<Url> {
121 if params.issuer.is_empty() {
123 return Err(Error::InvalidParam("issuer cannot be empty"));
124 }
125 if params.id_token_hint.is_empty() {
126 return Err(Error::InvalidParam("id_token_hint cannot be empty"));
127 }
128
129 let end_session_endpoint = if let Some(endpoint) = ¶ms.end_session_endpoint {
131 endpoint.clone()
133 } else {
134 if params.issuer.ends_with('/') {
136 format!("{}oidc/end_session", params.issuer)
137 } else {
138 format!("{}/oidc/end_session", params.issuer)
139 }
140 };
141
142 let mut url = Url::parse(&end_session_endpoint)?;
143
144 {
146 let mut query = url.query_pairs_mut();
147 query.append_pair("id_token_hint", ¶ms.id_token_hint);
148
149 if let Some(redirect_uri) = ¶ms.post_logout_redirect_uri {
150 query.append_pair("post_logout_redirect_uri", redirect_uri);
151 }
152
153 if let Some(state) = ¶ms.state {
154 query.append_pair("state", state);
155 }
156 }
157
158 Ok(url)
159}
160
161pub fn build_end_session_url_with_discovery(
184 mut params: EndSession,
185 metadata: &OidcProviderMetadata,
186) -> Result<Url> {
187 if params.end_session_endpoint.is_none() {
189 params.end_session_endpoint = metadata.end_session_endpoint.clone();
190 }
191
192 build_end_session_url(params)
193}
194
195pub fn parse_callback_params(url: &str) -> CallbackParams {
206 let mut params =
207 CallbackParams { code: None, state: None, error: None, error_description: None };
208
209 if let Ok(parsed_url) = Url::parse(url) {
211 for (key, value) in parsed_url.query_pairs() {
212 match key.as_ref() {
213 "code" => params.code = Some(value.into_owned()),
214 "state" => params.state = Some(value.into_owned()),
215 "error" => params.error = Some(value.into_owned()),
216 "error_description" => params.error_description = Some(value.into_owned()),
217 _ => {} }
219 }
220 } else {
221 let query = if let Some(query_start) = url.find('?') {
224 &url[query_start + 1..]
225 } else if url.contains('=') {
226 url
228 } else {
229 ""
230 };
231
232 if !query.is_empty() {
233 for pair in query.split('&') {
234 if let Some(eq_pos) = pair.find('=') {
235 let key = &pair[..eq_pos];
236 let value = &pair[eq_pos + 1..];
237 let decoded_value = urlencoding::decode(value).unwrap_or_else(|_| value.into());
238
239 match key {
240 "code" => params.code = Some(decoded_value.into_owned()),
241 "state" => params.state = Some(decoded_value.into_owned()),
242 "error" => params.error = Some(decoded_value.into_owned()),
243 "error_description" => {
244 params.error_description = Some(decoded_value.into_owned())
245 }
246 _ => {} }
248 }
249 }
250 }
251 }
252
253 params
254}
255
256#[allow(dead_code)]
258pub fn build_auth_url_with_metadata(
259 metadata: &OidcProviderMetadata,
260 params: BuildAuthUrl,
261) -> Result<AuthUrlResult> {
262 let mut url = Url::parse(&metadata.authorization_endpoint)?;
263
264 let state = params.state.unwrap_or_else(generate_state);
266 let nonce = if params.scope.contains("openid") {
267 Some(params.nonce.unwrap_or_else(generate_nonce))
268 } else {
269 None
270 };
271
272 {
274 let mut query = url.query_pairs_mut();
275 query.append_pair("response_type", "code");
276 query.append_pair("client_id", ¶ms.client_id);
277 query.append_pair("redirect_uri", ¶ms.redirect_uri);
278 query.append_pair("scope", ¶ms.scope);
279
280 query.append_pair("state", &state);
282
283 query.append_pair("code_challenge", ¶ms.code_challenge);
285 query.append_pair("code_challenge_method", "S256");
286
287 if let Some(ref nonce) = nonce {
289 query.append_pair("nonce", nonce);
290 }
291
292 if let Some(prompt) = ¶ms.prompt {
294 query.append_pair("prompt", prompt);
295 }
296
297 if let Some(tenant) = ¶ms.tenant {
298 query.append_pair("tenant", tenant);
299 }
300
301 if let Some(extra) = ¶ms.extra_params {
303 for (key, value) in extra {
304 query.append_pair(key, value);
305 }
306 }
307 }
308
309 Ok(AuthUrlResult { url, state, nonce })
310}
311
312fn generate_state() -> String {
314 rand::thread_rng().sample_iter(&Alphanumeric).take(32).map(char::from).collect()
315}
316
317fn generate_nonce() -> String {
319 rand::thread_rng().sample_iter(&Alphanumeric).take(32).map(char::from).collect()
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_build_auth_url() {
328 let result = build_auth_url(BuildAuthUrl {
329 issuer: "https://auth.example.com".into(),
330 client_id: "test-client".into(),
331 redirect_uri: "https://app.example.com/callback".into(),
332 scope: "openid profile".into(),
333 code_challenge: "test_challenge".into(),
334 state: Some("test_state".into()),
335 nonce: Some("test_nonce".into()),
336 prompt: None,
337 extra_params: None,
338 tenant: None,
339 authorization_endpoint: Some("https://auth.example.com/oauth/authorize".into()),
340 })
341 .unwrap();
342
343 let url = result.url;
344 assert_eq!(result.state, "test_state");
345 assert_eq!(result.nonce, Some("test_nonce".to_string()));
346
347 let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
348
349 assert_eq!(query.get("response_type"), Some(&"code".to_string()));
350 assert_eq!(query.get("client_id"), Some(&"test-client".to_string()));
351 assert_eq!(
352 query.get("redirect_uri"),
353 Some(&"https://app.example.com/callback".to_string())
354 );
355 assert_eq!(query.get("scope"), Some(&"openid profile".to_string()));
356 assert_eq!(query.get("state"), Some(&"test_state".to_string()));
357 assert_eq!(query.get("nonce"), Some(&"test_nonce".to_string()));
358 assert_eq!(query.get("code_challenge"), Some(&"test_challenge".to_string()));
359 assert_eq!(query.get("code_challenge_method"), Some(&"S256".to_string()));
360 }
361
362 #[test]
363 fn test_build_auth_url_auto_state_nonce() {
364 let result = build_auth_url(BuildAuthUrl {
365 issuer: "https://auth.example.com".into(),
366 client_id: "test-client".into(),
367 redirect_uri: "https://app.example.com/callback".into(),
368 scope: "openid profile".into(),
369 code_challenge: "test_challenge".into(),
370 state: None,
371 nonce: None,
372 prompt: None,
373 extra_params: None,
374 tenant: None,
375 authorization_endpoint: Some("https://auth.example.com/oauth/authorize".into()),
376 })
377 .unwrap();
378
379 let url = result.url;
380 assert_eq!(result.state.len(), 32);
382 assert_eq!(result.nonce.as_ref().unwrap().len(), 32);
383
384 let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
385
386 assert!(query.contains_key("state"));
388 assert!(query.contains_key("nonce"));
389 assert_eq!(query.get("state").unwrap().len(), 32);
390 assert_eq!(query.get("nonce").unwrap().len(), 32);
391 }
392
393 #[test]
394 fn test_build_auth_url_missing_authorization_endpoint() {
395 let result = build_auth_url(BuildAuthUrl {
396 issuer: "https://auth.example.com".into(),
397 client_id: "test-client".into(),
398 redirect_uri: "https://app.example.com/callback".into(),
399 scope: "openid profile".into(),
400 code_challenge: "test_challenge".into(),
401 state: Some("test_state".into()),
402 nonce: Some("test_nonce".into()),
403 prompt: None,
404 extra_params: None,
405 tenant: None,
406 authorization_endpoint: None,
407 });
408
409 assert!(result.is_err());
410 match result {
411 Err(Error::InvalidParam(msg)) => {
412 assert!(msg.contains("authorization_endpoint is required"));
413 }
414 _ => panic!("Expected InvalidParam error"),
415 }
416 }
417
418 #[test]
419 fn test_parse_callback_params() {
420 let params =
421 parse_callback_params("https://app.example.com/callback?code=abc123&state=xyz456");
422
423 assert_eq!(params.code, Some("abc123".to_string()));
424 assert_eq!(params.state, Some("xyz456".to_string()));
425 assert_eq!(params.error, None);
426 assert_eq!(params.error_description, None);
427 }
428
429 #[test]
430 fn test_parse_callback_params_error() {
431 let params = parse_callback_params(
432 "https://app.example.com/callback?error=access_denied&error_description=User%20denied%20access"
433 );
434
435 assert_eq!(params.code, None);
436 assert_eq!(params.state, None);
437 assert_eq!(params.error, Some("access_denied".to_string()));
438 assert_eq!(params.error_description, Some("User denied access".to_string()));
439 }
440
441 #[test]
442 fn test_parse_callback_params_relative_url() {
443 let params = parse_callback_params("/callback?code=test&state=test");
444
445 assert_eq!(params.code, Some("test".to_string()));
446 assert_eq!(params.state, Some("test".to_string()));
447 }
448
449 #[test]
450 fn test_build_end_session_url() {
451 let url = build_end_session_url(EndSession {
452 issuer: "https://auth.example.com".into(),
453 id_token_hint: "test_token".into(),
454 post_logout_redirect_uri: Some("https://app.example.com".into()),
455 state: Some("logout_state".into()),
456 end_session_endpoint: None,
457 })
458 .unwrap();
459
460 let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
461
462 assert_eq!(query.get("id_token_hint"), Some(&"test_token".to_string()));
463 assert_eq!(
464 query.get("post_logout_redirect_uri"),
465 Some(&"https://app.example.com".to_string())
466 );
467 assert_eq!(query.get("state"), Some(&"logout_state".to_string()));
468 }
469}