torii_storage_postgres/
lib.rs

1mod migrations;
2mod oauth;
3mod passkey;
4mod password;
5mod session;
6
7use async_trait::async_trait;
8use chrono::DateTime;
9use chrono::Utc;
10use migrations::CreateIndexes;
11use migrations::CreateOAuthAccountsTable;
12use migrations::CreatePasskeyChallengesTable;
13use migrations::CreatePasskeysTable;
14use migrations::CreateSessionsTable;
15use migrations::CreateUsersTable;
16use migrations::PostgresMigrationManager;
17use sqlx::PgPool;
18use torii_core::error::StorageError;
19use torii_core::{
20    User, UserId,
21    storage::{NewUser, UserStorage},
22};
23use torii_migration::Migration;
24use torii_migration::MigrationManager;
25
26#[derive(Debug)]
27pub struct PostgresStorage {
28    pool: PgPool,
29}
30
31impl PostgresStorage {
32    pub fn new(pool: PgPool) -> Self {
33        Self { pool }
34    }
35
36    pub async fn migrate(&self) -> Result<(), StorageError> {
37        let manager = PostgresMigrationManager::new(self.pool.clone());
38        manager.initialize().await.map_err(|e| {
39            tracing::error!(error = %e, "Failed to initialize migrations");
40            StorageError::Migration("Failed to initialize migrations".to_string())
41        })?;
42
43        let migrations: Vec<Box<dyn Migration<_>>> = vec![
44            Box::new(CreateUsersTable),
45            Box::new(CreateSessionsTable),
46            Box::new(CreateOAuthAccountsTable),
47            Box::new(CreatePasskeysTable),
48            Box::new(CreatePasskeyChallengesTable),
49            Box::new(CreateIndexes),
50        ];
51        manager.up(&migrations).await.map_err(|e| {
52            tracing::error!(error = %e, "Failed to run migrations");
53            StorageError::Migration("Failed to run migrations".to_string())
54        })?;
55
56        Ok(())
57    }
58}
59
60#[derive(Debug, Clone, sqlx::FromRow)]
61pub struct PostgresUser {
62    id: String,
63    email: String,
64    name: Option<String>,
65    email_verified_at: Option<DateTime<Utc>>,
66    created_at: DateTime<Utc>,
67    updated_at: DateTime<Utc>,
68}
69
70impl From<PostgresUser> for User {
71    fn from(user: PostgresUser) -> Self {
72        User::builder()
73            .id(UserId::new(&user.id))
74            .email(user.email)
75            .name(user.name)
76            .email_verified_at(user.email_verified_at)
77            .created_at(user.created_at)
78            .updated_at(user.updated_at)
79            .build()
80            .unwrap()
81    }
82}
83
84impl From<User> for PostgresUser {
85    fn from(user: User) -> Self {
86        PostgresUser {
87            id: user.id.into_inner(),
88            email: user.email,
89            name: user.name,
90            email_verified_at: user.email_verified_at,
91            created_at: user.created_at,
92            updated_at: user.updated_at,
93        }
94    }
95}
96
97#[async_trait]
98impl UserStorage for PostgresStorage {
99    type Error = torii_core::Error;
100
101    async fn create_user(&self, user: &NewUser) -> Result<User, Self::Error> {
102        let user = sqlx::query_as::<_, PostgresUser>(
103            r#"
104            INSERT INTO users (id, email) 
105            VALUES ($1::uuid, $2) 
106            RETURNING id::text, email, name, email_verified_at, created_at, updated_at
107            "#,
108        )
109        .bind(user.id.as_ref())
110        .bind(&user.email)
111        .fetch_one(&self.pool)
112        .await
113        .map_err(|e| {
114            tracing::error!(error = %e, "Failed to create user");
115            StorageError::Database("Failed to create user".to_string())
116        })?;
117
118        Ok(user.into())
119    }
120
121    async fn get_user(&self, id: &UserId) -> Result<Option<User>, Self::Error> {
122        let user = sqlx::query_as::<_, PostgresUser>(
123            r#"
124            SELECT id::text, email, name, email_verified_at, created_at, updated_at 
125            FROM users 
126            WHERE id::text = $1
127            "#,
128        )
129        .bind(id.as_ref())
130        .fetch_optional(&self.pool)
131        .await
132        .map_err(|e| {
133            tracing::error!(error = %e, "Failed to get user");
134            StorageError::Database("Failed to get user".to_string())
135        })?;
136
137        match user {
138            Some(user) => Ok(Some(user.into())),
139            None => Ok(None),
140        }
141    }
142
143    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>, Self::Error> {
144        let user = sqlx::query_as::<_, PostgresUser>(
145            r#"
146            SELECT id::text, email, name, email_verified_at, created_at, updated_at 
147            FROM users 
148            WHERE email = $1
149            "#,
150        )
151        .bind(email)
152        .fetch_optional(&self.pool)
153        .await
154        .map_err(|e| {
155            tracing::error!(error = %e, "Failed to get user by email");
156            StorageError::Database("Failed to get user by email".to_string())
157        })?;
158
159        match user {
160            Some(user) => Ok(Some(user.into())),
161            None => Ok(None),
162        }
163    }
164
165    async fn get_or_create_user_by_email(&self, email: &str) -> Result<User, Self::Error> {
166        let user = self.get_user_by_email(email).await?;
167        if let Some(user) = user {
168            return Ok(user);
169        }
170
171        let user = self
172            .create_user(
173                &NewUser::builder()
174                    .id(UserId::new_random())
175                    .email(email.to_string())
176                    .build()
177                    .unwrap(),
178            )
179            .await
180            .map_err(|e| {
181                tracing::error!(error = %e, "Failed to get or create user by email");
182                StorageError::Database("Failed to get or create user by email".to_string())
183            })?;
184
185        Ok(user)
186    }
187
188    async fn update_user(&self, user: &User) -> Result<User, Self::Error> {
189        let user = sqlx::query_as::<_, PostgresUser>(
190            r#"
191            UPDATE users 
192            SET email = $1, name = $2, email_verified_at = $3, updated_at = $4 
193            WHERE id::text = $5
194            RETURNING id::text, email, name, email_verified_at, created_at, updated_at
195            "#,
196        )
197        .bind(&user.email)
198        .bind(&user.name)
199        .bind(user.email_verified_at)
200        .bind(user.updated_at)
201        .bind(user.id.as_ref())
202        .fetch_one(&self.pool)
203        .await
204        .map_err(|e| {
205            tracing::error!(error = %e, "Failed to update user");
206            StorageError::Database("Failed to update user".to_string())
207        })?;
208
209        Ok(user.into())
210    }
211
212    async fn delete_user(&self, id: &UserId) -> Result<(), Self::Error> {
213        sqlx::query("DELETE FROM users WHERE id::text = $1")
214            .bind(id.as_ref())
215            .execute(&self.pool)
216            .await
217            .map_err(|e| {
218                tracing::error!(error = %e, "Failed to delete user");
219                StorageError::Database("Failed to delete user".to_string())
220            })?;
221
222        Ok(())
223    }
224
225    async fn set_user_email_verified(&self, user_id: &UserId) -> Result<(), Self::Error> {
226        sqlx::query("UPDATE users SET email_verified_at = $1 WHERE id::text = $2")
227            .bind(Utc::now())
228            .bind(user_id.as_ref())
229            .execute(&self.pool)
230            .await
231            .map_err(|e| {
232                tracing::error!(error = %e, "Failed to set user email verified");
233                StorageError::Database("Failed to set user email verified".to_string())
234            })?;
235
236        Ok(())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use rand::Rng;
244    use sqlx::types::chrono::Utc;
245    use std::time::Duration;
246    use torii_core::session::SessionId;
247    use torii_core::{Session, SessionStorage};
248
249    pub(crate) async fn setup_test_db() -> PostgresStorage {
250        // TODO: this function is leaking postgres databases after the test is done.
251        // We should find a way to clean up the database after the test is done.
252
253        let _ = tracing_subscriber::fmt().try_init();
254
255        let pool = PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres")
256            .await
257            .expect("Failed to create pool");
258
259        let db_name = format!("torii_test_{}", rand::rng().random_range(1..i64::MAX));
260
261        // Drop the database if it exists
262        sqlx::query(format!("DROP DATABASE IF EXISTS {}", db_name).as_str())
263            .execute(&pool)
264            .await
265            .expect("Failed to drop database");
266
267        // Create a new database for the test
268        sqlx::query(format!("CREATE DATABASE {}", db_name).as_str())
269            .execute(&pool)
270            .await
271            .expect("Failed to create database");
272
273        let pool = PgPool::connect(
274            format!("postgres://postgres:postgres@localhost:5432/{}", db_name).as_str(),
275        )
276        .await
277        .expect("Failed to create pool");
278
279        let storage = PostgresStorage::new(pool);
280        storage.migrate().await.expect("Failed to run migrations");
281        storage
282    }
283
284    pub(crate) async fn create_test_user(
285        storage: &PostgresStorage,
286        id: &UserId,
287    ) -> Result<User, torii_core::Error> {
288        storage
289            .create_user(
290                &NewUser::builder()
291                    .id(id.clone())
292                    .email(format!("test{}@example.com", id))
293                    .build()
294                    .expect("Failed to build user"),
295            )
296            .await
297    }
298
299    pub(crate) async fn create_test_session(
300        storage: &PostgresStorage,
301        session_id: &SessionId,
302        user_id: &UserId,
303        expires_in: Duration,
304    ) -> Result<Session, torii_core::Error> {
305        let now = Utc::now();
306        storage
307            .create_session(
308                &Session::builder()
309                    .id(session_id.clone())
310                    .user_id(user_id.clone())
311                    .user_agent(Some("test".to_string()))
312                    .ip_address(Some("127.0.0.1".to_string()))
313                    .created_at(now)
314                    .updated_at(now)
315                    .expires_at(now + expires_in)
316                    .build()
317                    .expect("Failed to build session"),
318            )
319            .await
320    }
321}