samling/auth/sso/
microsoft.rs

1use bytes::Bytes;
2use jsonwebtoken::{
3    jwk::{AlgorithmParameters, JwkSet},
4    DecodingKey, Validation,
5};
6use reqwest::StatusCode;
7use schemars::JsonSchema;
8use serde::Deserialize;
9
10use crate::{Error, Result};
11
12#[derive(Debug, Clone, Deserialize, JsonSchema)]
13#[serde(rename_all = "camelCase")]
14pub struct MicrosoftCredentials {
15    id_token_claims: MicrosoftClaims,
16    pub id_token: String,
17    pub access_token: String,
18}
19
20impl MicrosoftCredentials {
21    pub(crate) fn unverified_id_token_claims(&self) -> &MicrosoftClaims {
22        &self.id_token_claims
23    }
24}
25
26#[derive(Debug, Clone, Deserialize, JsonSchema)]
27pub(crate) struct MicrosoftClaims {
28    // pub ver: String,
29    // pub iss: String,
30    // pub sub: String,
31    // pub aud: String,
32    // pub exp: DateTime<FixedOffset>,
33    // pub iat: DateTime<FixedOffset>,
34    // pub nbf: DateTime<FixedOffset>,
35    pub name: String,
36    // pub preferred_username: String,
37    // pub oid: String,
38    pub email: String,
39    // pub tid: String,
40    // pub nonce: String,
41    // pub aio: String,
42}
43
44pub(crate) async fn get_profile_image(access_token: &str) -> Result<Option<Bytes>> {
45    let request = reqwest::Client::new()
46        .get("https://graph.microsoft.com/v1.0/me/photo/$value")
47        .bearer_auth(access_token);
48    let resp = request.send().await?;
49    if resp.status() == StatusCode::NOT_FOUND {
50        Ok(None)
51    } else {
52        resp.error_for_status_ref()?;
53        Ok(Some(resp.bytes().await?))
54    }
55}
56
57impl MicrosoftClaims {
58    pub(crate) async fn verify(audience: &[&str], login: &MicrosoftCredentials) -> Result<Self> {
59        // TODO: Cache response for as long as cache-control header allows (22967 seconds currently)
60        let resp =
61            reqwest::get("https://login.microsoftonline.com/common/discovery/v2.0/keys").await?;
62        let jwks: JwkSet = resp.json().await?;
63        let header = jsonwebtoken::decode_header(&login.id_token)
64            .map_err(|err| Error::InvalidToken(err.to_string()))?;
65        if let Some(kid) = header.kid {
66            if let Some(jwk) = jwks.find(&kid) {
67                let mut validation = Validation::new(header.alg);
68                validation.set_audience(audience);
69                match &jwk.algorithm {
70                    AlgorithmParameters::RSA(rsa) => {
71                        let key = DecodingKey::from_rsa_components(&rsa.n, &rsa.e)
72                            .map_err(|err| Error::InvalidToken(err.to_string()))?;
73                        let decoded = jsonwebtoken::decode::<MicrosoftClaims>(
74                            &login.id_token,
75                            &key,
76                            &validation,
77                        )
78                        .map_err(|err| Error::InvalidToken(err.to_string()))?;
79                        Ok(decoded.claims)
80                    }
81                    other => Err(Error::InvalidToken(format!(
82                        "Microsoft only supports RSA but got algorithm: {other:?}",
83                    ))),
84                }
85            } else {
86                Err(Error::InvalidToken("Failed to find token key".into()))
87            }
88        } else {
89            Err(Error::InvalidToken("No `kid` value found".into()))
90        }
91    }
92}