Skip to main content

rustauth_oauth/oauth2/
tokens.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use time::{Duration, OffsetDateTime};
4
5use super::error::OAuthError;
6use super::types::ClientSecret;
7
8const MAX_TOKEN_EXPIRY_SECONDS: i64 = 10 * 365 * 24 * 60 * 60;
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(untagged)]
12pub enum ClientId {
13    Single(String),
14    Multiple(Vec<String>),
15}
16
17impl ClientId {
18    pub fn primary(&self) -> Option<&str> {
19        match self {
20            Self::Single(value) if !value.is_empty() => Some(value),
21            Self::Single(_) => None,
22            Self::Multiple(values) => values
23                .first()
24                .map(String::as_str)
25                .filter(|value| !value.is_empty()),
26        }
27    }
28}
29
30impl From<&str> for ClientId {
31    fn from(value: &str) -> Self {
32        Self::Single(value.to_owned())
33    }
34}
35
36impl From<String> for ClientId {
37    fn from(value: String) -> Self {
38        Self::Single(value)
39    }
40}
41
42impl From<Vec<String>> for ClientId {
43    fn from(value: Vec<String>) -> Self {
44        Self::Multiple(value)
45    }
46}
47
48#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
49pub struct ProviderOptions {
50    pub client_id: Option<ClientId>,
51    #[serde(
52        default,
53        serialize_with = "serialize_client_secret",
54        deserialize_with = "deserialize_client_secret"
55    )]
56    pub client_secret: Option<ClientSecret>,
57    pub scope: Vec<String>,
58    pub disable_default_scope: bool,
59    pub redirect_uri: Option<String>,
60    pub authorization_endpoint: Option<String>,
61    pub client_key: Option<String>,
62    pub disable_id_token_sign_in: bool,
63    pub disable_implicit_sign_up: bool,
64    pub disable_sign_up: bool,
65    pub prompt: Option<String>,
66    pub response_mode: Option<String>,
67    pub override_user_info_on_sign_in: bool,
68}
69
70impl ProviderOptions {
71    pub fn client_secret_str(&self) -> Option<&str> {
72        self.client_secret.as_ref().map(ClientSecret::expose_secret)
73    }
74
75    pub fn with_client_secret(mut self, secret: impl Into<String>) -> Result<Self, OAuthError> {
76        self.client_secret = Some(ClientSecret::new(secret)?);
77        Ok(self)
78    }
79}
80
81fn serialize_client_secret<S: Serializer>(
82    secret: &Option<ClientSecret>,
83    serializer: S,
84) -> Result<S::Ok, S::Error> {
85    match secret {
86        Some(secret) => serializer.serialize_some(secret.expose_secret()),
87        None => serializer.serialize_none(),
88    }
89}
90
91fn deserialize_client_secret<'de, D: Deserializer<'de>>(
92    deserializer: D,
93) -> Result<Option<ClientSecret>, D::Error> {
94    Option::<String>::deserialize(deserializer)?
95        .map(|value| ClientSecret::new(value).map_err(serde::de::Error::custom))
96        .transpose()
97}
98
99#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
100pub struct OAuth2Tokens {
101    pub token_type: Option<String>,
102    pub access_token: Option<String>,
103    pub refresh_token: Option<String>,
104    pub access_token_expires_at: Option<OffsetDateTime>,
105    pub refresh_token_expires_at: Option<OffsetDateTime>,
106    pub scopes: Vec<String>,
107    pub id_token: Option<String>,
108    pub raw: Value,
109}
110
111impl Default for OAuth2Tokens {
112    fn default() -> Self {
113        Self {
114            token_type: None,
115            access_token: None,
116            refresh_token: None,
117            access_token_expires_at: None,
118            refresh_token_expires_at: None,
119            scopes: Vec::new(),
120            id_token: None,
121            raw: Value::Null,
122        }
123    }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub struct OAuth2UserInfo {
128    pub id: String,
129    pub name: Option<String>,
130    pub email: Option<String>,
131    pub image: Option<String>,
132    pub email_verified: bool,
133}
134
135pub fn get_primary_client_id(client_id: &Option<ClientId>) -> Option<&str> {
136    client_id.as_ref().and_then(ClientId::primary)
137}
138
139pub fn get_oauth2_tokens(data: Value) -> Result<OAuth2Tokens, OAuthError> {
140    let object = data.as_object().ok_or_else(|| {
141        OAuthError::InvalidTokenResponse("token response must be a JSON object".to_owned())
142    })?;
143    let now = OffsetDateTime::now_utc();
144    let access_token = optional_string_field(object, "access_token")?;
145    let refresh_token = optional_string_field(object, "refresh_token")?;
146    let id_token = optional_string_field(object, "id_token")?;
147    if access_token.is_none() && refresh_token.is_none() && id_token.is_none() {
148        return Err(OAuthError::InvalidTokenResponse(
149            "token response must include access_token, refresh_token, or id_token".to_owned(),
150        ));
151    }
152
153    Ok(OAuth2Tokens {
154        token_type: optional_string_field(object, "token_type")?,
155        access_token,
156        refresh_token,
157        access_token_expires_at: expires_at(object, "expires_in", now)?,
158        refresh_token_expires_at: expires_at(object, "refresh_token_expires_in", now)?,
159        scopes: scopes_field(object.get("scope"))?,
160        id_token,
161        raw: data,
162    })
163}
164
165fn optional_string_field(
166    object: &serde_json::Map<String, Value>,
167    key: &'static str,
168) -> Result<Option<String>, OAuthError> {
169    match object.get(key) {
170        Some(Value::String(value)) => Ok(Some(value.clone())),
171        Some(_) => Err(OAuthError::InvalidTokenResponse(format!(
172            "`{key}` must be a string"
173        ))),
174        None => Ok(None),
175    }
176}
177
178fn expires_at(
179    object: &serde_json::Map<String, Value>,
180    key: &'static str,
181    now: OffsetDateTime,
182) -> Result<Option<OffsetDateTime>, OAuthError> {
183    let Some(value) = object.get(key) else {
184        return Ok(None);
185    };
186    let seconds = value.as_i64().ok_or_else(|| {
187        OAuthError::InvalidTokenResponse(format!("`{key}` must be an integer number of seconds"))
188    })?;
189    if !(0..=MAX_TOKEN_EXPIRY_SECONDS).contains(&seconds) {
190        return Err(OAuthError::InvalidTokenResponse(format!(
191            "`{key}` must be between 0 and {MAX_TOKEN_EXPIRY_SECONDS} seconds"
192        )));
193    }
194    now.checked_add(Duration::seconds(seconds))
195        .ok_or_else(|| OAuthError::InvalidTokenResponse(format!("`{key}` is out of range")))
196        .map(Some)
197}
198
199fn scopes_field(value: Option<&Value>) -> Result<Vec<String>, OAuthError> {
200    match value {
201        Some(Value::String(scope)) => Ok(scope.split_whitespace().map(str::to_owned).collect()),
202        Some(Value::Array(scopes)) => scopes
203            .iter()
204            .map(|value| {
205                value.as_str().map(str::to_owned).ok_or_else(|| {
206                    OAuthError::InvalidTokenResponse(
207                        "`scope` array values must be strings".to_owned(),
208                    )
209                })
210            })
211            .collect(),
212        Some(_) => Err(OAuthError::InvalidTokenResponse(
213            "`scope` must be a string or string array".to_owned(),
214        )),
215        None => Ok(Vec::new()),
216    }
217}