1use rustauth_oauth::oauth2::{
2 ClientAuthentication, ClientId, ClientSecret, OAuth2Tokens, OAuth2UserInfo, OAuthError,
3 OAuthHttpClient, ProviderOptions, SocialIdTokenRequest,
4};
5use serde_json::{json, Value};
6use std::collections::BTreeMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11pub type GenericOAuthTokenFuture =
12 Pin<Box<dyn Future<Output = Result<OAuth2Tokens, OAuthError>> + Send>>;
13pub type GenericOAuthGetToken =
14 Arc<dyn Fn(GenericOAuthTokenRequest) -> GenericOAuthTokenFuture + Send + Sync>;
15pub type GenericOAuthUserInfoFuture =
16 Pin<Box<dyn Future<Output = Result<Option<OAuth2UserInfo>, OAuthError>> + Send>>;
17pub type GenericOAuthGetUserInfo =
18 Arc<dyn Fn(OAuth2Tokens) -> GenericOAuthUserInfoFuture + Send + Sync>;
19pub type GenericOAuthMapProfileFuture =
20 Pin<Box<dyn Future<Output = Result<OAuth2UserInfo, OAuthError>> + Send>>;
21pub type GenericOAuthMapProfileToUser =
22 Arc<dyn Fn(OAuth2UserInfo) -> GenericOAuthMapProfileFuture + Send + Sync>;
23pub type GenericOAuthRefreshAccessToken =
24 Arc<dyn Fn(String) -> GenericOAuthTokenFuture + Send + Sync>;
25pub type GenericOAuthVerifyIdTokenFuture =
26 Pin<Box<dyn Future<Output = Result<bool, OAuthError>> + Send>>;
27pub type GenericOAuthVerifyIdToken =
28 Arc<dyn Fn(SocialIdTokenRequest) -> GenericOAuthVerifyIdTokenFuture + Send + Sync>;
29pub type GenericOAuthRevokeTokenFuture =
30 Pin<Box<dyn Future<Output = Result<(), OAuthError>> + Send>>;
31pub type GenericOAuthRevokeToken =
32 Arc<dyn Fn(String) -> GenericOAuthRevokeTokenFuture + Send + Sync>;
33pub type GenericOAuthParams = BTreeMap<String, String>;
34pub type GenericOAuthParamsFuture =
35 Pin<Box<dyn Future<Output = Result<GenericOAuthParams, OAuthError>> + Send>>;
36pub type GenericOAuthParamsCallback =
37 Arc<dyn Fn(GenericOAuthParamsContext) -> GenericOAuthParamsFuture + Send + Sync>;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum GenericOAuthFlow {
41 SignIn,
42 Link,
43 Callback,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct GenericOAuthParamsContext {
48 pub provider_id: String,
49 pub flow: GenericOAuthFlow,
50 pub redirect_uri: String,
51}
52
53#[derive(Debug, Clone, Default, PartialEq, Eq)]
54pub struct GenericOAuthTokenRequest {
55 pub code: String,
56 pub redirect_uri: String,
57 pub code_verifier: Option<String>,
58 pub device_id: Option<String>,
59}
60
61#[derive(Clone, Default)]
62pub struct GenericOAuthOptions {
63 pub config: Vec<GenericOAuthConfig>,
64}
65
66impl std::fmt::Debug for GenericOAuthOptions {
67 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 formatter
69 .debug_struct("GenericOAuthOptions")
70 .field("config", &self.config)
71 .finish()
72 }
73}
74
75impl GenericOAuthOptions {
76 #[must_use]
77 pub fn builder() -> GenericOAuthOptionsBuilder {
78 GenericOAuthOptionsBuilder::default()
79 }
80
81 pub(crate) fn to_json(&self) -> Value {
82 json!({
83 "config": self.config.iter().map(GenericOAuthConfig::public_json).collect::<Vec<_>>(),
84 })
85 }
86
87 pub(crate) fn find(&self, provider_id: &str) -> Option<&GenericOAuthConfig> {
88 self.config
89 .iter()
90 .find(|config| config.provider_id == provider_id)
91 }
92}
93
94#[derive(Clone)]
95pub struct GenericOAuthConfig {
96 pub provider_id: String,
97 pub discovery_url: Option<String>,
98 pub issuer: Option<String>,
99 pub require_issuer_validation: bool,
100 pub authorization_url: Option<String>,
101 pub token_url: Option<String>,
102 pub user_info_url: Option<String>,
103 pub client_id: String,
104 pub client_secret: Option<String>,
105 pub scopes: Vec<String>,
106 pub redirect_uri: Option<String>,
107 pub response_type: Option<String>,
108 pub response_mode: Option<String>,
109 pub prompt: Option<String>,
110 pub pkce: bool,
111 pub access_type: Option<String>,
112 pub authorization_url_params: BTreeMap<String, String>,
113 pub token_url_params: BTreeMap<String, String>,
114 pub authorization_url_params_callback: Option<GenericOAuthParamsCallback>,
115 pub token_url_params_callback: Option<GenericOAuthParamsCallback>,
116 pub disable_implicit_sign_up: bool,
117 pub disable_sign_up: bool,
118 pub authentication: ClientAuthentication,
119 pub discovery_headers: BTreeMap<String, String>,
120 pub authorization_headers: BTreeMap<String, String>,
121 pub override_user_info: bool,
122 pub get_token: Option<GenericOAuthGetToken>,
123 pub get_user_info: Option<GenericOAuthGetUserInfo>,
124 pub map_profile_to_user: Option<GenericOAuthMapProfileToUser>,
125 pub refresh_access_token: Option<GenericOAuthRefreshAccessToken>,
126 pub verify_id_token: Option<GenericOAuthVerifyIdToken>,
127 pub revoke_token: Option<GenericOAuthRevokeToken>,
128 pub http_client: Option<OAuthHttpClient>,
131}
132
133impl std::fmt::Debug for GenericOAuthConfig {
134 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 formatter
136 .debug_struct("GenericOAuthConfig")
137 .field("provider_id", &self.provider_id)
138 .field("discovery_url", &self.discovery_url)
139 .field("issuer", &self.issuer)
140 .field("require_issuer_validation", &self.require_issuer_validation)
141 .field("authorization_url", &self.authorization_url)
142 .field("token_url", &self.token_url)
143 .field("user_info_url", &self.user_info_url)
144 .field("client_id", &self.client_id)
145 .field(
146 "client_secret",
147 &self.client_secret.as_ref().map(|_| "<redacted>"),
148 )
149 .field("scopes", &self.scopes)
150 .field("redirect_uri", &self.redirect_uri)
151 .field("response_type", &self.response_type)
152 .field("response_mode", &self.response_mode)
153 .field("prompt", &self.prompt)
154 .field("pkce", &self.pkce)
155 .field("access_type", &self.access_type)
156 .field("authorization_url_params", &self.authorization_url_params)
157 .field("token_url_params", &self.token_url_params)
158 .field(
159 "authorization_url_params_callback",
160 &self.authorization_url_params_callback.is_some(),
161 )
162 .field(
163 "token_url_params_callback",
164 &self.token_url_params_callback.is_some(),
165 )
166 .field("disable_implicit_sign_up", &self.disable_implicit_sign_up)
167 .field("disable_sign_up", &self.disable_sign_up)
168 .field("authentication", &self.authentication)
169 .field("discovery_headers", &self.discovery_headers)
170 .field("authorization_headers", &self.authorization_headers)
171 .field("override_user_info", &self.override_user_info)
172 .field("get_token", &self.get_token.is_some())
173 .field("get_user_info", &self.get_user_info.is_some())
174 .field("map_profile_to_user", &self.map_profile_to_user.is_some())
175 .field("refresh_access_token", &self.refresh_access_token.is_some())
176 .field("verify_id_token", &self.verify_id_token.is_some())
177 .field("revoke_token", &self.revoke_token.is_some())
178 .field("http_client", &self.http_client.is_some())
179 .finish()
180 }
181}
182
183impl GenericOAuthConfig {
184 pub fn new(
185 provider_id: impl Into<String>,
186 client_id: impl Into<String>,
187 client_secret: Option<impl Into<String>>,
188 authorization_url: impl Into<String>,
189 token_url: impl Into<String>,
190 ) -> Self {
191 Self {
192 provider_id: provider_id.into(),
193 client_id: client_id.into(),
194 client_secret: client_secret.map(Into::into),
195 authorization_url: Some(authorization_url.into()),
196 token_url: Some(token_url.into()),
197 ..Self::default()
198 }
199 }
200
201 pub fn discovery(
202 provider_id: impl Into<String>,
203 client_id: impl Into<String>,
204 client_secret: Option<impl Into<String>>,
205 discovery_url: impl Into<String>,
206 ) -> Self {
207 Self {
208 provider_id: provider_id.into(),
209 client_id: client_id.into(),
210 client_secret: client_secret.map(Into::into),
211 discovery_url: Some(discovery_url.into()),
212 ..Self::default()
213 }
214 }
215
216 pub(crate) fn provider_options(&self) -> ProviderOptions {
217 ProviderOptions {
218 client_id: Some(ClientId::Single(self.client_id.clone())),
219 client_secret: self
220 .client_secret
221 .as_ref()
222 .and_then(|secret| ClientSecret::new(secret.clone()).ok()),
223 scope: self.scopes.clone(),
224 redirect_uri: self.redirect_uri.clone(),
225 authorization_endpoint: self.authorization_url.clone(),
226 disable_implicit_sign_up: self.disable_implicit_sign_up,
227 disable_sign_up: self.disable_sign_up,
228 prompt: self.prompt.clone(),
229 response_mode: self.response_mode.clone(),
230 override_user_info_on_sign_in: self.override_user_info,
231 ..ProviderOptions::default()
232 }
233 }
234
235 pub(crate) fn scopes(&self, request_scopes: Vec<String>) -> Vec<String> {
236 if request_scopes.is_empty() {
237 return self.scopes.clone();
238 }
239 let mut scopes = request_scopes;
240 scopes.extend(self.scopes.clone());
241 scopes
242 }
243
244 fn public_json(&self) -> Value {
245 json!({
246 "providerId": self.provider_id,
247 "discoveryUrl": self.discovery_url,
248 "issuer": self.issuer,
249 "requireIssuerValidation": self.require_issuer_validation,
250 "authorizationUrl": self.authorization_url,
251 "tokenUrl": self.token_url,
252 "userInfoUrl": self.user_info_url,
253 "clientId": self.client_id,
254 "scopes": self.scopes,
255 "redirectURI": self.redirect_uri,
256 "pkce": self.pkce,
257 "disableImplicitSignUp": self.disable_implicit_sign_up,
258 "disableSignUp": self.disable_sign_up,
259 "overrideUserInfo": self.override_user_info,
260 })
261 }
262}
263
264impl Default for GenericOAuthConfig {
265 fn default() -> Self {
266 Self {
267 provider_id: String::new(),
268 discovery_url: None,
269 issuer: None,
270 require_issuer_validation: false,
271 authorization_url: None,
272 token_url: None,
273 user_info_url: None,
274 client_id: String::new(),
275 client_secret: None,
276 scopes: Vec::new(),
277 redirect_uri: None,
278 response_type: None,
279 response_mode: None,
280 prompt: None,
281 pkce: false,
282 access_type: None,
283 authorization_url_params: BTreeMap::new(),
284 token_url_params: BTreeMap::new(),
285 authorization_url_params_callback: None,
286 token_url_params_callback: None,
287 disable_implicit_sign_up: false,
288 disable_sign_up: false,
289 authentication: ClientAuthentication::Post,
290 discovery_headers: BTreeMap::new(),
291 authorization_headers: BTreeMap::new(),
292 override_user_info: false,
293 get_token: None,
294 get_user_info: None,
295 map_profile_to_user: None,
296 refresh_access_token: None,
297 verify_id_token: None,
298 revoke_token: None,
299 http_client: None,
300 }
301 }
302}
303
304#[derive(Debug, Clone, Default)]
305pub struct GenericOAuthOptionsBuilder {
306 config: Option<Vec<GenericOAuthConfig>>,
307}
308
309impl GenericOAuthOptionsBuilder {
310 #[must_use]
311 pub fn config(mut self, config: Vec<GenericOAuthConfig>) -> Self {
312 self.config = Some(config);
313 self
314 }
315
316 #[must_use]
317 pub fn provider(mut self, provider: GenericOAuthConfig) -> Self {
318 self.config.get_or_insert_with(Vec::new).push(provider);
319 self
320 }
321
322 #[must_use]
323 pub fn build(self) -> GenericOAuthOptions {
324 let defaults = GenericOAuthOptions::default();
325 GenericOAuthOptions {
326 config: self.config.unwrap_or(defaults.config),
327 }
328 }
329}