sqlpage/webserver/
oidc.rs

1use std::collections::HashSet;
2use std::future::ready;
3use std::rc::Rc;
4use std::time::{Duration, Instant};
5use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
6
7use crate::webserver::http_client::get_http_client_from_appdata;
8use crate::{app_config::AppConfig, AppState};
9use actix_web::http::header;
10use actix_web::{
11    body::BoxBody,
12    cookie::Cookie,
13    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
14    middleware::Condition,
15    web::{self, Query},
16    Error, HttpMessage, HttpResponse,
17};
18use anyhow::{anyhow, Context};
19use awc::Client;
20use openidconnect::core::{
21    CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
22    CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreRevocableToken,
23    CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType,
24};
25use openidconnect::{
26    core::CoreAuthenticationFlow, url::Url, AsyncHttpClient, Audience, CsrfToken, EndpointMaybeSet,
27    EndpointNotSet, EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, Scope,
28    TokenResponse,
29};
30use openidconnect::{
31    EmptyExtraTokenFields, IdTokenFields, IdTokenVerifier, StandardErrorResponse,
32    StandardTokenResponse,
33};
34use serde::{Deserialize, Serialize};
35use tokio::sync::{RwLock, RwLockReadGuard};
36
37use super::http_client::make_http_client;
38
39type LocalBoxFuture<T> = Pin<Box<dyn Future<Output = T> + 'static>>;
40
41const SQLPAGE_AUTH_COOKIE_NAME: &str = "sqlpage_auth";
42const SQLPAGE_REDIRECT_URI: &str = "/sqlpage/oidc_callback";
43const SQLPAGE_NONCE_COOKIE_NAME: &str = "sqlpage_oidc_nonce";
44const SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX: &str = "sqlpage_oidc_state_";
45const OIDC_CLIENT_MAX_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60);
46const OIDC_CLIENT_MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5);
47const AUTH_COOKIE_EXPIRATION: awc::cookie::time::Duration =
48    actix_web::cookie::time::Duration::days(7);
49const LOGIN_FLOW_STATE_COOKIE_EXPIRATION: awc::cookie::time::Duration =
50    actix_web::cookie::time::Duration::minutes(10);
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
53#[serde(transparent)]
54pub struct OidcAdditionalClaims(pub(crate) serde_json::Map<String, serde_json::Value>);
55
56impl openidconnect::AdditionalClaims for OidcAdditionalClaims {}
57type OidcToken = openidconnect::IdToken<
58    OidcAdditionalClaims,
59    openidconnect::core::CoreGenderClaim,
60    openidconnect::core::CoreJweContentEncryptionAlgorithm,
61    openidconnect::core::CoreJwsSigningAlgorithm,
62>;
63pub type OidcClaims =
64    openidconnect::IdTokenClaims<OidcAdditionalClaims, openidconnect::core::CoreGenderClaim>;
65
66#[derive(Clone, Debug)]
67pub struct OidcConfig {
68    pub issuer_url: IssuerUrl,
69    pub client_id: String,
70    pub client_secret: String,
71    pub protected_paths: Vec<String>,
72    pub public_paths: Vec<String>,
73    pub app_host: String,
74    pub scopes: Vec<Scope>,
75    pub additional_audience_verifier: AudienceVerifier,
76}
77
78impl TryFrom<&AppConfig> for OidcConfig {
79    type Error = Option<&'static str>;
80
81    fn try_from(config: &AppConfig) -> Result<Self, Self::Error> {
82        let issuer_url = config.oidc_issuer_url.as_ref().ok_or(None)?;
83        let client_secret = config.oidc_client_secret.as_ref().ok_or(Some(
84            "The \"oidc_client_secret\" setting is required to authenticate with the OIDC provider",
85        ))?;
86        let protected_paths: Vec<String> = config.oidc_protected_paths.clone();
87        let public_paths: Vec<String> = config.oidc_public_paths.clone();
88
89        let app_host = get_app_host(config);
90
91        Ok(Self {
92            issuer_url: issuer_url.clone(),
93            client_id: config.oidc_client_id.clone(),
94            client_secret: client_secret.clone(),
95            protected_paths,
96            public_paths,
97            scopes: config
98                .oidc_scopes
99                .split_whitespace()
100                .map(|s| Scope::new(s.to_string()))
101                .collect(),
102            app_host: app_host.clone(),
103            additional_audience_verifier: AudienceVerifier::new(
104                config.oidc_additional_trusted_audiences.clone(),
105            ),
106        })
107    }
108}
109
110impl OidcConfig {
111    #[must_use]
112    pub fn is_public_path(&self, path: &str) -> bool {
113        !self.protected_paths.iter().any(|p| path.starts_with(p))
114            || self.public_paths.iter().any(|p| path.starts_with(p))
115    }
116
117    /// Creates a custom ID token verifier that supports multiple issuers
118    fn create_id_token_verifier<'a>(
119        &'a self,
120        oidc_client: &'a OidcClient,
121    ) -> IdTokenVerifier<'a, CoreJsonWebKey> {
122        oidc_client
123            .id_token_verifier()
124            .set_other_audience_verifier_fn(self.additional_audience_verifier.as_fn())
125    }
126}
127
128fn get_app_host(config: &AppConfig) -> String {
129    if let Some(host) = &config.host {
130        return host.clone();
131    }
132    if let Some(https_domain) = &config.https_domain {
133        return https_domain.clone();
134    }
135
136    let socket_addr = config.listen_on();
137    let ip = socket_addr.ip();
138    let host = if ip.is_unspecified() || ip.is_loopback() {
139        format!("localhost:{}", socket_addr.port())
140    } else {
141        socket_addr.to_string()
142    };
143    log::warn!(
144        "No host or https_domain provided in the configuration, \
145         using \"{host}\" as the app host to build the redirect URL. \
146         This will only work locally. \
147         Disable this warning by providing a value for the \"host\" setting."
148    );
149    host
150}
151
152pub struct ClientWithTime {
153    client: OidcClient,
154    last_update: Instant,
155}
156
157pub struct OidcState {
158    pub config: OidcConfig,
159    client: RwLock<ClientWithTime>,
160}
161
162impl OidcState {
163    pub async fn new(oidc_cfg: OidcConfig, app_config: AppConfig) -> anyhow::Result<Self> {
164        let http_client = make_http_client(&app_config)?;
165        let client = build_oidc_client(&oidc_cfg, &http_client).await?;
166
167        Ok(Self {
168            config: oidc_cfg,
169            client: RwLock::new(ClientWithTime {
170                client,
171                last_update: Instant::now(),
172            }),
173        })
174    }
175
176    async fn refresh(&self, service_request: &ServiceRequest) {
177        // Obtain a write lock to prevent concurrent OIDC client refreshes.
178        let mut write_guard = self.client.write().await;
179        match build_oidc_client_from_appdata(&self.config, service_request).await {
180            Ok(http_client) => {
181                *write_guard = ClientWithTime {
182                    client: http_client,
183                    last_update: Instant::now(),
184                }
185            }
186            Err(e) => log::error!("Failed to refresh OIDC client: {e:#}"),
187        }
188    }
189
190    /// Refreshes the OIDC client from the provider metadata URL if it has expired.
191    /// Most providers update their signing keys periodically.
192    pub async fn refresh_if_expired(&self, service_request: &ServiceRequest) {
193        if self.client.read().await.last_update.elapsed() > OIDC_CLIENT_MAX_REFRESH_INTERVAL {
194            self.refresh(service_request).await;
195        }
196    }
197
198    /// When an authentication error is encountered, refresh the OIDC client info faster
199    pub async fn refresh_on_error(&self, service_request: &ServiceRequest) {
200        if self.client.read().await.last_update.elapsed() > OIDC_CLIENT_MIN_REFRESH_INTERVAL {
201            self.refresh(service_request).await;
202        }
203    }
204
205    /// Gets a reference to the oidc client, potentially generating a new one if needed
206    pub async fn get_client(&self) -> RwLockReadGuard<'_, OidcClient> {
207        RwLockReadGuard::map(
208            self.client.read().await,
209            |ClientWithTime { client, .. }| client,
210        )
211    }
212
213    /// Validate and decode the claims of an OIDC token, without refreshing the client.
214    async fn get_token_claims(
215        &self,
216        id_token: OidcToken,
217        expected_nonce: &Nonce,
218    ) -> anyhow::Result<OidcClaims> {
219        let client = &self.get_client().await;
220        let verifier = self.config.create_id_token_verifier(client);
221        let nonce_verifier = |nonce: Option<&Nonce>| check_nonce(nonce, expected_nonce);
222        let claims: OidcClaims = id_token
223            .into_claims(&verifier, nonce_verifier)
224            .map_err(|e| anyhow::anyhow!("Could not verify the ID token: {}", e))?;
225        Ok(claims)
226    }
227}
228
229pub async fn initialize_oidc_state(
230    app_config: &AppConfig,
231) -> anyhow::Result<Option<Arc<OidcState>>> {
232    let oidc_cfg = match OidcConfig::try_from(app_config) {
233        Ok(c) => c,
234        Err(None) => return Ok(None), // OIDC not configured
235        Err(Some(e)) => return Err(anyhow::anyhow!(e)),
236    };
237
238    Ok(Some(Arc::new(
239        OidcState::new(oidc_cfg, app_config.clone()).await?,
240    )))
241}
242
243async fn build_oidc_client_from_appdata(
244    cfg: &OidcConfig,
245    req: &ServiceRequest,
246) -> anyhow::Result<OidcClient> {
247    let http_client = get_http_client_from_appdata(req)?;
248    build_oidc_client(cfg, http_client).await
249}
250
251async fn build_oidc_client(
252    oidc_cfg: &OidcConfig,
253    http_client: &Client,
254) -> anyhow::Result<OidcClient> {
255    let issuer_url = oidc_cfg.issuer_url.clone();
256    let provider_metadata = discover_provider_metadata(http_client, issuer_url.clone()).await?;
257    let client = make_oidc_client(oidc_cfg, provider_metadata)?;
258    Ok(client)
259}
260
261pub struct OidcMiddleware {
262    oidc_state: Option<Arc<OidcState>>,
263}
264
265impl OidcMiddleware {
266    #[must_use]
267    pub fn new(app_state: &web::Data<AppState>) -> Condition<Self> {
268        let oidc_state = app_state.oidc_state.clone();
269        Condition::new(oidc_state.is_some(), Self { oidc_state })
270    }
271}
272
273async fn discover_provider_metadata(
274    http_client: &awc::Client,
275    issuer_url: IssuerUrl,
276) -> anyhow::Result<openidconnect::core::CoreProviderMetadata> {
277    log::debug!("Discovering provider metadata for {issuer_url}");
278    let provider_metadata = openidconnect::core::CoreProviderMetadata::discover_async(
279        issuer_url,
280        &AwcHttpClient::from_client(http_client),
281    )
282    .await
283    .with_context(|| "Failed to discover OIDC provider metadata".to_string())?;
284    log::debug!("Provider metadata discovered: {provider_metadata:?}");
285    Ok(provider_metadata)
286}
287
288impl<S> Transform<S, ServiceRequest> for OidcMiddleware
289where
290    S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + 'static,
291    S::Future: 'static,
292{
293    type Response = ServiceResponse<BoxBody>;
294    type Error = Error;
295    type InitError = ();
296    type Transform = OidcService<S>;
297    type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
298
299    fn new_transform(&self, service: S) -> Self::Future {
300        match &self.oidc_state {
301            Some(state) => ready(Ok(OidcService::new(service, Arc::clone(state)))),
302            None => ready(Err(())),
303        }
304    }
305}
306
307#[derive(Clone)]
308pub struct OidcService<S> {
309    service: Rc<S>,
310    oidc_state: Arc<OidcState>,
311}
312
313impl<S> OidcService<S>
314where
315    S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error>,
316    S::Future: 'static,
317{
318    pub fn new(service: S, oidc_state: Arc<OidcState>) -> Self {
319        Self {
320            service: Rc::new(service),
321            oidc_state,
322        }
323    }
324}
325
326enum MiddlewareResponse {
327    Forward(ServiceRequest),
328    Respond(ServiceResponse),
329}
330
331async fn handle_request(oidc_state: &OidcState, request: ServiceRequest) -> MiddlewareResponse {
332    log::trace!("Started OIDC middleware request handling");
333    oidc_state.refresh_if_expired(&request).await;
334
335    if request.path() == SQLPAGE_REDIRECT_URI {
336        let response = handle_oidc_callback(oidc_state, request).await;
337        return MiddlewareResponse::Respond(response);
338    }
339
340    match get_authenticated_user_info(oidc_state, &request).await {
341        Ok(Some(claims)) => {
342            log::trace!("Storing authenticated user info in request extensions: {claims:?}");
343            request.extensions_mut().insert(claims);
344            MiddlewareResponse::Forward(request)
345        }
346        Ok(None) => {
347            log::trace!("No authenticated user found");
348            handle_unauthenticated_request(oidc_state, request).await
349        }
350        Err(e) => {
351            log::debug!("An auth cookie is present but could not be verified. Redirecting to OIDC provider to re-authenticate. {e:?}");
352            oidc_state.refresh_on_error(&request).await;
353            handle_unauthenticated_request(oidc_state, request).await
354        }
355    }
356}
357
358async fn handle_unauthenticated_request(
359    oidc_state: &OidcState,
360    request: ServiceRequest,
361) -> MiddlewareResponse {
362    log::debug!("Handling unauthenticated request to {}", request.path());
363
364    if oidc_state.config.is_public_path(request.path()) {
365        return MiddlewareResponse::Forward(request);
366    }
367
368    log::debug!("Redirecting to OIDC provider");
369
370    let initial_url = request.uri().to_string();
371    let response = build_auth_provider_redirect_response(oidc_state, &initial_url).await;
372    MiddlewareResponse::Respond(request.into_response(response))
373}
374
375async fn handle_oidc_callback(oidc_state: &OidcState, request: ServiceRequest) -> ServiceResponse {
376    match process_oidc_callback(oidc_state, &request).await {
377        Ok(response) => request.into_response(response),
378        Err(e) => {
379            log::error!("Failed to process OIDC callback. Refreshing oidc provider metadata, then redirecting to home page: {e:#}");
380            oidc_state.refresh_on_error(&request).await;
381            let resp = build_auth_provider_redirect_response(oidc_state, "/").await;
382            request.into_response(resp)
383        }
384    }
385}
386
387impl<S> Service<ServiceRequest> for OidcService<S>
388where
389    S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + 'static,
390    S::Future: 'static,
391{
392    type Response = ServiceResponse<BoxBody>;
393    type Error = Error;
394    type Future = LocalBoxFuture<Result<Self::Response, Self::Error>>;
395
396    forward_ready!(service);
397
398    fn call(&self, request: ServiceRequest) -> Self::Future {
399        let srv = Rc::clone(&self.service);
400        let oidc_state = Arc::clone(&self.oidc_state);
401        Box::pin(async move {
402            match handle_request(&oidc_state, request).await {
403                MiddlewareResponse::Respond(response) => Ok(response),
404                MiddlewareResponse::Forward(request) => srv.call(request).await,
405            }
406        })
407    }
408}
409
410async fn process_oidc_callback(
411    oidc_state: &OidcState,
412    request: &ServiceRequest,
413) -> anyhow::Result<HttpResponse> {
414    let params = Query::<OidcCallbackParams>::from_query(request.query_string())
415        .with_context(|| format!("{SQLPAGE_REDIRECT_URI}: invalid url parameters"))?
416        .into_inner();
417    log::debug!("Processing OIDC callback with params: {params:?}. Requesting token...");
418    let mut tmp_login_flow_state_cookie = get_tmp_login_flow_state_cookie(request, &params.state)?;
419    let client = oidc_state.get_client().await;
420    let http_client = get_http_client_from_appdata(request)?;
421    let id_token = exchange_code_for_token(&client, http_client, params.clone()).await?;
422    log::debug!("Received token response: {id_token:?}");
423    let LoginFlowState {
424        nonce,
425        redirect_target,
426    } = parse_login_flow_state(&tmp_login_flow_state_cookie)?;
427    let redirect_target = validate_redirect_url(redirect_target.to_string());
428
429    log::info!("Redirecting to {redirect_target} after a successful login");
430    let mut response = build_redirect_response(redirect_target);
431    set_auth_cookie(&mut response, &id_token);
432    let claims = oidc_state
433        .get_token_claims(id_token, &nonce)
434        .await
435        .context("The identity provider returned an invalid ID token")?;
436    log::debug!("{} successfully logged in", claims.subject().as_str());
437    let nonce_cookie = create_final_nonce_cookie(&nonce);
438    response.add_cookie(&nonce_cookie)?;
439    tmp_login_flow_state_cookie.set_path("/"); // Required to clean up the cookie
440    response.add_removal_cookie(&tmp_login_flow_state_cookie)?;
441    Ok(response)
442}
443
444async fn exchange_code_for_token(
445    oidc_client: &OidcClient,
446    http_client: &awc::Client,
447    oidc_callback_params: OidcCallbackParams,
448) -> anyhow::Result<OidcToken> {
449    let token_response = oidc_client
450        .exchange_code(openidconnect::AuthorizationCode::new(
451            oidc_callback_params.code,
452        ))?
453        .request_async(&AwcHttpClient::from_client(http_client))
454        .await
455        .context("Failed to exchange code for token")?;
456    let access_token = token_response.access_token();
457    log::trace!("Received access token: {}", access_token.secret());
458    let id_token = token_response
459        .id_token()
460        .context("No ID token found in the token response. You may have specified an oauth2 provider that does not support OIDC.")?;
461    Ok(id_token.clone())
462}
463
464fn set_auth_cookie(response: &mut HttpResponse, id_token: &OidcToken) {
465    let id_token_str = id_token.to_string();
466    log::trace!("Setting auth cookie: {SQLPAGE_AUTH_COOKIE_NAME}=\"{id_token_str}\"");
467    let id_token_size_kb = id_token_str.len() / 1024;
468    if id_token_size_kb > 4 {
469        log::warn!(
470            "The ID token cookie from the OIDC provider is {id_token_size_kb}kb. \
471             Large cookies can cause performance issues and may be rejected by browsers or by reverse proxies."
472        );
473    }
474    let cookie = Cookie::build(SQLPAGE_AUTH_COOKIE_NAME, id_token_str)
475        .secure(true)
476        .http_only(true)
477        .max_age(AUTH_COOKIE_EXPIRATION)
478        .same_site(actix_web::cookie::SameSite::Lax)
479        .path("/")
480        .finish();
481
482    response.add_cookie(&cookie).unwrap();
483}
484
485async fn build_auth_provider_redirect_response(
486    oidc_state: &OidcState,
487    initial_url: &str,
488) -> HttpResponse {
489    let AuthUrl { url, params } = build_auth_url(oidc_state).await;
490    let tmp_login_flow_state_cookie = create_tmp_login_flow_state_cookie(&params, initial_url);
491    HttpResponse::TemporaryRedirect()
492        .append_header((header::LOCATION, url.to_string()))
493        .cookie(tmp_login_flow_state_cookie)
494        .body("Redirecting...")
495}
496
497fn build_redirect_response(target_url: String) -> HttpResponse {
498    HttpResponse::TemporaryRedirect()
499        .append_header(("Location", target_url))
500        .body("Redirecting...")
501}
502
503/// Returns the claims from the ID token in the `SQLPage` auth cookie.
504async fn get_authenticated_user_info(
505    oidc_state: &OidcState,
506    request: &ServiceRequest,
507) -> anyhow::Result<Option<OidcClaims>> {
508    let Some(cookie) = request.cookie(SQLPAGE_AUTH_COOKIE_NAME) else {
509        return Ok(None);
510    };
511    let cookie_value = cookie.value().to_string();
512    let id_token = OidcToken::from_str(&cookie_value)
513        .with_context(|| format!("Invalid SQLPage auth cookie: {cookie_value:?}"))?;
514
515    let nonce = get_final_nonce_from_cookie(request)?;
516    log::debug!("Verifying id token: {id_token:?}");
517    let claims = oidc_state.get_token_claims(id_token, &nonce).await?;
518    log::debug!("The current user is: {claims:?}");
519    Ok(Some(claims))
520}
521
522pub struct AwcHttpClient<'c> {
523    client: &'c awc::Client,
524}
525
526impl<'c> AwcHttpClient<'c> {
527    #[must_use]
528    pub fn from_client(client: &'c awc::Client) -> Self {
529        Self { client }
530    }
531}
532
533impl<'c> AsyncHttpClient<'c> for AwcHttpClient<'c> {
534    type Error = AwcWrapperError;
535    type Future =
536        Pin<Box<dyn Future<Output = Result<openidconnect::HttpResponse, Self::Error>> + 'c>>;
537
538    fn call(&'c self, request: openidconnect::HttpRequest) -> Self::Future {
539        let client = self.client.clone();
540        Box::pin(async move {
541            execute_oidc_request_with_awc(client, request)
542                .await
543                .map_err(AwcWrapperError)
544        })
545    }
546}
547
548async fn execute_oidc_request_with_awc(
549    client: Client,
550    request: openidconnect::HttpRequest,
551) -> Result<openidconnect::http::Response<Vec<u8>>, anyhow::Error> {
552    let awc_method = awc::http::Method::from_bytes(request.method().as_str().as_bytes())?;
553    let awc_uri = awc::http::Uri::from_str(&request.uri().to_string())?;
554    log::debug!("Executing OIDC request: {awc_method} {awc_uri}");
555    let mut req = client.request(awc_method, awc_uri);
556    for (name, value) in request.headers() {
557        req = req.insert_header((name.as_str(), value.to_str()?));
558    }
559    let (req_head, body) = request.into_parts();
560    let mut response = req.send_body(body).await.map_err(|e| {
561        anyhow!(e.to_string()).context(format!(
562            "Failed to send request: {} {}",
563            &req_head.method, &req_head.uri
564        ))
565    })?;
566    let head = response.headers();
567    let mut resp_builder =
568        openidconnect::http::Response::builder().status(response.status().as_u16());
569    for (name, value) in head {
570        resp_builder = resp_builder.header(name.as_str(), value.to_str()?);
571    }
572    let body = response
573        .body()
574        .await
575        .with_context(|| format!("Couldnt read from {}", &req_head.uri))?;
576    log::debug!(
577        "Received OIDC response with status {}: {}",
578        response.status(),
579        String::from_utf8_lossy(&body)
580    );
581    let resp = resp_builder.body(body.to_vec())?;
582    Ok(resp)
583}
584
585#[derive(Debug)]
586pub struct AwcWrapperError(anyhow::Error);
587
588impl std::fmt::Display for AwcWrapperError {
589    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
590        std::fmt::Display::fmt(&self.0, f)
591    }
592}
593
594type OidcTokenResponse = StandardTokenResponse<
595    IdTokenFields<
596        OidcAdditionalClaims,
597        EmptyExtraTokenFields,
598        CoreGenderClaim,
599        CoreJweContentEncryptionAlgorithm,
600        CoreJwsSigningAlgorithm,
601    >,
602    CoreTokenType,
603>;
604
605type OidcClient = openidconnect::Client<
606    OidcAdditionalClaims,
607    CoreAuthDisplay,
608    CoreGenderClaim,
609    CoreJweContentEncryptionAlgorithm,
610    CoreJsonWebKey,
611    CoreAuthPrompt,
612    StandardErrorResponse<CoreErrorResponseType>,
613    OidcTokenResponse,
614    CoreTokenIntrospectionResponse,
615    CoreRevocableToken,
616    CoreRevocationErrorResponse,
617    EndpointSet,
618    EndpointNotSet,
619    EndpointNotSet,
620    EndpointNotSet,
621    EndpointMaybeSet,
622    EndpointMaybeSet,
623>;
624
625impl std::error::Error for AwcWrapperError {
626    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
627        self.0.source()
628    }
629}
630
631fn make_oidc_client(
632    config: &OidcConfig,
633    provider_metadata: openidconnect::core::CoreProviderMetadata,
634) -> anyhow::Result<OidcClient> {
635    let client_id = openidconnect::ClientId::new(config.client_id.clone());
636    let client_secret = openidconnect::ClientSecret::new(config.client_secret.clone());
637
638    let mut redirect_url = RedirectUrl::new(format!(
639        "https://{}{}",
640        config.app_host, SQLPAGE_REDIRECT_URI,
641    ))
642    .with_context(|| {
643        format!(
644            "Failed to build the redirect URL; invalid app host \"{}\"",
645            config.app_host
646        )
647    })?;
648    let needs_http = match redirect_url.url().host() {
649        Some(openidconnect::url::Host::Domain(domain)) => domain == "localhost",
650        Some(openidconnect::url::Host::Ipv4(_) | openidconnect::url::Host::Ipv6(_)) => true,
651        None => false,
652    };
653    if needs_http {
654        log::debug!("App host seems to be local, changing redirect URL to HTTP");
655        redirect_url = RedirectUrl::new(format!(
656            "http://{}{}",
657            config.app_host, SQLPAGE_REDIRECT_URI,
658        ))?;
659    }
660    log::info!("OIDC redirect URL for {}: {redirect_url}", config.client_id);
661    let client =
662        OidcClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
663            .set_redirect_uri(redirect_url);
664
665    Ok(client)
666}
667
668#[derive(Debug, Deserialize, Clone)]
669struct OidcCallbackParams {
670    code: String,
671    state: CsrfToken,
672}
673
674struct AuthUrl {
675    url: Url,
676    params: AuthUrlParams,
677}
678
679struct AuthUrlParams {
680    csrf_token: CsrfToken,
681    nonce: Nonce,
682}
683
684async fn build_auth_url(oidc_state: &OidcState) -> AuthUrl {
685    let nonce_source = Nonce::new_random();
686    let hashed_nonce = Nonce::new(hash_nonce(&nonce_source));
687    let scopes = &oidc_state.config.scopes;
688    let client_lock = oidc_state.get_client().await;
689    let (url, csrf_token, _nonce) = client_lock
690        .authorize_url(
691            CoreAuthenticationFlow::AuthorizationCode,
692            CsrfToken::new_random,
693            || hashed_nonce,
694        )
695        .add_scopes(scopes.iter().cloned())
696        .url();
697    AuthUrl {
698        url,
699        params: AuthUrlParams {
700            csrf_token,
701            nonce: nonce_source,
702        },
703    }
704}
705
706fn hash_nonce(nonce: &Nonce) -> String {
707    use argon2::password_hash::{rand_core::OsRng, PasswordHasher, SaltString};
708    let salt = SaltString::generate(&mut OsRng);
709    // low-cost parameters: oidc tokens are short-lived and the source nonce is high-entropy
710    let params = argon2::Params::new(8, 1, 1, Some(16)).expect("bug: invalid Argon2 parameters");
711    let argon2 = argon2::Argon2::new(argon2::Algorithm::Argon2id, argon2::Version::V0x13, params);
712    let hash = argon2
713        .hash_password(nonce.secret().as_bytes(), &salt)
714        .expect("bug: failed to hash nonce");
715    hash.to_string()
716}
717
718fn check_nonce(id_token_nonce: Option<&Nonce>, expected_nonce: &Nonce) -> Result<(), String> {
719    match id_token_nonce {
720        Some(id_token_nonce) => nonce_matches(id_token_nonce, expected_nonce),
721        None => Err("No nonce found in the ID token".to_string()),
722    }
723}
724
725fn nonce_matches(id_token_nonce: &Nonce, state_nonce: &Nonce) -> Result<(), String> {
726    log::debug!(
727        "Checking nonce: {} == {}",
728        id_token_nonce.secret(),
729        state_nonce.secret()
730    );
731    let hash = argon2::password_hash::PasswordHash::new(id_token_nonce.secret()).map_err(|e| {
732        format!(
733            "Failed to parse state nonce ({}): {e}",
734            id_token_nonce.secret()
735        )
736    })?;
737    argon2::password_hash::PasswordVerifier::verify_password(
738        &argon2::Argon2::default(),
739        state_nonce.secret().as_bytes(),
740        &hash,
741    )
742    .map_err(|e| format!("Failed to verify nonce ({}): {e}", state_nonce.secret()))?;
743    log::debug!("Nonce successfully verified");
744    Ok(())
745}
746
747fn create_final_nonce_cookie(nonce: &Nonce) -> Cookie<'_> {
748    Cookie::build(SQLPAGE_NONCE_COOKIE_NAME, nonce.secret())
749        .secure(true)
750        .http_only(true)
751        .same_site(actix_web::cookie::SameSite::Lax)
752        .max_age(AUTH_COOKIE_EXPIRATION)
753        .path("/")
754        .finish()
755}
756
757fn create_tmp_login_flow_state_cookie<'a>(
758    params: &'a AuthUrlParams,
759    initial_url: &'a str,
760) -> Cookie<'a> {
761    let csrf_token = &params.csrf_token;
762    let cookie_name = SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX.to_owned() + csrf_token.secret();
763    let cookie_value = serde_json::to_string(&LoginFlowState {
764        nonce: params.nonce.clone(),
765        redirect_target: initial_url,
766    })
767    .expect("login flow state is always serializable");
768    Cookie::build(cookie_name, cookie_value)
769        .secure(true)
770        .http_only(true)
771        .same_site(actix_web::cookie::SameSite::Lax)
772        .path("/")
773        .max_age(LOGIN_FLOW_STATE_COOKIE_EXPIRATION)
774        .finish()
775}
776
777fn get_final_nonce_from_cookie(request: &ServiceRequest) -> anyhow::Result<Nonce> {
778    let cookie = request
779        .cookie(SQLPAGE_NONCE_COOKIE_NAME)
780        .with_context(|| format!("No {SQLPAGE_NONCE_COOKIE_NAME} cookie found"))?;
781    Ok(Nonce::new(cookie.value().to_string()))
782}
783
784fn get_tmp_login_flow_state_cookie(
785    request: &ServiceRequest,
786    csrf_token: &CsrfToken,
787) -> anyhow::Result<Cookie<'static>> {
788    let cookie_name = SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX.to_owned() + csrf_token.secret();
789    request
790        .cookie(&cookie_name)
791        .with_context(|| format!("No {cookie_name} cookie found"))
792}
793
794#[derive(Debug, Serialize, Deserialize, Clone)]
795struct LoginFlowState<'a> {
796    #[serde(rename = "n")]
797    nonce: Nonce,
798    #[serde(rename = "r")]
799    redirect_target: &'a str,
800}
801
802fn parse_login_flow_state<'a>(cookie: &'a Cookie<'_>) -> anyhow::Result<LoginFlowState<'a>> {
803    serde_json::from_str(cookie.value())
804        .with_context(|| format!("Invalid login flow state cookie: {}", cookie.value()))
805}
806
807/// Given an audience, verify if it is trusted. The `client_id` is always trusted, independently of this function.
808#[derive(Clone, Debug)]
809pub struct AudienceVerifier(Option<HashSet<String>>);
810
811impl AudienceVerifier {
812    /// JWT audiences (aud claim) are always required to contain the `client_id`, but they can also contain additional audiences.
813    /// By default we allow any additional audience.
814    /// The user can restrict the allowed additional audiences by providing a list of trusted audiences.
815    fn new(additional_trusted_audiences: Option<Vec<String>>) -> Self {
816        AudienceVerifier(additional_trusted_audiences.map(HashSet::from_iter))
817    }
818
819    /// Returns a function that given an audience, verifies if it is trusted.
820    fn as_fn(&self) -> impl Fn(&Audience) -> bool + '_ {
821        move |aud: &Audience| -> bool {
822            let Some(trusted_set) = &self.0 else {
823                return true;
824            };
825            trusted_set.contains(aud.as_str())
826        }
827    }
828}
829
830/// Validate that a redirect URL is safe to use (prevents open redirect attacks)
831fn validate_redirect_url(url: String) -> String {
832    if url.starts_with('/') && !url.starts_with("//") && !url.starts_with(SQLPAGE_REDIRECT_URI) {
833        return url;
834    }
835    log::warn!("Refusing to redirect to {url}");
836    '/'.to_string()
837}