1pub mod providers;
2
3use std::sync::Arc;
4
5use oauth2::TokenResponse;
6
7use providers::{Provider, UserInfo};
8use torii_core::error::AuthError;
9use torii_core::{Error, NewUser, Plugin, User, UserId, UserManager};
10use torii_core::{
11 events::{Event, EventBus},
12 storage::OAuthStorage,
13};
14
15pub struct AuthorizationUrl {
31 url: String,
33 csrf_state: String,
36}
37
38impl AuthorizationUrl {
39 pub fn new(url: &str, csrf_state: &str) -> Self {
40 Self {
41 url: url.to_string(),
42 csrf_state: csrf_state.to_string(),
43 }
44 }
45
46 pub fn url(&self) -> &str {
47 &self.url
48 }
49
50 pub fn csrf_state(&self) -> &str {
51 &self.csrf_state
52 }
53}
54
55pub struct OAuthPlugin<M, S>
56where
57 M: UserManager,
58 S: OAuthStorage,
59{
60 provider: Provider,
62 user_manager: Arc<M>,
64 oauth_storage: Arc<S>,
66 event_bus: Option<EventBus>,
68}
69
70impl<M, S> Plugin for OAuthPlugin<M, S>
71where
72 M: UserManager,
73 S: OAuthStorage,
74{
75 fn name(&self) -> String {
76 self.provider.name().to_string()
77 }
78}
79
80pub struct OAuthPluginBuilder<M, S>
81where
82 M: UserManager,
83 S: OAuthStorage,
84{
85 provider: Provider,
86 user_manager: Arc<M>,
87 oauth_storage: Arc<S>,
88 event_bus: Option<EventBus>,
89}
90
91impl<M, S> OAuthPluginBuilder<M, S>
92where
93 M: UserManager,
94 S: OAuthStorage,
95{
96 pub fn new(provider: Provider, user_manager: Arc<M>, oauth_storage: Arc<S>) -> Self {
97 Self {
98 provider,
99 user_manager,
100 oauth_storage,
101 event_bus: None,
102 }
103 }
104
105 pub fn event_bus(mut self, event_bus: EventBus) -> Self {
106 self.event_bus = Some(event_bus);
107 self
108 }
109
110 pub fn build(self) -> OAuthPlugin<M, S> {
111 OAuthPlugin {
112 provider: self.provider,
113 user_manager: self.user_manager,
114 oauth_storage: self.oauth_storage,
115 event_bus: self.event_bus,
116 }
117 }
118}
119
120impl<M, S> OAuthPlugin<M, S>
121where
122 M: UserManager,
123 S: OAuthStorage,
124{
125 pub fn builder(
126 provider: Provider,
127 user_manager: Arc<M>,
128 oauth_storage: Arc<S>,
129 ) -> OAuthPluginBuilder<M, S> {
130 OAuthPluginBuilder::new(provider, user_manager, oauth_storage)
131 }
132
133 pub fn new(provider: Provider, user_manager: Arc<M>, oauth_storage: Arc<S>) -> Self {
134 Self {
135 provider,
136 user_manager,
137 oauth_storage,
138 event_bus: None,
139 }
140 }
141
142 pub fn with_event_bus(mut self, event_bus: EventBus) -> Self {
143 self.event_bus = Some(event_bus);
144 self
145 }
146
147 pub fn google(
149 client_id: &str,
150 client_secret: &str,
151 redirect_uri: &str,
152 user_manager: Arc<M>,
153 oauth_storage: Arc<S>,
154 ) -> Self {
155 OAuthPluginBuilder::new(
156 Provider::google(client_id, client_secret, redirect_uri),
157 user_manager,
158 oauth_storage,
159 )
160 .build()
161 }
162
163 pub fn github(
165 client_id: &str,
166 client_secret: &str,
167 redirect_uri: &str,
168 user_manager: Arc<M>,
169 oauth_storage: Arc<S>,
170 ) -> Self {
171 OAuthPluginBuilder::new(
172 Provider::github(client_id, client_secret, redirect_uri),
173 user_manager,
174 oauth_storage,
175 )
176 .build()
177 }
178}
179
180impl<M, S> OAuthPlugin<M, S>
181where
182 M: UserManager,
183 S: OAuthStorage,
184{
185 pub async fn get_authorization_url(&self) -> Result<AuthorizationUrl, Error> {
201 let (authorization_url, pkce_verifier) = self.provider.get_authorization_url()?;
202
203 self.oauth_storage
205 .store_pkce_verifier(
206 &authorization_url.csrf_state,
207 &pkce_verifier,
208 chrono::Duration::minutes(5),
209 )
210 .await
211 .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
212
213 Ok(authorization_url)
214 }
215
216 pub async fn get_or_create_user(&self, email: String, subject: String) -> Result<User, Error> {
225 let oauth_account = self
227 .oauth_storage
228 .get_oauth_account_by_provider_and_subject(self.provider.name(), &subject)
229 .await
230 .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
231
232 if let Some(oauth_account) = oauth_account {
233 tracing::info!(
234 user_id = ?oauth_account.user_id,
235 "User already exists in database"
236 );
237
238 let user = self
239 .user_manager
240 .get_user(&oauth_account.user_id)
241 .await?
242 .ok_or(Error::Auth(AuthError::UserNotFound))?;
243
244 return Ok(user);
245 }
246
247 let new_user = NewUser::builder()
249 .id(UserId::new_random())
250 .email(email)
251 .email_verified_at(Some(chrono::Utc::now()))
252 .build()
253 .unwrap();
254
255 let user = self.user_manager.create_user(&new_user).await?;
256
257 self.oauth_storage
259 .create_oauth_account(self.provider.name(), &subject, &user.id)
260 .await
261 .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
262
263 tracing::info!(
264 user_id = ?user.id,
265 provider = ?self.provider.name(),
266 subject = ?subject,
267 "Successfully created link between user and provider"
268 );
269
270 self.emit_event(&Event::UserCreated(user.clone())).await?;
271
272 Ok(user)
273 }
274
275 pub async fn exchange_code(
289 &self,
290 code: String,
291 csrf_state: String,
292 ) -> Result<(User, UserInfo), Error> {
293 let pkce_verifier = self
294 .oauth_storage
295 .get_pkce_verifier(&csrf_state)
296 .await
297 .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?
298 .ok_or(Error::Auth(AuthError::InvalidCredentials))?;
299
300 tracing::debug!(
301 pkce_verifier = ?pkce_verifier,
302 csrf_state = ?csrf_state,
303 "Exchanging code for token"
304 );
305
306 let token_response = self.provider.exchange_code(&code, &pkce_verifier).await?;
307
308 let access_token = token_response.access_token();
309
310 tracing::debug!(
311 access_token = ?access_token,
312 "Getting user info"
313 );
314
315 let user_info = self.provider.get_user_info(access_token.secret()).await?;
316
317 tracing::debug!(
318 user_info = ?user_info,
319 "Got user info"
320 );
321
322 let email = match &user_info {
323 UserInfo::Google(user_info) => user_info.email.clone(),
324 UserInfo::Github(user_info) => {
325 user_info.email.clone().expect("No email found for user")
326 }
327 };
328
329 let subject = match &user_info {
330 UserInfo::Google(user_info) => user_info.sub.clone(),
331 UserInfo::Github(user_info) => user_info.id.to_string(),
332 };
333
334 tracing::debug!(
335 email = ?email,
336 subject = ?subject,
337 "Getting or creating user"
338 );
339
340 let user = self
341 .get_or_create_user(email, subject)
342 .await
343 .map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
344
345 Ok((user, user_info))
346 }
347
348 async fn emit_event(&self, event: &Event) -> Result<(), Error> {
349 if let Some(event_bus) = &self.event_bus {
350 event_bus.emit(event).await?;
351 }
352 Ok(())
353 }
354}