1use std::collections::HashSet;
2use std::future::ready;
3use std::rc::Rc;
4use std::time::{Duration, Instant};
5use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
6
7use crate::webserver::http_client::get_http_client_from_appdata;
8use crate::{app_config::AppConfig, AppState};
9use actix_web::http::header;
10use actix_web::{
11 body::BoxBody,
12 cookie::Cookie,
13 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
14 middleware::Condition,
15 web::{self, Query},
16 Error, HttpMessage, HttpResponse,
17};
18use anyhow::{anyhow, Context};
19use awc::Client;
20use openidconnect::core::{
21 CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
22 CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreRevocableToken,
23 CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType,
24};
25use openidconnect::{
26 core::CoreAuthenticationFlow, url::Url, AsyncHttpClient, Audience, CsrfToken, EndpointMaybeSet,
27 EndpointNotSet, EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, Scope,
28 TokenResponse,
29};
30use openidconnect::{
31 EmptyExtraTokenFields, IdTokenFields, IdTokenVerifier, StandardErrorResponse,
32 StandardTokenResponse,
33};
34use serde::{Deserialize, Serialize};
35use tokio::sync::{RwLock, RwLockReadGuard};
36
37use super::http_client::make_http_client;
38
39type LocalBoxFuture<T> = Pin<Box<dyn Future<Output = T> + 'static>>;
40
41const SQLPAGE_AUTH_COOKIE_NAME: &str = "sqlpage_auth";
42const SQLPAGE_REDIRECT_URI: &str = "/sqlpage/oidc_callback";
43const SQLPAGE_NONCE_COOKIE_NAME: &str = "sqlpage_oidc_nonce";
44const SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX: &str = "sqlpage_oidc_state_";
45const OIDC_CLIENT_MAX_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60);
46const OIDC_CLIENT_MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5);
47const AUTH_COOKIE_EXPIRATION: awc::cookie::time::Duration =
48 actix_web::cookie::time::Duration::days(7);
49const LOGIN_FLOW_STATE_COOKIE_EXPIRATION: awc::cookie::time::Duration =
50 actix_web::cookie::time::Duration::minutes(10);
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
53#[serde(transparent)]
54pub struct OidcAdditionalClaims(pub(crate) serde_json::Map<String, serde_json::Value>);
55
56impl openidconnect::AdditionalClaims for OidcAdditionalClaims {}
57type OidcToken = openidconnect::IdToken<
58 OidcAdditionalClaims,
59 openidconnect::core::CoreGenderClaim,
60 openidconnect::core::CoreJweContentEncryptionAlgorithm,
61 openidconnect::core::CoreJwsSigningAlgorithm,
62>;
63pub type OidcClaims =
64 openidconnect::IdTokenClaims<OidcAdditionalClaims, openidconnect::core::CoreGenderClaim>;
65
66#[derive(Clone, Debug)]
67pub struct OidcConfig {
68 pub issuer_url: IssuerUrl,
69 pub client_id: String,
70 pub client_secret: String,
71 pub protected_paths: Vec<String>,
72 pub public_paths: Vec<String>,
73 pub app_host: String,
74 pub scopes: Vec<Scope>,
75 pub additional_audience_verifier: AudienceVerifier,
76}
77
78impl TryFrom<&AppConfig> for OidcConfig {
79 type Error = Option<&'static str>;
80
81 fn try_from(config: &AppConfig) -> Result<Self, Self::Error> {
82 let issuer_url = config.oidc_issuer_url.as_ref().ok_or(None)?;
83 let client_secret = config.oidc_client_secret.as_ref().ok_or(Some(
84 "The \"oidc_client_secret\" setting is required to authenticate with the OIDC provider",
85 ))?;
86 let protected_paths: Vec<String> = config.oidc_protected_paths.clone();
87 let public_paths: Vec<String> = config.oidc_public_paths.clone();
88
89 let app_host = get_app_host(config);
90
91 Ok(Self {
92 issuer_url: issuer_url.clone(),
93 client_id: config.oidc_client_id.clone(),
94 client_secret: client_secret.clone(),
95 protected_paths,
96 public_paths,
97 scopes: config
98 .oidc_scopes
99 .split_whitespace()
100 .map(|s| Scope::new(s.to_string()))
101 .collect(),
102 app_host: app_host.clone(),
103 additional_audience_verifier: AudienceVerifier::new(
104 config.oidc_additional_trusted_audiences.clone(),
105 ),
106 })
107 }
108}
109
110impl OidcConfig {
111 #[must_use]
112 pub fn is_public_path(&self, path: &str) -> bool {
113 !self.protected_paths.iter().any(|p| path.starts_with(p))
114 || self.public_paths.iter().any(|p| path.starts_with(p))
115 }
116
117 fn create_id_token_verifier<'a>(
119 &'a self,
120 oidc_client: &'a OidcClient,
121 ) -> IdTokenVerifier<'a, CoreJsonWebKey> {
122 oidc_client
123 .id_token_verifier()
124 .set_other_audience_verifier_fn(self.additional_audience_verifier.as_fn())
125 }
126}
127
128fn get_app_host(config: &AppConfig) -> String {
129 if let Some(host) = &config.host {
130 return host.clone();
131 }
132 if let Some(https_domain) = &config.https_domain {
133 return https_domain.clone();
134 }
135
136 let socket_addr = config.listen_on();
137 let ip = socket_addr.ip();
138 let host = if ip.is_unspecified() || ip.is_loopback() {
139 format!("localhost:{}", socket_addr.port())
140 } else {
141 socket_addr.to_string()
142 };
143 log::warn!(
144 "No host or https_domain provided in the configuration, \
145 using \"{host}\" as the app host to build the redirect URL. \
146 This will only work locally. \
147 Disable this warning by providing a value for the \"host\" setting."
148 );
149 host
150}
151
152pub struct ClientWithTime {
153 client: OidcClient,
154 last_update: Instant,
155}
156
157pub struct OidcState {
158 pub config: OidcConfig,
159 client: RwLock<ClientWithTime>,
160}
161
162impl OidcState {
163 pub async fn new(oidc_cfg: OidcConfig, app_config: AppConfig) -> anyhow::Result<Self> {
164 let http_client = make_http_client(&app_config)?;
165 let client = build_oidc_client(&oidc_cfg, &http_client).await?;
166
167 Ok(Self {
168 config: oidc_cfg,
169 client: RwLock::new(ClientWithTime {
170 client,
171 last_update: Instant::now(),
172 }),
173 })
174 }
175
176 async fn refresh(&self, service_request: &ServiceRequest) {
177 let mut write_guard = self.client.write().await;
179 match build_oidc_client_from_appdata(&self.config, service_request).await {
180 Ok(http_client) => {
181 *write_guard = ClientWithTime {
182 client: http_client,
183 last_update: Instant::now(),
184 }
185 }
186 Err(e) => log::error!("Failed to refresh OIDC client: {e:#}"),
187 }
188 }
189
190 pub async fn refresh_if_expired(&self, service_request: &ServiceRequest) {
193 if self.client.read().await.last_update.elapsed() > OIDC_CLIENT_MAX_REFRESH_INTERVAL {
194 self.refresh(service_request).await;
195 }
196 }
197
198 pub async fn refresh_on_error(&self, service_request: &ServiceRequest) {
200 if self.client.read().await.last_update.elapsed() > OIDC_CLIENT_MIN_REFRESH_INTERVAL {
201 self.refresh(service_request).await;
202 }
203 }
204
205 pub async fn get_client(&self) -> RwLockReadGuard<'_, OidcClient> {
207 RwLockReadGuard::map(
208 self.client.read().await,
209 |ClientWithTime { client, .. }| client,
210 )
211 }
212
213 async fn get_token_claims(
215 &self,
216 id_token: OidcToken,
217 expected_nonce: &Nonce,
218 ) -> anyhow::Result<OidcClaims> {
219 let client = &self.get_client().await;
220 let verifier = self.config.create_id_token_verifier(client);
221 let nonce_verifier = |nonce: Option<&Nonce>| check_nonce(nonce, expected_nonce);
222 let claims: OidcClaims = id_token
223 .into_claims(&verifier, nonce_verifier)
224 .map_err(|e| anyhow::anyhow!("Could not verify the ID token: {}", e))?;
225 Ok(claims)
226 }
227}
228
229pub async fn initialize_oidc_state(
230 app_config: &AppConfig,
231) -> anyhow::Result<Option<Arc<OidcState>>> {
232 let oidc_cfg = match OidcConfig::try_from(app_config) {
233 Ok(c) => c,
234 Err(None) => return Ok(None), Err(Some(e)) => return Err(anyhow::anyhow!(e)),
236 };
237
238 Ok(Some(Arc::new(
239 OidcState::new(oidc_cfg, app_config.clone()).await?,
240 )))
241}
242
243async fn build_oidc_client_from_appdata(
244 cfg: &OidcConfig,
245 req: &ServiceRequest,
246) -> anyhow::Result<OidcClient> {
247 let http_client = get_http_client_from_appdata(req)?;
248 build_oidc_client(cfg, http_client).await
249}
250
251async fn build_oidc_client(
252 oidc_cfg: &OidcConfig,
253 http_client: &Client,
254) -> anyhow::Result<OidcClient> {
255 let issuer_url = oidc_cfg.issuer_url.clone();
256 let provider_metadata = discover_provider_metadata(http_client, issuer_url.clone()).await?;
257 let client = make_oidc_client(oidc_cfg, provider_metadata)?;
258 Ok(client)
259}
260
261pub struct OidcMiddleware {
262 oidc_state: Option<Arc<OidcState>>,
263}
264
265impl OidcMiddleware {
266 #[must_use]
267 pub fn new(app_state: &web::Data<AppState>) -> Condition<Self> {
268 let oidc_state = app_state.oidc_state.clone();
269 Condition::new(oidc_state.is_some(), Self { oidc_state })
270 }
271}
272
273async fn discover_provider_metadata(
274 http_client: &awc::Client,
275 issuer_url: IssuerUrl,
276) -> anyhow::Result<openidconnect::core::CoreProviderMetadata> {
277 log::debug!("Discovering provider metadata for {issuer_url}");
278 let provider_metadata = openidconnect::core::CoreProviderMetadata::discover_async(
279 issuer_url,
280 &AwcHttpClient::from_client(http_client),
281 )
282 .await
283 .with_context(|| "Failed to discover OIDC provider metadata".to_string())?;
284 log::debug!("Provider metadata discovered: {provider_metadata:?}");
285 Ok(provider_metadata)
286}
287
288impl<S> Transform<S, ServiceRequest> for OidcMiddleware
289where
290 S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + 'static,
291 S::Future: 'static,
292{
293 type Response = ServiceResponse<BoxBody>;
294 type Error = Error;
295 type InitError = ();
296 type Transform = OidcService<S>;
297 type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
298
299 fn new_transform(&self, service: S) -> Self::Future {
300 match &self.oidc_state {
301 Some(state) => ready(Ok(OidcService::new(service, Arc::clone(state)))),
302 None => ready(Err(())),
303 }
304 }
305}
306
307#[derive(Clone)]
308pub struct OidcService<S> {
309 service: Rc<S>,
310 oidc_state: Arc<OidcState>,
311}
312
313impl<S> OidcService<S>
314where
315 S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error>,
316 S::Future: 'static,
317{
318 pub fn new(service: S, oidc_state: Arc<OidcState>) -> Self {
319 Self {
320 service: Rc::new(service),
321 oidc_state,
322 }
323 }
324}
325
326enum MiddlewareResponse {
327 Forward(ServiceRequest),
328 Respond(ServiceResponse),
329}
330
331async fn handle_request(oidc_state: &OidcState, request: ServiceRequest) -> MiddlewareResponse {
332 log::trace!("Started OIDC middleware request handling");
333 oidc_state.refresh_if_expired(&request).await;
334
335 if request.path() == SQLPAGE_REDIRECT_URI {
336 let response = handle_oidc_callback(oidc_state, request).await;
337 return MiddlewareResponse::Respond(response);
338 }
339
340 match get_authenticated_user_info(oidc_state, &request).await {
341 Ok(Some(claims)) => {
342 log::trace!("Storing authenticated user info in request extensions: {claims:?}");
343 request.extensions_mut().insert(claims);
344 MiddlewareResponse::Forward(request)
345 }
346 Ok(None) => {
347 log::trace!("No authenticated user found");
348 handle_unauthenticated_request(oidc_state, request).await
349 }
350 Err(e) => {
351 log::debug!("An auth cookie is present but could not be verified. Redirecting to OIDC provider to re-authenticate. {e:?}");
352 oidc_state.refresh_on_error(&request).await;
353 handle_unauthenticated_request(oidc_state, request).await
354 }
355 }
356}
357
358async fn handle_unauthenticated_request(
359 oidc_state: &OidcState,
360 request: ServiceRequest,
361) -> MiddlewareResponse {
362 log::debug!("Handling unauthenticated request to {}", request.path());
363
364 if oidc_state.config.is_public_path(request.path()) {
365 return MiddlewareResponse::Forward(request);
366 }
367
368 log::debug!("Redirecting to OIDC provider");
369
370 let initial_url = request.uri().to_string();
371 let response = build_auth_provider_redirect_response(oidc_state, &initial_url).await;
372 MiddlewareResponse::Respond(request.into_response(response))
373}
374
375async fn handle_oidc_callback(oidc_state: &OidcState, request: ServiceRequest) -> ServiceResponse {
376 match process_oidc_callback(oidc_state, &request).await {
377 Ok(response) => request.into_response(response),
378 Err(e) => {
379 log::error!("Failed to process OIDC callback. Refreshing oidc provider metadata, then redirecting to home page: {e:#}");
380 oidc_state.refresh_on_error(&request).await;
381 let resp = build_auth_provider_redirect_response(oidc_state, "/").await;
382 request.into_response(resp)
383 }
384 }
385}
386
387impl<S> Service<ServiceRequest> for OidcService<S>
388where
389 S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + 'static,
390 S::Future: 'static,
391{
392 type Response = ServiceResponse<BoxBody>;
393 type Error = Error;
394 type Future = LocalBoxFuture<Result<Self::Response, Self::Error>>;
395
396 forward_ready!(service);
397
398 fn call(&self, request: ServiceRequest) -> Self::Future {
399 let srv = Rc::clone(&self.service);
400 let oidc_state = Arc::clone(&self.oidc_state);
401 Box::pin(async move {
402 match handle_request(&oidc_state, request).await {
403 MiddlewareResponse::Respond(response) => Ok(response),
404 MiddlewareResponse::Forward(request) => srv.call(request).await,
405 }
406 })
407 }
408}
409
410async fn process_oidc_callback(
411 oidc_state: &OidcState,
412 request: &ServiceRequest,
413) -> anyhow::Result<HttpResponse> {
414 let params = Query::<OidcCallbackParams>::from_query(request.query_string())
415 .with_context(|| format!("{SQLPAGE_REDIRECT_URI}: invalid url parameters"))?
416 .into_inner();
417 log::debug!("Processing OIDC callback with params: {params:?}. Requesting token...");
418 let mut tmp_login_flow_state_cookie = get_tmp_login_flow_state_cookie(request, ¶ms.state)?;
419 let client = oidc_state.get_client().await;
420 let http_client = get_http_client_from_appdata(request)?;
421 let id_token = exchange_code_for_token(&client, http_client, params.clone()).await?;
422 log::debug!("Received token response: {id_token:?}");
423 let LoginFlowState {
424 nonce,
425 redirect_target,
426 } = parse_login_flow_state(&tmp_login_flow_state_cookie)?;
427 let redirect_target = validate_redirect_url(redirect_target.to_string());
428
429 log::info!("Redirecting to {redirect_target} after a successful login");
430 let mut response = build_redirect_response(redirect_target);
431 set_auth_cookie(&mut response, &id_token);
432 let claims = oidc_state
433 .get_token_claims(id_token, &nonce)
434 .await
435 .context("The identity provider returned an invalid ID token")?;
436 log::debug!("{} successfully logged in", claims.subject().as_str());
437 let nonce_cookie = create_final_nonce_cookie(&nonce);
438 response.add_cookie(&nonce_cookie)?;
439 tmp_login_flow_state_cookie.set_path("/"); response.add_removal_cookie(&tmp_login_flow_state_cookie)?;
441 Ok(response)
442}
443
444async fn exchange_code_for_token(
445 oidc_client: &OidcClient,
446 http_client: &awc::Client,
447 oidc_callback_params: OidcCallbackParams,
448) -> anyhow::Result<OidcToken> {
449 let token_response = oidc_client
450 .exchange_code(openidconnect::AuthorizationCode::new(
451 oidc_callback_params.code,
452 ))?
453 .request_async(&AwcHttpClient::from_client(http_client))
454 .await
455 .context("Failed to exchange code for token")?;
456 let access_token = token_response.access_token();
457 log::trace!("Received access token: {}", access_token.secret());
458 let id_token = token_response
459 .id_token()
460 .context("No ID token found in the token response. You may have specified an oauth2 provider that does not support OIDC.")?;
461 Ok(id_token.clone())
462}
463
464fn set_auth_cookie(response: &mut HttpResponse, id_token: &OidcToken) {
465 let id_token_str = id_token.to_string();
466 log::trace!("Setting auth cookie: {SQLPAGE_AUTH_COOKIE_NAME}=\"{id_token_str}\"");
467 let id_token_size_kb = id_token_str.len() / 1024;
468 if id_token_size_kb > 4 {
469 log::warn!(
470 "The ID token cookie from the OIDC provider is {id_token_size_kb}kb. \
471 Large cookies can cause performance issues and may be rejected by browsers or by reverse proxies."
472 );
473 }
474 let cookie = Cookie::build(SQLPAGE_AUTH_COOKIE_NAME, id_token_str)
475 .secure(true)
476 .http_only(true)
477 .max_age(AUTH_COOKIE_EXPIRATION)
478 .same_site(actix_web::cookie::SameSite::Lax)
479 .path("/")
480 .finish();
481
482 response.add_cookie(&cookie).unwrap();
483}
484
485async fn build_auth_provider_redirect_response(
486 oidc_state: &OidcState,
487 initial_url: &str,
488) -> HttpResponse {
489 let AuthUrl { url, params } = build_auth_url(oidc_state).await;
490 let tmp_login_flow_state_cookie = create_tmp_login_flow_state_cookie(¶ms, initial_url);
491 HttpResponse::TemporaryRedirect()
492 .append_header((header::LOCATION, url.to_string()))
493 .cookie(tmp_login_flow_state_cookie)
494 .body("Redirecting...")
495}
496
497fn build_redirect_response(target_url: String) -> HttpResponse {
498 HttpResponse::TemporaryRedirect()
499 .append_header(("Location", target_url))
500 .body("Redirecting...")
501}
502
503async fn get_authenticated_user_info(
505 oidc_state: &OidcState,
506 request: &ServiceRequest,
507) -> anyhow::Result<Option<OidcClaims>> {
508 let Some(cookie) = request.cookie(SQLPAGE_AUTH_COOKIE_NAME) else {
509 return Ok(None);
510 };
511 let cookie_value = cookie.value().to_string();
512 let id_token = OidcToken::from_str(&cookie_value)
513 .with_context(|| format!("Invalid SQLPage auth cookie: {cookie_value:?}"))?;
514
515 let nonce = get_final_nonce_from_cookie(request)?;
516 log::debug!("Verifying id token: {id_token:?}");
517 let claims = oidc_state.get_token_claims(id_token, &nonce).await?;
518 log::debug!("The current user is: {claims:?}");
519 Ok(Some(claims))
520}
521
522pub struct AwcHttpClient<'c> {
523 client: &'c awc::Client,
524}
525
526impl<'c> AwcHttpClient<'c> {
527 #[must_use]
528 pub fn from_client(client: &'c awc::Client) -> Self {
529 Self { client }
530 }
531}
532
533impl<'c> AsyncHttpClient<'c> for AwcHttpClient<'c> {
534 type Error = AwcWrapperError;
535 type Future =
536 Pin<Box<dyn Future<Output = Result<openidconnect::HttpResponse, Self::Error>> + 'c>>;
537
538 fn call(&'c self, request: openidconnect::HttpRequest) -> Self::Future {
539 let client = self.client.clone();
540 Box::pin(async move {
541 execute_oidc_request_with_awc(client, request)
542 .await
543 .map_err(AwcWrapperError)
544 })
545 }
546}
547
548async fn execute_oidc_request_with_awc(
549 client: Client,
550 request: openidconnect::HttpRequest,
551) -> Result<openidconnect::http::Response<Vec<u8>>, anyhow::Error> {
552 let awc_method = awc::http::Method::from_bytes(request.method().as_str().as_bytes())?;
553 let awc_uri = awc::http::Uri::from_str(&request.uri().to_string())?;
554 log::debug!("Executing OIDC request: {awc_method} {awc_uri}");
555 let mut req = client.request(awc_method, awc_uri);
556 for (name, value) in request.headers() {
557 req = req.insert_header((name.as_str(), value.to_str()?));
558 }
559 let (req_head, body) = request.into_parts();
560 let mut response = req.send_body(body).await.map_err(|e| {
561 anyhow!(e.to_string()).context(format!(
562 "Failed to send request: {} {}",
563 &req_head.method, &req_head.uri
564 ))
565 })?;
566 let head = response.headers();
567 let mut resp_builder =
568 openidconnect::http::Response::builder().status(response.status().as_u16());
569 for (name, value) in head {
570 resp_builder = resp_builder.header(name.as_str(), value.to_str()?);
571 }
572 let body = response
573 .body()
574 .await
575 .with_context(|| format!("Couldnt read from {}", &req_head.uri))?;
576 log::debug!(
577 "Received OIDC response with status {}: {}",
578 response.status(),
579 String::from_utf8_lossy(&body)
580 );
581 let resp = resp_builder.body(body.to_vec())?;
582 Ok(resp)
583}
584
585#[derive(Debug)]
586pub struct AwcWrapperError(anyhow::Error);
587
588impl std::fmt::Display for AwcWrapperError {
589 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
590 std::fmt::Display::fmt(&self.0, f)
591 }
592}
593
594type OidcTokenResponse = StandardTokenResponse<
595 IdTokenFields<
596 OidcAdditionalClaims,
597 EmptyExtraTokenFields,
598 CoreGenderClaim,
599 CoreJweContentEncryptionAlgorithm,
600 CoreJwsSigningAlgorithm,
601 >,
602 CoreTokenType,
603>;
604
605type OidcClient = openidconnect::Client<
606 OidcAdditionalClaims,
607 CoreAuthDisplay,
608 CoreGenderClaim,
609 CoreJweContentEncryptionAlgorithm,
610 CoreJsonWebKey,
611 CoreAuthPrompt,
612 StandardErrorResponse<CoreErrorResponseType>,
613 OidcTokenResponse,
614 CoreTokenIntrospectionResponse,
615 CoreRevocableToken,
616 CoreRevocationErrorResponse,
617 EndpointSet,
618 EndpointNotSet,
619 EndpointNotSet,
620 EndpointNotSet,
621 EndpointMaybeSet,
622 EndpointMaybeSet,
623>;
624
625impl std::error::Error for AwcWrapperError {
626 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
627 self.0.source()
628 }
629}
630
631fn make_oidc_client(
632 config: &OidcConfig,
633 provider_metadata: openidconnect::core::CoreProviderMetadata,
634) -> anyhow::Result<OidcClient> {
635 let client_id = openidconnect::ClientId::new(config.client_id.clone());
636 let client_secret = openidconnect::ClientSecret::new(config.client_secret.clone());
637
638 let mut redirect_url = RedirectUrl::new(format!(
639 "https://{}{}",
640 config.app_host, SQLPAGE_REDIRECT_URI,
641 ))
642 .with_context(|| {
643 format!(
644 "Failed to build the redirect URL; invalid app host \"{}\"",
645 config.app_host
646 )
647 })?;
648 let needs_http = match redirect_url.url().host() {
649 Some(openidconnect::url::Host::Domain(domain)) => domain == "localhost",
650 Some(openidconnect::url::Host::Ipv4(_) | openidconnect::url::Host::Ipv6(_)) => true,
651 None => false,
652 };
653 if needs_http {
654 log::debug!("App host seems to be local, changing redirect URL to HTTP");
655 redirect_url = RedirectUrl::new(format!(
656 "http://{}{}",
657 config.app_host, SQLPAGE_REDIRECT_URI,
658 ))?;
659 }
660 log::info!("OIDC redirect URL for {}: {redirect_url}", config.client_id);
661 let client =
662 OidcClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
663 .set_redirect_uri(redirect_url);
664
665 Ok(client)
666}
667
668#[derive(Debug, Deserialize, Clone)]
669struct OidcCallbackParams {
670 code: String,
671 state: CsrfToken,
672}
673
674struct AuthUrl {
675 url: Url,
676 params: AuthUrlParams,
677}
678
679struct AuthUrlParams {
680 csrf_token: CsrfToken,
681 nonce: Nonce,
682}
683
684async fn build_auth_url(oidc_state: &OidcState) -> AuthUrl {
685 let nonce_source = Nonce::new_random();
686 let hashed_nonce = Nonce::new(hash_nonce(&nonce_source));
687 let scopes = &oidc_state.config.scopes;
688 let client_lock = oidc_state.get_client().await;
689 let (url, csrf_token, _nonce) = client_lock
690 .authorize_url(
691 CoreAuthenticationFlow::AuthorizationCode,
692 CsrfToken::new_random,
693 || hashed_nonce,
694 )
695 .add_scopes(scopes.iter().cloned())
696 .url();
697 AuthUrl {
698 url,
699 params: AuthUrlParams {
700 csrf_token,
701 nonce: nonce_source,
702 },
703 }
704}
705
706fn hash_nonce(nonce: &Nonce) -> String {
707 use argon2::password_hash::{rand_core::OsRng, PasswordHasher, SaltString};
708 let salt = SaltString::generate(&mut OsRng);
709 let params = argon2::Params::new(8, 1, 1, Some(16)).expect("bug: invalid Argon2 parameters");
711 let argon2 = argon2::Argon2::new(argon2::Algorithm::Argon2id, argon2::Version::V0x13, params);
712 let hash = argon2
713 .hash_password(nonce.secret().as_bytes(), &salt)
714 .expect("bug: failed to hash nonce");
715 hash.to_string()
716}
717
718fn check_nonce(id_token_nonce: Option<&Nonce>, expected_nonce: &Nonce) -> Result<(), String> {
719 match id_token_nonce {
720 Some(id_token_nonce) => nonce_matches(id_token_nonce, expected_nonce),
721 None => Err("No nonce found in the ID token".to_string()),
722 }
723}
724
725fn nonce_matches(id_token_nonce: &Nonce, state_nonce: &Nonce) -> Result<(), String> {
726 log::debug!(
727 "Checking nonce: {} == {}",
728 id_token_nonce.secret(),
729 state_nonce.secret()
730 );
731 let hash = argon2::password_hash::PasswordHash::new(id_token_nonce.secret()).map_err(|e| {
732 format!(
733 "Failed to parse state nonce ({}): {e}",
734 id_token_nonce.secret()
735 )
736 })?;
737 argon2::password_hash::PasswordVerifier::verify_password(
738 &argon2::Argon2::default(),
739 state_nonce.secret().as_bytes(),
740 &hash,
741 )
742 .map_err(|e| format!("Failed to verify nonce ({}): {e}", state_nonce.secret()))?;
743 log::debug!("Nonce successfully verified");
744 Ok(())
745}
746
747fn create_final_nonce_cookie(nonce: &Nonce) -> Cookie<'_> {
748 Cookie::build(SQLPAGE_NONCE_COOKIE_NAME, nonce.secret())
749 .secure(true)
750 .http_only(true)
751 .same_site(actix_web::cookie::SameSite::Lax)
752 .max_age(AUTH_COOKIE_EXPIRATION)
753 .path("/")
754 .finish()
755}
756
757fn create_tmp_login_flow_state_cookie<'a>(
758 params: &'a AuthUrlParams,
759 initial_url: &'a str,
760) -> Cookie<'a> {
761 let csrf_token = ¶ms.csrf_token;
762 let cookie_name = SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX.to_owned() + csrf_token.secret();
763 let cookie_value = serde_json::to_string(&LoginFlowState {
764 nonce: params.nonce.clone(),
765 redirect_target: initial_url,
766 })
767 .expect("login flow state is always serializable");
768 Cookie::build(cookie_name, cookie_value)
769 .secure(true)
770 .http_only(true)
771 .same_site(actix_web::cookie::SameSite::Lax)
772 .path("/")
773 .max_age(LOGIN_FLOW_STATE_COOKIE_EXPIRATION)
774 .finish()
775}
776
777fn get_final_nonce_from_cookie(request: &ServiceRequest) -> anyhow::Result<Nonce> {
778 let cookie = request
779 .cookie(SQLPAGE_NONCE_COOKIE_NAME)
780 .with_context(|| format!("No {SQLPAGE_NONCE_COOKIE_NAME} cookie found"))?;
781 Ok(Nonce::new(cookie.value().to_string()))
782}
783
784fn get_tmp_login_flow_state_cookie(
785 request: &ServiceRequest,
786 csrf_token: &CsrfToken,
787) -> anyhow::Result<Cookie<'static>> {
788 let cookie_name = SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX.to_owned() + csrf_token.secret();
789 request
790 .cookie(&cookie_name)
791 .with_context(|| format!("No {cookie_name} cookie found"))
792}
793
794#[derive(Debug, Serialize, Deserialize, Clone)]
795struct LoginFlowState<'a> {
796 #[serde(rename = "n")]
797 nonce: Nonce,
798 #[serde(rename = "r")]
799 redirect_target: &'a str,
800}
801
802fn parse_login_flow_state<'a>(cookie: &'a Cookie<'_>) -> anyhow::Result<LoginFlowState<'a>> {
803 serde_json::from_str(cookie.value())
804 .with_context(|| format!("Invalid login flow state cookie: {}", cookie.value()))
805}
806
807#[derive(Clone, Debug)]
809pub struct AudienceVerifier(Option<HashSet<String>>);
810
811impl AudienceVerifier {
812 fn new(additional_trusted_audiences: Option<Vec<String>>) -> Self {
816 AudienceVerifier(additional_trusted_audiences.map(HashSet::from_iter))
817 }
818
819 fn as_fn(&self) -> impl Fn(&Audience) -> bool + '_ {
821 move |aud: &Audience| -> bool {
822 let Some(trusted_set) = &self.0 else {
823 return true;
824 };
825 trusted_set.contains(aud.as_str())
826 }
827 }
828}
829
830fn validate_redirect_url(url: String) -> String {
832 if url.starts_with('/') && !url.starts_with("//") && !url.starts_with(SQLPAGE_REDIRECT_URI) {
833 return url;
834 }
835 log::warn!("Refusing to redirect to {url}");
836 '/'.to_string()
837}