1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use time::{Duration, OffsetDateTime};
4
5use super::error::OAuthError;
6use super::types::ClientSecret;
7
8const MAX_TOKEN_EXPIRY_SECONDS: i64 = 10 * 365 * 24 * 60 * 60;
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(untagged)]
12pub enum ClientId {
13 Single(String),
14 Multiple(Vec<String>),
15}
16
17impl ClientId {
18 pub fn primary(&self) -> Option<&str> {
19 match self {
20 Self::Single(value) if !value.is_empty() => Some(value),
21 Self::Single(_) => None,
22 Self::Multiple(values) => values
23 .first()
24 .map(String::as_str)
25 .filter(|value| !value.is_empty()),
26 }
27 }
28}
29
30impl From<&str> for ClientId {
31 fn from(value: &str) -> Self {
32 Self::Single(value.to_owned())
33 }
34}
35
36impl From<String> for ClientId {
37 fn from(value: String) -> Self {
38 Self::Single(value)
39 }
40}
41
42impl From<Vec<String>> for ClientId {
43 fn from(value: Vec<String>) -> Self {
44 Self::Multiple(value)
45 }
46}
47
48#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
49pub struct ProviderOptions {
50 pub client_id: Option<ClientId>,
51 #[serde(
52 default,
53 serialize_with = "serialize_client_secret",
54 deserialize_with = "deserialize_client_secret"
55 )]
56 pub client_secret: Option<ClientSecret>,
57 pub scope: Vec<String>,
58 pub disable_default_scope: bool,
59 pub redirect_uri: Option<String>,
60 pub authorization_endpoint: Option<String>,
61 pub client_key: Option<String>,
62 pub disable_id_token_sign_in: bool,
63 pub disable_implicit_sign_up: bool,
64 pub disable_sign_up: bool,
65 pub prompt: Option<String>,
66 pub response_mode: Option<String>,
67 pub override_user_info_on_sign_in: bool,
68}
69
70impl ProviderOptions {
71 pub fn client_secret_str(&self) -> Option<&str> {
72 self.client_secret.as_ref().map(ClientSecret::expose_secret)
73 }
74
75 pub fn with_client_secret(mut self, secret: impl Into<String>) -> Result<Self, OAuthError> {
76 self.client_secret = Some(ClientSecret::new(secret)?);
77 Ok(self)
78 }
79}
80
81fn serialize_client_secret<S: Serializer>(
82 secret: &Option<ClientSecret>,
83 serializer: S,
84) -> Result<S::Ok, S::Error> {
85 match secret {
86 Some(secret) => serializer.serialize_some(secret.expose_secret()),
87 None => serializer.serialize_none(),
88 }
89}
90
91fn deserialize_client_secret<'de, D: Deserializer<'de>>(
92 deserializer: D,
93) -> Result<Option<ClientSecret>, D::Error> {
94 Option::<String>::deserialize(deserializer)?
95 .map(|value| ClientSecret::new(value).map_err(serde::de::Error::custom))
96 .transpose()
97}
98
99#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
100pub struct OAuth2Tokens {
101 pub token_type: Option<String>,
102 pub access_token: Option<String>,
103 pub refresh_token: Option<String>,
104 pub access_token_expires_at: Option<OffsetDateTime>,
105 pub refresh_token_expires_at: Option<OffsetDateTime>,
106 pub scopes: Vec<String>,
107 pub id_token: Option<String>,
108 pub raw: Value,
109}
110
111impl Default for OAuth2Tokens {
112 fn default() -> Self {
113 Self {
114 token_type: None,
115 access_token: None,
116 refresh_token: None,
117 access_token_expires_at: None,
118 refresh_token_expires_at: None,
119 scopes: Vec::new(),
120 id_token: None,
121 raw: Value::Null,
122 }
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub struct OAuth2UserInfo {
128 pub id: String,
129 pub name: Option<String>,
130 pub email: Option<String>,
131 pub image: Option<String>,
132 pub email_verified: bool,
133}
134
135pub fn get_primary_client_id(client_id: &Option<ClientId>) -> Option<&str> {
136 client_id.as_ref().and_then(ClientId::primary)
137}
138
139pub fn get_oauth2_tokens(data: Value) -> Result<OAuth2Tokens, OAuthError> {
140 let object = data.as_object().ok_or_else(|| {
141 OAuthError::InvalidTokenResponse("token response must be a JSON object".to_owned())
142 })?;
143 let now = OffsetDateTime::now_utc();
144 let access_token = optional_string_field(object, "access_token")?;
145 let refresh_token = optional_string_field(object, "refresh_token")?;
146 let id_token = optional_string_field(object, "id_token")?;
147 if access_token.is_none() && refresh_token.is_none() && id_token.is_none() {
148 return Err(OAuthError::InvalidTokenResponse(
149 "token response must include access_token, refresh_token, or id_token".to_owned(),
150 ));
151 }
152
153 Ok(OAuth2Tokens {
154 token_type: optional_string_field(object, "token_type")?,
155 access_token,
156 refresh_token,
157 access_token_expires_at: expires_at(object, "expires_in", now)?,
158 refresh_token_expires_at: expires_at(object, "refresh_token_expires_in", now)?,
159 scopes: scopes_field(object.get("scope"))?,
160 id_token,
161 raw: data,
162 })
163}
164
165fn optional_string_field(
166 object: &serde_json::Map<String, Value>,
167 key: &'static str,
168) -> Result<Option<String>, OAuthError> {
169 match object.get(key) {
170 Some(Value::String(value)) => Ok(Some(value.clone())),
171 Some(_) => Err(OAuthError::InvalidTokenResponse(format!(
172 "`{key}` must be a string"
173 ))),
174 None => Ok(None),
175 }
176}
177
178fn expires_at(
179 object: &serde_json::Map<String, Value>,
180 key: &'static str,
181 now: OffsetDateTime,
182) -> Result<Option<OffsetDateTime>, OAuthError> {
183 let Some(value) = object.get(key) else {
184 return Ok(None);
185 };
186 let seconds = value.as_i64().ok_or_else(|| {
187 OAuthError::InvalidTokenResponse(format!("`{key}` must be an integer number of seconds"))
188 })?;
189 if !(0..=MAX_TOKEN_EXPIRY_SECONDS).contains(&seconds) {
190 return Err(OAuthError::InvalidTokenResponse(format!(
191 "`{key}` must be between 0 and {MAX_TOKEN_EXPIRY_SECONDS} seconds"
192 )));
193 }
194 now.checked_add(Duration::seconds(seconds))
195 .ok_or_else(|| OAuthError::InvalidTokenResponse(format!("`{key}` is out of range")))
196 .map(Some)
197}
198
199fn scopes_field(value: Option<&Value>) -> Result<Vec<String>, OAuthError> {
200 match value {
201 Some(Value::String(scope)) => Ok(scope.split_whitespace().map(str::to_owned).collect()),
202 Some(Value::Array(scopes)) => scopes
203 .iter()
204 .map(|value| {
205 value.as_str().map(str::to_owned).ok_or_else(|| {
206 OAuthError::InvalidTokenResponse(
207 "`scope` array values must be strings".to_owned(),
208 )
209 })
210 })
211 .collect(),
212 Some(_) => Err(OAuthError::InvalidTokenResponse(
213 "`scope` must be a string or string array".to_owned(),
214 )),
215 None => Ok(Vec::new()),
216 }
217}