1use crate::{SqliteStorage, SqliteUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use torii_core::error::StorageError;
5use torii_core::storage::OAuthStorage;
6use torii_core::{OAuthAccount, User, UserId};
7
8#[derive(Debug, Clone, sqlx::FromRow)]
9pub struct SqliteOAuthAccount {
10 #[allow(dead_code)]
11 id: Option<i64>,
12 user_id: String,
13 provider: String,
14 subject: String,
15 created_at: i64,
16 updated_at: i64,
17}
18
19impl From<SqliteOAuthAccount> for OAuthAccount {
20 fn from(oauth_account: SqliteOAuthAccount) -> Self {
21 OAuthAccount::builder()
22 .user_id(UserId::new(&oauth_account.user_id))
23 .provider(oauth_account.provider)
24 .subject(oauth_account.subject)
25 .created_at(
26 DateTime::from_timestamp(oauth_account.created_at, 0).expect("Invalid timestamp"),
27 )
28 .updated_at(
29 DateTime::from_timestamp(oauth_account.updated_at, 0).expect("Invalid timestamp"),
30 )
31 .build()
32 .unwrap()
33 }
34}
35
36impl From<OAuthAccount> for SqliteOAuthAccount {
37 fn from(oauth_account: OAuthAccount) -> Self {
38 SqliteOAuthAccount {
39 id: None,
40 user_id: oauth_account.user_id.into_inner(),
41 provider: oauth_account.provider,
42 subject: oauth_account.subject,
43 created_at: oauth_account.created_at.timestamp(),
44 updated_at: oauth_account.updated_at.timestamp(),
45 }
46 }
47}
48
49#[async_trait]
50impl OAuthStorage for SqliteStorage {
51 async fn create_oauth_account(
52 &self,
53 provider: &str,
54 subject: &str,
55 user_id: &UserId,
56 ) -> Result<OAuthAccount, torii_core::Error> {
57 let now = Utc::now();
58 let oauth_account = sqlx::query_as::<_, SqliteOAuthAccount>(
59 r#"
60 INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at)
61 VALUES (?, ?, ?, ?, ?)
62 RETURNING id, user_id, provider, subject, created_at, updated_at
63 "#,
64 )
65 .bind(user_id.as_str())
66 .bind(provider)
67 .bind(subject)
68 .bind(now.timestamp())
69 .bind(now.timestamp())
70 .fetch_one(&self.pool)
71 .await
72 .map_err(|e| {
73 tracing::error!(error = %e, "Failed to create oauth account");
74 StorageError::Database("Failed to create oauth account".to_string())
75 })?;
76
77 Ok(oauth_account.into())
78 }
79
80 async fn get_user_by_provider_and_subject(
81 &self,
82 provider: &str,
83 subject: &str,
84 ) -> Result<Option<User>, torii_core::Error> {
85 let user = sqlx::query_as::<_, SqliteUser>(
86 r#"
87 SELECT id, email, name, email_verified_at, created_at, updated_at
88 FROM users
89 WHERE provider = ? AND subject = ?
90 "#,
91 )
92 .bind(provider)
93 .bind(subject)
94 .fetch_optional(&self.pool)
95 .await
96 .map_err(|e| {
97 tracing::error!(error = %e, "Failed to get user by provider and subject");
98 StorageError::Database("Failed to get user by provider and subject".to_string())
99 })?;
100
101 if let Some(user) = user {
102 Ok(Some(user.into()))
103 } else {
104 Ok(None)
105 }
106 }
107
108 async fn get_oauth_account_by_provider_and_subject(
109 &self,
110 provider: &str,
111 subject: &str,
112 ) -> Result<Option<OAuthAccount>, torii_core::Error> {
113 let oauth_account = sqlx::query_as::<_, SqliteOAuthAccount>(
114 r#"
115 SELECT id, user_id, provider, subject, created_at, updated_at
116 FROM oauth_accounts
117 WHERE provider = ? AND subject = ?
118 "#,
119 )
120 .bind(provider)
121 .bind(subject)
122 .fetch_optional(&self.pool)
123 .await
124 .map_err(|e| {
125 tracing::error!(error = %e, "Failed to get oauth account");
126 StorageError::Database("Failed to get oauth account".to_string())
127 })?;
128
129 if let Some(oauth_account) = oauth_account {
130 Ok(Some(oauth_account.into()))
131 } else {
132 Ok(None)
133 }
134 }
135
136 async fn link_oauth_account(
137 &self,
138 user_id: &UserId,
139 provider: &str,
140 subject: &str,
141 ) -> Result<(), torii_core::Error> {
142 let now = Utc::now();
143 sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES (?, ?, ?, ?, ?)")
144 .bind(user_id.as_str())
145 .bind(provider)
146 .bind(subject)
147 .bind(now.timestamp())
148 .bind(now.timestamp())
149 .execute(&self.pool)
150 .await
151 .map_err(|e| {
152 tracing::error!(error = %e, "Failed to link oauth account");
153 StorageError::Database("Failed to link oauth account".to_string())
154 })?;
155
156 Ok(())
157 }
158
159 async fn store_pkce_verifier(
160 &self,
161 csrf_state: &str,
162 pkce_verifier: &str,
163 expires_in: chrono::Duration,
164 ) -> Result<(), torii_core::Error> {
165 sqlx::query(
166 "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES (?, ?, ?)",
167 )
168 .bind(csrf_state)
169 .bind(pkce_verifier)
170 .bind((Utc::now() + expires_in).timestamp())
171 .execute(&self.pool)
172 .await
173 .map_err(|e| {
174 tracing::error!(error = %e, "Failed to save pkce verifier");
175 StorageError::Database("Failed to save pkce verifier".to_string())
176 })?;
177
178 Ok(())
179 }
180
181 async fn get_pkce_verifier(
182 &self,
183 csrf_state: &str,
184 ) -> Result<Option<String>, torii_core::Error> {
185 let pkce_verifier: Option<String> =
186 sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = ?")
187 .bind(csrf_state)
188 .fetch_optional(&self.pool)
189 .await
190 .map_err(|e| {
191 tracing::error!(error = %e, "Failed to get pkce verifier");
192 StorageError::Database("Failed to get pkce verifier".to_string())
193 })?;
194
195 Ok(pkce_verifier)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 use crate::tests::setup_sqlite_storage;
204
205 use crate::tests::create_test_user;
206
207 #[tokio::test]
208 async fn test_oauth_account_linking() {
209 let storage = setup_sqlite_storage()
210 .await
211 .expect("Failed to setup storage");
212
213 let user = create_test_user(&storage, "1")
215 .await
216 .expect("Failed to create user");
217
218 storage
220 .link_oauth_account(&user.id, "google", "oauth_id_123")
221 .await
222 .expect("Failed to link oauth account");
223
224 let result = storage
226 .link_oauth_account(&user.id, "google", "oauth_id_123")
227 .await;
228 assert!(result.is_err());
229
230 let oauth_account = storage
232 .get_oauth_account_by_provider_and_subject("google", "oauth_id_123")
233 .await
234 .expect("Failed to get oauth account");
235
236 assert!(oauth_account.is_some());
237 assert_eq!(oauth_account.unwrap().user_id, user.id);
238 }
239
240 #[tokio::test]
241 async fn test_pkce_verifier() {
242 let storage = setup_sqlite_storage()
243 .await
244 .expect("Failed to setup storage");
245
246 let csrf_state = "test_state";
247 let pkce_verifier = "test_verifier";
248 let expires_in = chrono::Duration::seconds(3600);
249
250 storage
252 .store_pkce_verifier(csrf_state, pkce_verifier, expires_in)
253 .await
254 .expect("Failed to store pkce verifier");
255
256 let stored_verifier = storage
258 .get_pkce_verifier(csrf_state)
259 .await
260 .expect("Failed to get pkce verifier");
261
262 assert_eq!(stored_verifier, Some(pkce_verifier.to_string()));
263
264 let non_existent = storage
266 .get_pkce_verifier("non_existent")
267 .await
268 .expect("Failed to get pkce verifier");
269
270 assert_eq!(non_existent, None);
271 }
272
273 #[tokio::test]
274 async fn test_multiple_oauth_providers() {
275 let storage = setup_sqlite_storage()
276 .await
277 .expect("Failed to setup storage");
278
279 let user = create_test_user(&storage, "1")
281 .await
282 .expect("Failed to create user");
283
284 storage
286 .link_oauth_account(&user.id, "google", "google_id_123")
287 .await
288 .expect("Failed to link Google account");
289
290 storage
291 .link_oauth_account(&user.id, "github", "github_id_123")
292 .await
293 .expect("Failed to link GitHub account");
294
295 let google_user = storage
297 .get_oauth_account_by_provider_and_subject("google", "google_id_123")
298 .await
299 .expect("Failed to get Google oauth account");
300
301 let github_user = storage
302 .get_oauth_account_by_provider_and_subject("github", "github_id_123")
303 .await
304 .expect("Failed to get GitHub oauth account");
305
306 assert!(google_user.is_some());
307 assert!(github_user.is_some());
308 assert_eq!(google_user.unwrap().user_id, user.id);
309 assert_eq!(github_user.unwrap().user_id, user.id);
310 }
311}