Skip to main content

torii_storage_sqlite/
oauth.rs

1use crate::{SqliteStorage, SqliteUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use torii_core::error::StorageError;
5use torii_core::storage::OAuthStorage;
6use torii_core::{OAuthAccount, User, UserId};
7
8#[derive(Debug, Clone, sqlx::FromRow)]
9pub struct SqliteOAuthAccount {
10    #[allow(dead_code)]
11    id: Option<i64>,
12    user_id: String,
13    provider: String,
14    subject: String,
15    created_at: i64,
16    updated_at: i64,
17}
18
19impl From<SqliteOAuthAccount> for OAuthAccount {
20    fn from(oauth_account: SqliteOAuthAccount) -> Self {
21        OAuthAccount::builder()
22            .user_id(UserId::new(&oauth_account.user_id))
23            .provider(oauth_account.provider)
24            .subject(oauth_account.subject)
25            .created_at(
26                DateTime::from_timestamp(oauth_account.created_at, 0).expect("Invalid timestamp"),
27            )
28            .updated_at(
29                DateTime::from_timestamp(oauth_account.updated_at, 0).expect("Invalid timestamp"),
30            )
31            .build()
32            .unwrap()
33    }
34}
35
36impl From<OAuthAccount> for SqliteOAuthAccount {
37    fn from(oauth_account: OAuthAccount) -> Self {
38        SqliteOAuthAccount {
39            id: None,
40            user_id: oauth_account.user_id.into_inner(),
41            provider: oauth_account.provider,
42            subject: oauth_account.subject,
43            created_at: oauth_account.created_at.timestamp(),
44            updated_at: oauth_account.updated_at.timestamp(),
45        }
46    }
47}
48
49#[async_trait]
50impl OAuthStorage for SqliteStorage {
51    async fn create_oauth_account(
52        &self,
53        provider: &str,
54        subject: &str,
55        user_id: &UserId,
56    ) -> Result<OAuthAccount, torii_core::Error> {
57        let now = Utc::now();
58        let oauth_account = sqlx::query_as::<_, SqliteOAuthAccount>(
59            r#"
60            INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at)
61            VALUES (?, ?, ?, ?, ?)
62            RETURNING id, user_id, provider, subject, created_at, updated_at
63            "#,
64        )
65        .bind(user_id.as_str())
66        .bind(provider)
67        .bind(subject)
68        .bind(now.timestamp())
69        .bind(now.timestamp())
70        .fetch_one(&self.pool)
71        .await
72        .map_err(|e| {
73            tracing::error!(error = %e, "Failed to create oauth account");
74            StorageError::Database("Failed to create oauth account".to_string())
75        })?;
76
77        Ok(oauth_account.into())
78    }
79
80    async fn get_user_by_provider_and_subject(
81        &self,
82        provider: &str,
83        subject: &str,
84    ) -> Result<Option<User>, torii_core::Error> {
85        let user = sqlx::query_as::<_, SqliteUser>(
86            r#"
87            SELECT id, email, name, email_verified_at, created_at, updated_at
88            FROM users
89            WHERE provider = ? AND subject = ?
90            "#,
91        )
92        .bind(provider)
93        .bind(subject)
94        .fetch_optional(&self.pool)
95        .await
96        .map_err(|e| {
97            tracing::error!(error = %e, "Failed to get user by provider and subject");
98            StorageError::Database("Failed to get user by provider and subject".to_string())
99        })?;
100
101        if let Some(user) = user {
102            Ok(Some(user.into()))
103        } else {
104            Ok(None)
105        }
106    }
107
108    async fn get_oauth_account_by_provider_and_subject(
109        &self,
110        provider: &str,
111        subject: &str,
112    ) -> Result<Option<OAuthAccount>, torii_core::Error> {
113        let oauth_account = sqlx::query_as::<_, SqliteOAuthAccount>(
114            r#"
115            SELECT id, user_id, provider, subject, created_at, updated_at
116            FROM oauth_accounts
117            WHERE provider = ? AND subject = ?
118            "#,
119        )
120        .bind(provider)
121        .bind(subject)
122        .fetch_optional(&self.pool)
123        .await
124        .map_err(|e| {
125            tracing::error!(error = %e, "Failed to get oauth account");
126            StorageError::Database("Failed to get oauth account".to_string())
127        })?;
128
129        if let Some(oauth_account) = oauth_account {
130            Ok(Some(oauth_account.into()))
131        } else {
132            Ok(None)
133        }
134    }
135
136    async fn link_oauth_account(
137        &self,
138        user_id: &UserId,
139        provider: &str,
140        subject: &str,
141    ) -> Result<(), torii_core::Error> {
142        let now = Utc::now();
143        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES (?, ?, ?, ?, ?)")
144            .bind(user_id.as_str())
145            .bind(provider)
146            .bind(subject)
147            .bind(now.timestamp())
148            .bind(now.timestamp())
149            .execute(&self.pool)
150            .await
151            .map_err(|e| {
152                tracing::error!(error = %e, "Failed to link oauth account");
153                StorageError::Database("Failed to link oauth account".to_string())
154            })?;
155
156        Ok(())
157    }
158
159    async fn store_pkce_verifier(
160        &self,
161        csrf_state: &str,
162        pkce_verifier: &str,
163        expires_in: chrono::Duration,
164    ) -> Result<(), torii_core::Error> {
165        sqlx::query(
166            "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES (?, ?, ?)",
167        )
168        .bind(csrf_state)
169        .bind(pkce_verifier)
170        .bind((Utc::now() + expires_in).timestamp())
171        .execute(&self.pool)
172        .await
173        .map_err(|e| {
174            tracing::error!(error = %e, "Failed to save pkce verifier");
175            StorageError::Database("Failed to save pkce verifier".to_string())
176        })?;
177
178        Ok(())
179    }
180
181    async fn get_pkce_verifier(
182        &self,
183        csrf_state: &str,
184    ) -> Result<Option<String>, torii_core::Error> {
185        let pkce_verifier: Option<String> =
186            sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = ?")
187                .bind(csrf_state)
188                .fetch_optional(&self.pool)
189                .await
190                .map_err(|e| {
191                    tracing::error!(error = %e, "Failed to get pkce verifier");
192                    StorageError::Database("Failed to get pkce verifier".to_string())
193                })?;
194
195        Ok(pkce_verifier)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    use crate::tests::setup_sqlite_storage;
204
205    use crate::tests::create_test_user;
206
207    #[tokio::test]
208    async fn test_oauth_account_linking() {
209        let storage = setup_sqlite_storage()
210            .await
211            .expect("Failed to setup storage");
212
213        // Create test user
214        let user = create_test_user(&storage, "1")
215            .await
216            .expect("Failed to create user");
217
218        // Link OAuth account
219        storage
220            .link_oauth_account(&user.id, "google", "oauth_id_123")
221            .await
222            .expect("Failed to link oauth account");
223
224        // Try linking same account again - should fail
225        let result = storage
226            .link_oauth_account(&user.id, "google", "oauth_id_123")
227            .await;
228        assert!(result.is_err());
229
230        // Get OAuth account
231        let oauth_account = storage
232            .get_oauth_account_by_provider_and_subject("google", "oauth_id_123")
233            .await
234            .expect("Failed to get oauth account");
235
236        assert!(oauth_account.is_some());
237        assert_eq!(oauth_account.unwrap().user_id, user.id);
238    }
239
240    #[tokio::test]
241    async fn test_pkce_verifier() {
242        let storage = setup_sqlite_storage()
243            .await
244            .expect("Failed to setup storage");
245
246        let csrf_state = "test_state";
247        let pkce_verifier = "test_verifier";
248        let expires_in = chrono::Duration::seconds(3600);
249
250        // Store PKCE verifier
251        storage
252            .store_pkce_verifier(csrf_state, pkce_verifier, expires_in)
253            .await
254            .expect("Failed to store pkce verifier");
255
256        // Get PKCE verifier
257        let stored_verifier = storage
258            .get_pkce_verifier(csrf_state)
259            .await
260            .expect("Failed to get pkce verifier");
261
262        assert_eq!(stored_verifier, Some(pkce_verifier.to_string()));
263
264        // Get non-existent PKCE verifier
265        let non_existent = storage
266            .get_pkce_verifier("non_existent")
267            .await
268            .expect("Failed to get pkce verifier");
269
270        assert_eq!(non_existent, None);
271    }
272
273    #[tokio::test]
274    async fn test_multiple_oauth_providers() {
275        let storage = setup_sqlite_storage()
276            .await
277            .expect("Failed to setup storage");
278
279        // Create test user
280        let user = create_test_user(&storage, "1")
281            .await
282            .expect("Failed to create user");
283
284        // Link multiple OAuth accounts
285        storage
286            .link_oauth_account(&user.id, "google", "google_id_123")
287            .await
288            .expect("Failed to link Google account");
289
290        storage
291            .link_oauth_account(&user.id, "github", "github_id_123")
292            .await
293            .expect("Failed to link GitHub account");
294
295        // Verify both accounts are linked
296        let google_user = storage
297            .get_oauth_account_by_provider_and_subject("google", "google_id_123")
298            .await
299            .expect("Failed to get Google oauth account");
300
301        let github_user = storage
302            .get_oauth_account_by_provider_and_subject("github", "github_id_123")
303            .await
304            .expect("Failed to get GitHub oauth account");
305
306        assert!(google_user.is_some());
307        assert!(github_user.is_some());
308        assert_eq!(google_user.unwrap().user_id, user.id);
309        assert_eq!(github_user.unwrap().user_id, user.id);
310    }
311}