torii_auth_oauth/
lib.rs

1pub mod providers;
2
3use std::sync::Arc;
4
5use oauth2::TokenResponse;
6
7use providers::{Provider, UserInfo};
8use torii_core::error::AuthError;
9use torii_core::{Error, NewUser, Plugin, User, UserId, UserManager};
10use torii_core::{
11    events::{Event, EventBus},
12    storage::OAuthStorage,
13};
14
15/// A struct containing the necessary information to complete an OAuth2 authorization flow.
16///
17/// # Fields
18/// - `url`: The authorization URL that the user should be redirected to in order to authenticate
19/// - `csrf_state`: A randomly generated state value that should be stored and verified when the user returns
20///   from the authorization flow to prevent CSRF attacks
21/// - `pkce_verifier`: The PKCE verifier code that should be stored and used when exchanging the authorization
22///   code for tokens
23///
24/// # Usage
25/// 1. Generate an `AuthorizationUrl` using the OAuth provider's `get_authorization_url()` method
26/// 2. Store both the `csrf_state` and `pkce_verifier` values securely (e.g. in the user's session)
27/// 3. Redirect the user to the `url` to begin the OAuth flow
28/// 4. When the user returns to your redirect URI, verify that the state parameter matches the stored `csrf_state`
29/// 5. Use the stored `pkce_verifier` when calling the provider's `exchange_code()` method
30pub struct AuthorizationUrl {
31    /// The authorization URL to redirect the user to
32    url: String,
33    /// The CSRF state. This is typically set as a cookie in the user's browser to use when the user returns
34    /// from the authorization flow
35    csrf_state: String,
36}
37
38impl AuthorizationUrl {
39    pub fn new(url: &str, csrf_state: &str) -> Self {
40        Self {
41            url: url.to_string(),
42            csrf_state: csrf_state.to_string(),
43        }
44    }
45
46    pub fn url(&self) -> &str {
47        &self.url
48    }
49
50    pub fn csrf_state(&self) -> &str {
51        &self.csrf_state
52    }
53}
54
55pub struct OAuthPlugin<M, S>
56where
57    M: UserManager,
58    S: OAuthStorage,
59{
60    /// The provider.
61    provider: Provider,
62    /// The user manager.
63    user_manager: Arc<M>,
64    /// The OAuth storage.
65    oauth_storage: Arc<S>,
66    /// The event bus.
67    event_bus: Option<EventBus>,
68}
69
70impl<M, S> Plugin for OAuthPlugin<M, S>
71where
72    M: UserManager,
73    S: OAuthStorage,
74{
75    fn name(&self) -> String {
76        self.provider.name().to_string()
77    }
78}
79
80pub struct OAuthPluginBuilder<M, S>
81where
82    M: UserManager,
83    S: OAuthStorage,
84{
85    provider: Provider,
86    user_manager: Arc<M>,
87    oauth_storage: Arc<S>,
88    event_bus: Option<EventBus>,
89}
90
91impl<M, S> OAuthPluginBuilder<M, S>
92where
93    M: UserManager,
94    S: OAuthStorage,
95{
96    pub fn new(provider: Provider, user_manager: Arc<M>, oauth_storage: Arc<S>) -> Self {
97        Self {
98            provider,
99            user_manager,
100            oauth_storage,
101            event_bus: None,
102        }
103    }
104
105    pub fn event_bus(mut self, event_bus: EventBus) -> Self {
106        self.event_bus = Some(event_bus);
107        self
108    }
109
110    pub fn build(self) -> OAuthPlugin<M, S> {
111        OAuthPlugin {
112            provider: self.provider,
113            user_manager: self.user_manager,
114            oauth_storage: self.oauth_storage,
115            event_bus: self.event_bus,
116        }
117    }
118}
119
120impl<M, S> OAuthPlugin<M, S>
121where
122    M: UserManager,
123    S: OAuthStorage,
124{
125    pub fn builder(
126        provider: Provider,
127        user_manager: Arc<M>,
128        oauth_storage: Arc<S>,
129    ) -> OAuthPluginBuilder<M, S> {
130        OAuthPluginBuilder::new(provider, user_manager, oauth_storage)
131    }
132
133    pub fn new(provider: Provider, user_manager: Arc<M>, oauth_storage: Arc<S>) -> Self {
134        Self {
135            provider,
136            user_manager,
137            oauth_storage,
138            event_bus: None,
139        }
140    }
141
142    pub fn with_event_bus(mut self, event_bus: EventBus) -> Self {
143        self.event_bus = Some(event_bus);
144        self
145    }
146
147    /// Create a new OAuth plugin for Google
148    pub fn google(
149        client_id: &str,
150        client_secret: &str,
151        redirect_uri: &str,
152        user_manager: Arc<M>,
153        oauth_storage: Arc<S>,
154    ) -> Self {
155        OAuthPluginBuilder::new(
156            Provider::google(client_id, client_secret, redirect_uri),
157            user_manager,
158            oauth_storage,
159        )
160        .build()
161    }
162
163    /// Create a new OAuth plugin for GitHub
164    pub fn github(
165        client_id: &str,
166        client_secret: &str,
167        redirect_uri: &str,
168        user_manager: Arc<M>,
169        oauth_storage: Arc<S>,
170    ) -> Self {
171        OAuthPluginBuilder::new(
172            Provider::github(client_id, client_secret, redirect_uri),
173            user_manager,
174            oauth_storage,
175        )
176        .build()
177    }
178}
179
180impl<M, S> OAuthPlugin<M, S>
181where
182    M: UserManager,
183    S: OAuthStorage,
184{
185    /// Begin the authentication process by generating a new CSRF state and redirecting the user to the provider's authorization URL.
186    ///
187    /// This method is the first step in the oauth authorization code flow. It will:
188    /// 1. Generate a CSRF token for security
189    /// 2. Generate the authorization URL to redirect the user to
190    ///
191    /// # Returns
192    /// Returns an `AuthFlowBegin` containing:
193    /// * The CSRF state to prevent cross-site request forgery
194    /// * The authorization URI to redirect the user to
195    ///
196    /// # Errors
197    /// Returns an error if:
198    /// * The provider metadata discovery fails
199    /// * The HTTP client cannot be created
200    pub async fn get_authorization_url(&self) -> Result<AuthorizationUrl, Error> {
201        let (authorization_url, pkce_verifier) = self.provider.get_authorization_url()?;
202
203        // Store the PKCE verifier in the storage using the CSRF state as the key
204        self.oauth_storage
205            .store_pkce_verifier(
206                &authorization_url.csrf_state,
207                &pkce_verifier,
208                chrono::Duration::minutes(5),
209            )
210            .await
211            .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
212
213        Ok(authorization_url)
214    }
215
216    /// Creates or retrieves an existing user based on oauth account information
217    ///
218    /// # Arguments
219    /// * `email` - The email address of the user
220    /// * `subject` - The subject of the oauth account
221    ///
222    /// # Returns
223    /// Returns a [`User`] struct containing the user's information.
224    pub async fn get_or_create_user(&self, email: String, subject: String) -> Result<User, Error> {
225        // Check if user exists in database by provider and subject
226        let oauth_account = self
227            .oauth_storage
228            .get_oauth_account_by_provider_and_subject(self.provider.name(), &subject)
229            .await
230            .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
231
232        if let Some(oauth_account) = oauth_account {
233            tracing::info!(
234                user_id = ?oauth_account.user_id,
235                "User already exists in database"
236            );
237
238            let user = self
239                .user_manager
240                .get_user(&oauth_account.user_id)
241                .await?
242                .ok_or(Error::Auth(AuthError::UserNotFound))?;
243
244            return Ok(user);
245        }
246
247        // Create new user with verified email
248        let new_user = NewUser::builder()
249            .id(UserId::new_random())
250            .email(email)
251            .email_verified_at(Some(chrono::Utc::now()))
252            .build()
253            .unwrap();
254
255        let user = self.user_manager.create_user(&new_user).await?;
256
257        // Create link between user and provider
258        self.oauth_storage
259            .create_oauth_account(self.provider.name(), &subject, &user.id)
260            .await
261            .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
262
263        tracing::info!(
264            user_id = ?user.id,
265            provider = ?self.provider.name(),
266            subject = ?subject,
267            "Successfully created link between user and provider"
268        );
269
270        self.emit_event(&Event::UserCreated(user.clone())).await?;
271
272        Ok(user)
273    }
274
275    /// Complete the authentication process by exchanging the authorization code for an access token and user information.
276    ///
277    /// This method is the second step in the oauth authorization code flow. It will:
278    /// 1. Exchange the authorization code for an access token and user information
279    /// 2. Create a new user if they don't exist
280    /// 3. Create a link between the user and the provider
281    ///
282    /// # Arguments
283    /// * `code` - The authorization code
284    /// * `csrf_state` - The CSRF state
285    ///
286    /// # Returns
287    /// Returns a [`User`] struct containing the user's information.
288    pub async fn exchange_code(
289        &self,
290        code: String,
291        csrf_state: String,
292    ) -> Result<(User, UserInfo), Error> {
293        let pkce_verifier = self
294            .oauth_storage
295            .get_pkce_verifier(&csrf_state)
296            .await
297            .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?
298            .ok_or(Error::Auth(AuthError::InvalidCredentials))?;
299
300        tracing::debug!(
301            pkce_verifier = ?pkce_verifier,
302            csrf_state = ?csrf_state,
303            "Exchanging code for token"
304        );
305
306        let token_response = self.provider.exchange_code(&code, &pkce_verifier).await?;
307
308        let access_token = token_response.access_token();
309
310        tracing::debug!(
311            access_token = ?access_token,
312            "Getting user info"
313        );
314
315        let user_info = self.provider.get_user_info(access_token.secret()).await?;
316
317        tracing::debug!(
318            user_info = ?user_info,
319            "Got user info"
320        );
321
322        let email = match &user_info {
323            UserInfo::Google(user_info) => user_info.email.clone(),
324            UserInfo::Github(user_info) => {
325                user_info.email.clone().expect("No email found for user")
326            }
327        };
328
329        let subject = match &user_info {
330            UserInfo::Google(user_info) => user_info.sub.clone(),
331            UserInfo::Github(user_info) => user_info.id.to_string(),
332        };
333
334        tracing::debug!(
335            email = ?email,
336            subject = ?subject,
337            "Getting or creating user"
338        );
339
340        let user = self
341            .get_or_create_user(email, subject)
342            .await
343            .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
344
345        Ok((user, user_info))
346    }
347
348    async fn emit_event(&self, event: &Event) -> Result<(), Error> {
349        if let Some(event_bus) = &self.event_bus {
350            event_bus.emit(event).await?;
351        }
352        Ok(())
353    }
354}