torii_storage_postgres/
oauth.rs

1use crate::{PostgresStorage, PostgresUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::time::Duration;
5use torii_core::error::{StorageError, ValidationError};
6use torii_core::storage::OAuthStorage;
7use torii_core::{OAuthAccount, User, UserId};
8
9#[derive(Default)]
10pub struct PostgresOAuthAccountBuilder {
11    user_id: Option<UserId>,
12    provider: Option<String>,
13    subject: Option<String>,
14    created_at: Option<DateTime<Utc>>,
15    updated_at: Option<DateTime<Utc>>,
16}
17
18impl PostgresOAuthAccountBuilder {
19    pub fn user_id(mut self, user_id: UserId) -> Self {
20        self.user_id = Some(user_id);
21        self
22    }
23
24    pub fn provider(mut self, provider: String) -> Self {
25        self.provider = Some(provider);
26        self
27    }
28
29    pub fn subject(mut self, subject: String) -> Self {
30        self.subject = Some(subject);
31        self
32    }
33
34    pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
35        self.created_at = Some(created_at);
36        self
37    }
38
39    pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
40        self.updated_at = Some(updated_at);
41        self
42    }
43
44    pub fn build(self) -> Result<PostgresOAuthAccount, torii_core::Error> {
45        let now = Utc::now();
46        Ok(PostgresOAuthAccount {
47            id: None,
48            user_id: self
49                .user_id
50                .ok_or(ValidationError::MissingField(
51                    "User ID is required".to_string(),
52                ))?
53                .to_string(),
54            provider: self.provider.ok_or(ValidationError::MissingField(
55                "Provider is required".to_string(),
56            ))?,
57            subject: self.subject.ok_or(ValidationError::MissingField(
58                "Subject is required".to_string(),
59            ))?,
60            created_at: self.created_at.unwrap_or(now),
61            updated_at: self.updated_at.unwrap_or(now),
62        })
63    }
64}
65
66#[derive(Debug, Clone, sqlx::FromRow)]
67pub struct PostgresOAuthAccount {
68    pub id: Option<i64>,
69    pub user_id: String,
70    pub provider: String,
71    pub subject: String,
72    pub created_at: DateTime<Utc>,
73    pub updated_at: DateTime<Utc>,
74}
75
76impl PostgresOAuthAccount {
77    pub fn builder() -> PostgresOAuthAccountBuilder {
78        PostgresOAuthAccountBuilder::default()
79    }
80
81    pub fn new(user_id: UserId, provider: impl Into<String>, subject: impl Into<String>) -> Self {
82        PostgresOAuthAccountBuilder::default()
83            .user_id(user_id)
84            .provider(provider.into())
85            .subject(subject.into())
86            .build()
87            .expect("Default builder should never fail")
88    }
89
90    pub fn is_expired(&self, ttl: Duration) -> bool {
91        Utc::now() > self.created_at + ttl
92    }
93}
94
95impl From<PostgresOAuthAccount> for OAuthAccount {
96    fn from(oauth_account: PostgresOAuthAccount) -> Self {
97        OAuthAccount::builder()
98            .user_id(UserId::new(&oauth_account.user_id))
99            .provider(oauth_account.provider)
100            .subject(oauth_account.subject)
101            .created_at(oauth_account.created_at)
102            .updated_at(oauth_account.updated_at)
103            .build()
104            .expect("Default builder should never fail")
105    }
106}
107
108impl From<OAuthAccount> for PostgresOAuthAccount {
109    fn from(oauth_account: OAuthAccount) -> Self {
110        PostgresOAuthAccount::builder()
111            .user_id(oauth_account.user_id)
112            .provider(oauth_account.provider)
113            .subject(oauth_account.subject)
114            .created_at(oauth_account.created_at)
115            .updated_at(oauth_account.updated_at)
116            .build()
117            .expect("Default builder should never fail")
118    }
119}
120
121#[async_trait]
122impl OAuthStorage for PostgresStorage {
123    type Error = StorageError;
124
125    async fn create_oauth_account(
126        &self,
127        provider: &str,
128        subject: &str,
129        user_id: &UserId,
130    ) -> Result<OAuthAccount, <Self as OAuthStorage>::Error> {
131        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
132            .bind(user_id.as_str())
133            .bind(provider)
134            .bind(subject)
135            .bind(Utc::now())
136            .bind(Utc::now())
137            .execute(&self.pool)
138            .await
139            .map_err(|e| {
140                tracing::error!(error = %e, "Failed to create oauth account");
141                StorageError::Database("Failed to create oauth account".to_string())
142            })?;
143
144        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
145            r#"
146            SELECT id, user_id, provider, subject, created_at, updated_at
147            FROM oauth_accounts
148            WHERE user_id = $1
149            "#,
150        )
151        .bind(user_id.as_str())
152        .fetch_one(&self.pool)
153        .await
154        .map_err(|e| {
155            tracing::error!(error = %e, "Failed to get oauth account");
156            StorageError::Database("Failed to get oauth account".to_string())
157        })?;
158
159        Ok(oauth_account.into())
160    }
161
162    async fn store_pkce_verifier(
163        &self,
164        csrf_state: &str,
165        pkce_verifier: &str,
166        expires_in: chrono::Duration,
167    ) -> Result<(), <Self as OAuthStorage>::Error> {
168        sqlx::query(
169            "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES ($1, $2, $3) RETURNING value",
170        )
171        .bind(csrf_state)
172        .bind(pkce_verifier)
173        .bind(Utc::now() + expires_in)
174        .fetch_optional(&self.pool)
175        .await
176        .map_err(|e| {
177            tracing::error!(error = %e, "Failed to store pkce verifier");
178            StorageError::Database("Failed to store pkce verifier".to_string())
179        })?;
180
181        Ok(())
182    }
183
184    async fn get_pkce_verifier(
185        &self,
186        csrf_state: &str,
187    ) -> Result<Option<String>, <Self as OAuthStorage>::Error> {
188        let pkce_verifier =
189            sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = $1")
190                .bind(csrf_state)
191                .fetch_optional(&self.pool)
192                .await
193                .map_err(|e| {
194                    tracing::error!(error = %e, "Failed to get pkce verifier");
195                    StorageError::Database("Failed to get pkce verifier".to_string())
196                })?;
197
198        Ok(pkce_verifier)
199    }
200
201    async fn get_oauth_account_by_provider_and_subject(
202        &self,
203        provider: &str,
204        subject: &str,
205    ) -> Result<Option<OAuthAccount>, <Self as OAuthStorage>::Error> {
206        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
207            r#"
208            SELECT id, user_id, provider, subject, created_at, updated_at
209            FROM oauth_accounts
210            WHERE provider = $1 AND subject = $2
211            "#,
212        )
213        .bind(provider)
214        .bind(subject)
215        .fetch_optional(&self.pool)
216        .await
217        .map_err(|e| {
218            tracing::error!(error = %e, "Failed to get oauth account");
219            StorageError::Database("Failed to get oauth account".to_string())
220        })?;
221
222        if let Some(oauth_account) = oauth_account {
223            Ok(Some(oauth_account.into()))
224        } else {
225            Ok(None)
226        }
227    }
228
229    async fn get_user_by_provider_and_subject(
230        &self,
231        provider: &str,
232        subject: &str,
233    ) -> Result<Option<User>, <Self as OAuthStorage>::Error> {
234        let user = sqlx::query_as::<_, PostgresUser>(
235            r#"
236            SELECT id, email, name, email_verified_at, created_at, updated_at
237            FROM users
238            WHERE provider = $1 AND subject = $2
239            "#,
240        )
241        .bind(provider)
242        .bind(subject)
243        .fetch_optional(&self.pool)
244        .await
245        .map_err(|e| {
246            tracing::error!(error = %e, "Failed to get user by provider and subject");
247            StorageError::Database("Failed to get user by provider and subject".to_string())
248        })?;
249
250        if let Some(user) = user {
251            Ok(Some(user.into()))
252        } else {
253            Ok(None)
254        }
255    }
256
257    async fn link_oauth_account(
258        &self,
259        user_id: &UserId,
260        provider: &str,
261        subject: &str,
262    ) -> Result<(), <Self as OAuthStorage>::Error> {
263        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
264            .bind(user_id.as_str())
265            .bind(provider)
266            .bind(subject)
267            .bind(Utc::now())
268            .bind(Utc::now())
269            .execute(&self.pool)
270            .await
271            .map_err(|e| {
272                tracing::error!(error = %e, "Failed to link oauth account");
273                StorageError::Database("Failed to link oauth account".to_string())
274            })?;
275
276        Ok(())
277    }
278}