Skip to main content

rustauth_plugins/generic_oauth/
config.rs

1use rustauth_oauth::oauth2::{
2    ClientAuthentication, ClientId, ClientSecret, OAuth2Tokens, OAuth2UserInfo, OAuthError,
3    OAuthHttpClient, ProviderOptions, SocialIdTokenRequest,
4};
5use serde_json::{json, Value};
6use std::collections::BTreeMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11pub type GenericOAuthTokenFuture =
12    Pin<Box<dyn Future<Output = Result<OAuth2Tokens, OAuthError>> + Send>>;
13pub type GenericOAuthGetToken =
14    Arc<dyn Fn(GenericOAuthTokenRequest) -> GenericOAuthTokenFuture + Send + Sync>;
15pub type GenericOAuthUserInfoFuture =
16    Pin<Box<dyn Future<Output = Result<Option<OAuth2UserInfo>, OAuthError>> + Send>>;
17pub type GenericOAuthGetUserInfo =
18    Arc<dyn Fn(OAuth2Tokens) -> GenericOAuthUserInfoFuture + Send + Sync>;
19pub type GenericOAuthMapProfileFuture =
20    Pin<Box<dyn Future<Output = Result<OAuth2UserInfo, OAuthError>> + Send>>;
21pub type GenericOAuthMapProfileToUser =
22    Arc<dyn Fn(OAuth2UserInfo) -> GenericOAuthMapProfileFuture + Send + Sync>;
23pub type GenericOAuthRefreshAccessToken =
24    Arc<dyn Fn(String) -> GenericOAuthTokenFuture + Send + Sync>;
25pub type GenericOAuthVerifyIdTokenFuture =
26    Pin<Box<dyn Future<Output = Result<bool, OAuthError>> + Send>>;
27pub type GenericOAuthVerifyIdToken =
28    Arc<dyn Fn(SocialIdTokenRequest) -> GenericOAuthVerifyIdTokenFuture + Send + Sync>;
29pub type GenericOAuthRevokeTokenFuture =
30    Pin<Box<dyn Future<Output = Result<(), OAuthError>> + Send>>;
31pub type GenericOAuthRevokeToken =
32    Arc<dyn Fn(String) -> GenericOAuthRevokeTokenFuture + Send + Sync>;
33pub type GenericOAuthParams = BTreeMap<String, String>;
34pub type GenericOAuthParamsFuture =
35    Pin<Box<dyn Future<Output = Result<GenericOAuthParams, OAuthError>> + Send>>;
36pub type GenericOAuthParamsCallback =
37    Arc<dyn Fn(GenericOAuthParamsContext) -> GenericOAuthParamsFuture + Send + Sync>;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum GenericOAuthFlow {
41    SignIn,
42    Link,
43    Callback,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct GenericOAuthParamsContext {
48    pub provider_id: String,
49    pub flow: GenericOAuthFlow,
50    pub redirect_uri: String,
51}
52
53#[derive(Debug, Clone, Default, PartialEq, Eq)]
54pub struct GenericOAuthTokenRequest {
55    pub code: String,
56    pub redirect_uri: String,
57    pub code_verifier: Option<String>,
58    pub device_id: Option<String>,
59}
60
61#[derive(Clone, Default)]
62pub struct GenericOAuthOptions {
63    pub config: Vec<GenericOAuthConfig>,
64}
65
66impl std::fmt::Debug for GenericOAuthOptions {
67    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        formatter
69            .debug_struct("GenericOAuthOptions")
70            .field("config", &self.config)
71            .finish()
72    }
73}
74
75impl GenericOAuthOptions {
76    #[must_use]
77    pub fn builder() -> GenericOAuthOptionsBuilder {
78        GenericOAuthOptionsBuilder::default()
79    }
80
81    pub(crate) fn to_json(&self) -> Value {
82        json!({
83            "config": self.config.iter().map(GenericOAuthConfig::public_json).collect::<Vec<_>>(),
84        })
85    }
86
87    pub(crate) fn find(&self, provider_id: &str) -> Option<&GenericOAuthConfig> {
88        self.config
89            .iter()
90            .find(|config| config.provider_id == provider_id)
91    }
92}
93
94#[derive(Clone)]
95pub struct GenericOAuthConfig {
96    pub provider_id: String,
97    pub discovery_url: Option<String>,
98    pub issuer: Option<String>,
99    pub require_issuer_validation: bool,
100    pub authorization_url: Option<String>,
101    pub token_url: Option<String>,
102    pub user_info_url: Option<String>,
103    pub client_id: String,
104    pub client_secret: Option<String>,
105    pub scopes: Vec<String>,
106    pub redirect_uri: Option<String>,
107    pub response_type: Option<String>,
108    pub response_mode: Option<String>,
109    pub prompt: Option<String>,
110    pub pkce: bool,
111    pub access_type: Option<String>,
112    pub authorization_url_params: BTreeMap<String, String>,
113    pub token_url_params: BTreeMap<String, String>,
114    pub authorization_url_params_callback: Option<GenericOAuthParamsCallback>,
115    pub token_url_params_callback: Option<GenericOAuthParamsCallback>,
116    pub disable_implicit_sign_up: bool,
117    pub disable_sign_up: bool,
118    pub authentication: ClientAuthentication,
119    pub discovery_headers: BTreeMap<String, String>,
120    pub authorization_headers: BTreeMap<String, String>,
121    pub override_user_info: bool,
122    pub get_token: Option<GenericOAuthGetToken>,
123    pub get_user_info: Option<GenericOAuthGetUserInfo>,
124    pub map_profile_to_user: Option<GenericOAuthMapProfileToUser>,
125    pub refresh_access_token: Option<GenericOAuthRefreshAccessToken>,
126    pub verify_id_token: Option<GenericOAuthVerifyIdToken>,
127    pub revoke_token: Option<GenericOAuthRevokeToken>,
128    /// Optional outbound HTTP client for discovery, token, and userinfo requests.
129    /// When unset, the SSRF-guarded default client is used.
130    pub http_client: Option<OAuthHttpClient>,
131}
132
133impl std::fmt::Debug for GenericOAuthConfig {
134    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        formatter
136            .debug_struct("GenericOAuthConfig")
137            .field("provider_id", &self.provider_id)
138            .field("discovery_url", &self.discovery_url)
139            .field("issuer", &self.issuer)
140            .field("require_issuer_validation", &self.require_issuer_validation)
141            .field("authorization_url", &self.authorization_url)
142            .field("token_url", &self.token_url)
143            .field("user_info_url", &self.user_info_url)
144            .field("client_id", &self.client_id)
145            .field(
146                "client_secret",
147                &self.client_secret.as_ref().map(|_| "<redacted>"),
148            )
149            .field("scopes", &self.scopes)
150            .field("redirect_uri", &self.redirect_uri)
151            .field("response_type", &self.response_type)
152            .field("response_mode", &self.response_mode)
153            .field("prompt", &self.prompt)
154            .field("pkce", &self.pkce)
155            .field("access_type", &self.access_type)
156            .field("authorization_url_params", &self.authorization_url_params)
157            .field("token_url_params", &self.token_url_params)
158            .field(
159                "authorization_url_params_callback",
160                &self.authorization_url_params_callback.is_some(),
161            )
162            .field(
163                "token_url_params_callback",
164                &self.token_url_params_callback.is_some(),
165            )
166            .field("disable_implicit_sign_up", &self.disable_implicit_sign_up)
167            .field("disable_sign_up", &self.disable_sign_up)
168            .field("authentication", &self.authentication)
169            .field("discovery_headers", &self.discovery_headers)
170            .field("authorization_headers", &self.authorization_headers)
171            .field("override_user_info", &self.override_user_info)
172            .field("get_token", &self.get_token.is_some())
173            .field("get_user_info", &self.get_user_info.is_some())
174            .field("map_profile_to_user", &self.map_profile_to_user.is_some())
175            .field("refresh_access_token", &self.refresh_access_token.is_some())
176            .field("verify_id_token", &self.verify_id_token.is_some())
177            .field("revoke_token", &self.revoke_token.is_some())
178            .field("http_client", &self.http_client.is_some())
179            .finish()
180    }
181}
182
183impl GenericOAuthConfig {
184    pub fn new(
185        provider_id: impl Into<String>,
186        client_id: impl Into<String>,
187        client_secret: Option<impl Into<String>>,
188        authorization_url: impl Into<String>,
189        token_url: impl Into<String>,
190    ) -> Self {
191        Self {
192            provider_id: provider_id.into(),
193            client_id: client_id.into(),
194            client_secret: client_secret.map(Into::into),
195            authorization_url: Some(authorization_url.into()),
196            token_url: Some(token_url.into()),
197            ..Self::default()
198        }
199    }
200
201    pub fn discovery(
202        provider_id: impl Into<String>,
203        client_id: impl Into<String>,
204        client_secret: Option<impl Into<String>>,
205        discovery_url: impl Into<String>,
206    ) -> Self {
207        Self {
208            provider_id: provider_id.into(),
209            client_id: client_id.into(),
210            client_secret: client_secret.map(Into::into),
211            discovery_url: Some(discovery_url.into()),
212            ..Self::default()
213        }
214    }
215
216    pub(crate) fn provider_options(&self) -> ProviderOptions {
217        ProviderOptions {
218            client_id: Some(ClientId::Single(self.client_id.clone())),
219            client_secret: self
220                .client_secret
221                .as_ref()
222                .and_then(|secret| ClientSecret::new(secret.clone()).ok()),
223            scope: self.scopes.clone(),
224            redirect_uri: self.redirect_uri.clone(),
225            authorization_endpoint: self.authorization_url.clone(),
226            disable_implicit_sign_up: self.disable_implicit_sign_up,
227            disable_sign_up: self.disable_sign_up,
228            prompt: self.prompt.clone(),
229            response_mode: self.response_mode.clone(),
230            override_user_info_on_sign_in: self.override_user_info,
231            ..ProviderOptions::default()
232        }
233    }
234
235    pub(crate) fn scopes(&self, request_scopes: Vec<String>) -> Vec<String> {
236        if request_scopes.is_empty() {
237            return self.scopes.clone();
238        }
239        let mut scopes = request_scopes;
240        scopes.extend(self.scopes.clone());
241        scopes
242    }
243
244    fn public_json(&self) -> Value {
245        json!({
246            "providerId": self.provider_id,
247            "discoveryUrl": self.discovery_url,
248            "issuer": self.issuer,
249            "requireIssuerValidation": self.require_issuer_validation,
250            "authorizationUrl": self.authorization_url,
251            "tokenUrl": self.token_url,
252            "userInfoUrl": self.user_info_url,
253            "clientId": self.client_id,
254            "scopes": self.scopes,
255            "redirectURI": self.redirect_uri,
256            "pkce": self.pkce,
257            "disableImplicitSignUp": self.disable_implicit_sign_up,
258            "disableSignUp": self.disable_sign_up,
259            "overrideUserInfo": self.override_user_info,
260        })
261    }
262}
263
264impl Default for GenericOAuthConfig {
265    fn default() -> Self {
266        Self {
267            provider_id: String::new(),
268            discovery_url: None,
269            issuer: None,
270            require_issuer_validation: false,
271            authorization_url: None,
272            token_url: None,
273            user_info_url: None,
274            client_id: String::new(),
275            client_secret: None,
276            scopes: Vec::new(),
277            redirect_uri: None,
278            response_type: None,
279            response_mode: None,
280            prompt: None,
281            pkce: false,
282            access_type: None,
283            authorization_url_params: BTreeMap::new(),
284            token_url_params: BTreeMap::new(),
285            authorization_url_params_callback: None,
286            token_url_params_callback: None,
287            disable_implicit_sign_up: false,
288            disable_sign_up: false,
289            authentication: ClientAuthentication::Post,
290            discovery_headers: BTreeMap::new(),
291            authorization_headers: BTreeMap::new(),
292            override_user_info: false,
293            get_token: None,
294            get_user_info: None,
295            map_profile_to_user: None,
296            refresh_access_token: None,
297            verify_id_token: None,
298            revoke_token: None,
299            http_client: None,
300        }
301    }
302}
303
304#[derive(Debug, Clone, Default)]
305pub struct GenericOAuthOptionsBuilder {
306    config: Option<Vec<GenericOAuthConfig>>,
307}
308
309impl GenericOAuthOptionsBuilder {
310    #[must_use]
311    pub fn config(mut self, config: Vec<GenericOAuthConfig>) -> Self {
312        self.config = Some(config);
313        self
314    }
315
316    #[must_use]
317    pub fn provider(mut self, provider: GenericOAuthConfig) -> Self {
318        self.config.get_or_insert_with(Vec::new).push(provider);
319        self
320    }
321
322    #[must_use]
323    pub fn build(self) -> GenericOAuthOptions {
324        let defaults = GenericOAuthOptions::default();
325        GenericOAuthOptions {
326            config: self.config.unwrap_or(defaults.config),
327        }
328    }
329}