Skip to main content

rustauth_sso/options/
callbacks.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use rustauth_core::db::User;
6use rustauth_core::error::RustAuthError;
7use rustauth_oauth::oauth2::OAuth2Tokens;
8
9use crate::linking_impl::NormalizedSsoProfile;
10use crate::store::SsoProviderRecord;
11
12type TxtResolverFuture = Pin<Box<dyn Future<Output = Result<Vec<String>, RustAuthError>> + Send>>;
13type ProvidersLimitFuture = Pin<Box<dyn Future<Output = Result<usize, RustAuthError>> + Send>>;
14type OrganizationRoleFuture = Pin<Box<dyn Future<Output = Result<String, RustAuthError>> + Send>>;
15type ProvisionUserFuture = Pin<Box<dyn Future<Output = Result<(), RustAuthError>> + Send>>;
16
17#[derive(Clone)]
18/// Async resolver used to verify domain ownership through DNS TXT records.
19pub struct DnsTxtResolver {
20    resolver: Arc<dyn Fn(String) -> TxtResolverFuture + Send + Sync>,
21}
22
23impl DnsTxtResolver {
24    /// Create a resolver from an async function receiving the DNS name to query.
25    pub fn new<F, Fut>(resolver: F) -> Self
26    where
27        F: Fn(String) -> Fut + Send + Sync + 'static,
28        Fut: Future<Output = Result<Vec<String>, RustAuthError>> + Send + 'static,
29    {
30        Self {
31            resolver: Arc::new(move |name| Box::pin(resolver(name))),
32        }
33    }
34
35    /// Resolve TXT values for the provided DNS name.
36    pub async fn resolve(&self, name: &str) -> Result<Vec<String>, RustAuthError> {
37        (self.resolver)(name.to_owned()).await
38    }
39}
40
41impl std::fmt::Debug for DnsTxtResolver {
42    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        formatter.write_str("DnsTxtResolver(..)")
44    }
45}
46
47impl PartialEq for DnsTxtResolver {
48    fn eq(&self, _other: &Self) -> bool {
49        true
50    }
51}
52
53impl Eq for DnsTxtResolver {}
54
55#[derive(Clone)]
56/// Async resolver used to compute a per-user dynamic provider limit.
57pub struct ProvidersLimitResolver {
58    resolver: Arc<dyn Fn(User) -> ProvidersLimitFuture + Send + Sync>,
59}
60
61impl ProvidersLimitResolver {
62    /// Create a provider-limit resolver from an async function.
63    pub fn new<F, Fut>(resolver: F) -> Self
64    where
65        F: Fn(User) -> Fut + Send + Sync + 'static,
66        Fut: Future<Output = Result<usize, RustAuthError>> + Send + 'static,
67    {
68        Self {
69            resolver: Arc::new(move |user| Box::pin(resolver(user))),
70        }
71    }
72
73    /// Resolve the maximum number of providers the user may register.
74    pub async fn resolve(&self, user: User) -> Result<usize, RustAuthError> {
75        (self.resolver)(user).await
76    }
77}
78
79impl std::fmt::Debug for ProvidersLimitResolver {
80    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        formatter.write_str("ProvidersLimitResolver(..)")
82    }
83}
84
85impl PartialEq for ProvidersLimitResolver {
86    fn eq(&self, _other: &Self) -> bool {
87        true
88    }
89}
90
91impl Eq for ProvidersLimitResolver {}
92
93#[derive(Debug, Clone, PartialEq)]
94/// Input passed to organization role resolution after a successful SSO login.
95pub struct OrganizationRoleInput {
96    /// User created or linked by the SSO flow.
97    pub user: User,
98    /// Normalized profile extracted from OIDC UserInfo or SAML attributes.
99    pub profile: NormalizedSsoProfile,
100    /// SSO provider that authenticated the user.
101    pub provider: SsoProviderRecord,
102    /// OAuth tokens for OIDC flows; `None` for SAML flows.
103    pub token: Option<OAuth2Tokens>,
104}
105
106#[derive(Clone)]
107/// Async callback that maps an SSO login to an organization role.
108pub struct OrganizationRoleResolver {
109    resolver: Arc<dyn Fn(OrganizationRoleInput) -> OrganizationRoleFuture + Send + Sync>,
110}
111
112impl OrganizationRoleResolver {
113    /// Create a role resolver from an async function.
114    pub fn new<F, Fut>(resolver: F) -> Self
115    where
116        F: Fn(OrganizationRoleInput) -> Fut + Send + Sync + 'static,
117        Fut: Future<Output = Result<String, RustAuthError>> + Send + 'static,
118    {
119        Self {
120            resolver: Arc::new(move |input| Box::pin(resolver(input))),
121        }
122    }
123
124    /// Resolve the organization role for the login.
125    pub async fn resolve(&self, input: OrganizationRoleInput) -> Result<String, RustAuthError> {
126        (self.resolver)(input).await
127    }
128}
129
130impl std::fmt::Debug for OrganizationRoleResolver {
131    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        formatter.write_str("OrganizationRoleResolver(..)")
133    }
134}
135
136impl PartialEq for OrganizationRoleResolver {
137    fn eq(&self, _other: &Self) -> bool {
138        true
139    }
140}
141
142impl Eq for OrganizationRoleResolver {}
143
144#[derive(Debug, Clone, PartialEq)]
145/// Input passed to the `provision_user` hook.
146pub struct ProvisionUserInput {
147    /// User created or linked by the SSO flow.
148    pub user: User,
149    /// Normalized identity profile from the identity provider.
150    pub profile: NormalizedSsoProfile,
151    /// SSO provider that authenticated the user.
152    pub provider: SsoProviderRecord,
153    /// OAuth tokens for OIDC flows; `None` for SAML flows.
154    pub token: Option<OAuth2Tokens>,
155    /// Whether this login came from an explicit SSO registration request.
156    pub is_register: bool,
157}
158
159#[derive(Clone)]
160/// Async hook invoked after an SSO user is created or linked.
161pub struct ProvisionUserResolver {
162    resolver: Arc<dyn Fn(ProvisionUserInput) -> ProvisionUserFuture + Send + Sync>,
163}
164
165impl ProvisionUserResolver {
166    /// Create a provisioning resolver from an async function.
167    pub fn new<F, Fut>(resolver: F) -> Self
168    where
169        F: Fn(ProvisionUserInput) -> Fut + Send + Sync + 'static,
170        Fut: Future<Output = Result<(), RustAuthError>> + Send + 'static,
171    {
172        Self {
173            resolver: Arc::new(move |input| Box::pin(resolver(input))),
174        }
175    }
176
177    /// Run user provisioning for the completed SSO login.
178    pub async fn resolve(&self, input: ProvisionUserInput) -> Result<(), RustAuthError> {
179        (self.resolver)(input).await
180    }
181}
182
183impl std::fmt::Debug for ProvisionUserResolver {
184    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        formatter.write_str("ProvisionUserResolver(..)")
186    }
187}
188
189impl PartialEq for ProvisionUserResolver {
190    fn eq(&self, _other: &Self) -> bool {
191        true
192    }
193}
194
195impl Eq for ProvisionUserResolver {}