Skip to main content

rustauth_core/options/
account.rs

1use std::fmt;
2use std::sync::Arc;
3
4use crate::api::ApiRequest;
5use crate::error::RustAuthError;
6
7use super::model_schema::ModelSchemaOptions;
8
9pub trait TrustedProvidersProvider: Send + Sync + 'static {
10    fn trusted_providers(&self) -> Result<Vec<String>, RustAuthError>;
11}
12
13impl<F> TrustedProvidersProvider for F
14where
15    F: Fn() -> Result<Vec<String>, RustAuthError> + Send + Sync + 'static,
16{
17    fn trusted_providers(&self) -> Result<Vec<String>, RustAuthError> {
18        self()
19    }
20}
21
22pub trait TrustedProvidersRequestProvider: Send + Sync + 'static {
23    fn trusted_providers_for_request(
24        &self,
25        request: Option<&ApiRequest>,
26    ) -> Result<Vec<String>, RustAuthError>;
27}
28
29impl<F> TrustedProvidersRequestProvider for F
30where
31    F: for<'a> Fn(Option<&'a ApiRequest>) -> Result<Vec<String>, RustAuthError>
32        + Send
33        + Sync
34        + 'static,
35{
36    fn trusted_providers_for_request(
37        &self,
38        request: Option<&ApiRequest>,
39    ) -> Result<Vec<String>, RustAuthError> {
40        self(request)
41    }
42}
43
44/// Account and OAuth account behavior.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct AccountOptions {
47    pub schema: ModelSchemaOptions,
48    pub update_account_on_sign_in: bool,
49    pub encrypt_oauth_tokens: bool,
50    pub store_account_cookie: bool,
51    pub store_state_strategy: OAuthStateStoreStrategy,
52    pub skip_state_cookie_check: bool,
53    pub account_linking: AccountLinkingOptions,
54}
55
56impl Default for AccountOptions {
57    fn default() -> Self {
58        Self {
59            schema: ModelSchemaOptions::default(),
60            update_account_on_sign_in: true,
61            encrypt_oauth_tokens: false,
62            store_account_cookie: false,
63            store_state_strategy: OAuthStateStoreStrategy::Cookie,
64            skip_state_cookie_check: false,
65            account_linking: AccountLinkingOptions::default(),
66        }
67    }
68}
69
70impl AccountOptions {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    #[must_use]
76    pub fn schema(mut self, schema: ModelSchemaOptions) -> Self {
77        self.schema = schema;
78        self
79    }
80
81    #[must_use]
82    pub fn update_account_on_sign_in(mut self, enabled: bool) -> Self {
83        self.update_account_on_sign_in = enabled;
84        self
85    }
86
87    #[must_use]
88    pub fn encrypt_oauth_tokens(mut self, enabled: bool) -> Self {
89        self.encrypt_oauth_tokens = enabled;
90        self
91    }
92
93    #[must_use]
94    pub fn store_account_cookie(mut self, enabled: bool) -> Self {
95        self.store_account_cookie = enabled;
96        self
97    }
98
99    #[must_use]
100    pub fn store_state_strategy(mut self, strategy: OAuthStateStoreStrategy) -> Self {
101        self.store_state_strategy = strategy;
102        self
103    }
104
105    #[must_use]
106    pub fn skip_state_cookie_check(mut self, skip: bool) -> Self {
107        self.skip_state_cookie_check = skip;
108        self
109    }
110
111    #[must_use]
112    pub fn account_linking(mut self, account_linking: AccountLinkingOptions) -> Self {
113        self.account_linking = account_linking;
114        self
115    }
116}
117
118/// Where the OAuth `state` (and the PKCE verifier / OIDC nonce it carries) is
119/// persisted between the authorization redirect and the callback.
120///
121/// Both strategies enforce single-use semantics: the `state` is consumed on the
122/// first successful callback, so a captured value cannot be replayed within its
123/// TTL. `Cookie` keeps the payload in an encrypted, client-held value and binds
124/// it to a short server-side single-use marker; `Database` stores the full
125/// payload server-side and deletes it on first use.
126#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
127pub enum OAuthStateStoreStrategy {
128    #[default]
129    Cookie,
130    Database,
131}
132
133#[derive(Clone)]
134pub struct AccountLinkingOptions {
135    pub enabled: bool,
136    pub disable_implicit_linking: bool,
137    pub trusted_providers: Vec<String>,
138    pub trusted_providers_provider: Option<Arc<dyn TrustedProvidersProvider>>,
139    pub trusted_providers_request_provider: Option<Arc<dyn TrustedProvidersRequestProvider>>,
140    pub allow_different_emails: bool,
141    pub allow_unlinking_all: bool,
142    pub update_user_info_on_link: bool,
143}
144
145impl Default for AccountLinkingOptions {
146    fn default() -> Self {
147        Self {
148            enabled: true,
149            disable_implicit_linking: false,
150            trusted_providers: Vec::new(),
151            trusted_providers_provider: None,
152            trusted_providers_request_provider: None,
153            allow_different_emails: false,
154            allow_unlinking_all: false,
155            update_user_info_on_link: false,
156        }
157    }
158}
159
160impl AccountLinkingOptions {
161    pub fn new() -> Self {
162        Self::default()
163    }
164
165    #[must_use]
166    pub fn enabled(mut self, enabled: bool) -> Self {
167        self.enabled = enabled;
168        self
169    }
170
171    #[must_use]
172    pub fn disable_implicit_linking(mut self, enabled: bool) -> Self {
173        self.disable_implicit_linking = enabled;
174        self
175    }
176
177    #[must_use]
178    pub fn trusted_provider(mut self, provider: impl Into<String>) -> Self {
179        self.trusted_providers.push(provider.into());
180        self
181    }
182
183    #[must_use]
184    pub fn trusted_providers<I, S>(mut self, providers: I) -> Self
185    where
186        I: IntoIterator<Item = S>,
187        S: Into<String>,
188    {
189        self.trusted_providers
190            .extend(providers.into_iter().map(Into::into));
191        self
192    }
193
194    #[must_use]
195    pub fn trusted_providers_provider<P>(mut self, provider: P) -> Self
196    where
197        P: TrustedProvidersProvider,
198    {
199        self.trusted_providers_provider = Some(Arc::new(provider));
200        self
201    }
202
203    #[must_use]
204    pub fn trusted_providers_for_request_provider<P>(mut self, provider: P) -> Self
205    where
206        P: TrustedProvidersRequestProvider,
207    {
208        self.trusted_providers_request_provider = Some(Arc::new(provider));
209        self
210    }
211
212    #[must_use]
213    pub fn allow_different_emails(mut self, enabled: bool) -> Self {
214        self.allow_different_emails = enabled;
215        self
216    }
217
218    #[must_use]
219    pub fn allow_unlinking_all(mut self, enabled: bool) -> Self {
220        self.allow_unlinking_all = enabled;
221        self
222    }
223
224    #[must_use]
225    pub fn update_user_info_on_link(mut self, enabled: bool) -> Self {
226        self.update_user_info_on_link = enabled;
227        self
228    }
229}
230
231impl fmt::Debug for AccountLinkingOptions {
232    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
233        formatter
234            .debug_struct("AccountLinkingOptions")
235            .field("enabled", &self.enabled)
236            .field("disable_implicit_linking", &self.disable_implicit_linking)
237            .field("trusted_providers", &self.trusted_providers)
238            .field(
239                "trusted_providers_provider",
240                &self
241                    .trusted_providers_provider
242                    .as_ref()
243                    .map(|_| "<dynamic>"),
244            )
245            .field(
246                "trusted_providers_request_provider",
247                &self
248                    .trusted_providers_request_provider
249                    .as_ref()
250                    .map(|_| "<request-dynamic>"),
251            )
252            .field("allow_different_emails", &self.allow_different_emails)
253            .field("allow_unlinking_all", &self.allow_unlinking_all)
254            .field("update_user_info_on_link", &self.update_user_info_on_link)
255            .finish()
256    }
257}
258
259impl PartialEq for AccountLinkingOptions {
260    fn eq(&self, other: &Self) -> bool {
261        self.enabled == other.enabled
262            && self.disable_implicit_linking == other.disable_implicit_linking
263            && self.trusted_providers == other.trusted_providers
264            && self.trusted_providers_provider.is_some()
265                == other.trusted_providers_provider.is_some()
266            && self.trusted_providers_request_provider.is_some()
267                == other.trusted_providers_request_provider.is_some()
268            && self.allow_different_emails == other.allow_different_emails
269            && self.allow_unlinking_all == other.allow_unlinking_all
270            && self.update_user_info_on_link == other.update_user_info_on_link
271    }
272}
273
274impl Eq for AccountLinkingOptions {}