torii_storage_postgres/
oauth.rs

1use crate::{PostgresStorage, PostgresUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::time::Duration;
5use torii_core::error::{StorageError, ValidationError};
6use torii_core::storage::OAuthStorage;
7use torii_core::{OAuthAccount, User, UserId};
8
9#[derive(Default)]
10pub struct PostgresOAuthAccountBuilder {
11    user_id: Option<UserId>,
12    provider: Option<String>,
13    subject: Option<String>,
14    created_at: Option<DateTime<Utc>>,
15    updated_at: Option<DateTime<Utc>>,
16}
17
18impl PostgresOAuthAccountBuilder {
19    pub fn user_id(mut self, user_id: UserId) -> Self {
20        self.user_id = Some(user_id);
21        self
22    }
23
24    pub fn provider(mut self, provider: String) -> Self {
25        self.provider = Some(provider);
26        self
27    }
28
29    pub fn subject(mut self, subject: String) -> Self {
30        self.subject = Some(subject);
31        self
32    }
33
34    pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
35        self.created_at = Some(created_at);
36        self
37    }
38
39    pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
40        self.updated_at = Some(updated_at);
41        self
42    }
43
44    pub fn build(self) -> Result<PostgresOAuthAccount, torii_core::Error> {
45        let now = Utc::now();
46        Ok(PostgresOAuthAccount {
47            id: None,
48            user_id: self
49                .user_id
50                .ok_or(ValidationError::MissingField(
51                    "User ID is required".to_string(),
52                ))?
53                .to_string(),
54            provider: self.provider.ok_or(ValidationError::MissingField(
55                "Provider is required".to_string(),
56            ))?,
57            subject: self.subject.ok_or(ValidationError::MissingField(
58                "Subject is required".to_string(),
59            ))?,
60            created_at: self.created_at.unwrap_or(now),
61            updated_at: self.updated_at.unwrap_or(now),
62        })
63    }
64}
65
66#[derive(Debug, Clone, sqlx::FromRow)]
67pub struct PostgresOAuthAccount {
68    pub id: Option<i64>,
69    pub user_id: String,
70    pub provider: String,
71    pub subject: String,
72    pub created_at: DateTime<Utc>,
73    pub updated_at: DateTime<Utc>,
74}
75
76impl PostgresOAuthAccount {
77    pub fn builder() -> PostgresOAuthAccountBuilder {
78        PostgresOAuthAccountBuilder::default()
79    }
80
81    pub fn new(user_id: UserId, provider: impl Into<String>, subject: impl Into<String>) -> Self {
82        PostgresOAuthAccountBuilder::default()
83            .user_id(user_id)
84            .provider(provider.into())
85            .subject(subject.into())
86            .build()
87            .expect("Default builder should never fail")
88    }
89
90    pub fn is_expired(&self, ttl: Duration) -> bool {
91        Utc::now() > self.created_at + ttl
92    }
93}
94
95impl From<PostgresOAuthAccount> for OAuthAccount {
96    fn from(oauth_account: PostgresOAuthAccount) -> Self {
97        OAuthAccount::builder()
98            .user_id(UserId::new(&oauth_account.user_id))
99            .provider(oauth_account.provider)
100            .subject(oauth_account.subject)
101            .created_at(oauth_account.created_at)
102            .updated_at(oauth_account.updated_at)
103            .build()
104            .expect("Default builder should never fail")
105    }
106}
107
108impl From<OAuthAccount> for PostgresOAuthAccount {
109    fn from(oauth_account: OAuthAccount) -> Self {
110        PostgresOAuthAccount::builder()
111            .user_id(oauth_account.user_id)
112            .provider(oauth_account.provider)
113            .subject(oauth_account.subject)
114            .created_at(oauth_account.created_at)
115            .updated_at(oauth_account.updated_at)
116            .build()
117            .expect("Default builder should never fail")
118    }
119}
120
121#[async_trait]
122impl OAuthStorage for PostgresStorage {
123    async fn create_oauth_account(
124        &self,
125        provider: &str,
126        subject: &str,
127        user_id: &UserId,
128    ) -> Result<OAuthAccount, torii_core::Error> {
129        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
130            .bind(user_id.as_str())
131            .bind(provider)
132            .bind(subject)
133            .bind(Utc::now())
134            .bind(Utc::now())
135            .execute(&self.pool)
136            .await
137            .map_err(|e| {
138                tracing::error!(error = %e, "Failed to create oauth account");
139                StorageError::Database("Failed to create oauth account".to_string())
140            })?;
141
142        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
143            r#"
144            SELECT id, user_id, provider, subject, created_at, updated_at
145            FROM oauth_accounts
146            WHERE user_id = $1
147            "#,
148        )
149        .bind(user_id.as_str())
150        .fetch_one(&self.pool)
151        .await
152        .map_err(|e| {
153            tracing::error!(error = %e, "Failed to get oauth account");
154            StorageError::Database("Failed to get oauth account".to_string())
155        })?;
156
157        Ok(oauth_account.into())
158    }
159
160    async fn store_pkce_verifier(
161        &self,
162        csrf_state: &str,
163        pkce_verifier: &str,
164        expires_in: chrono::Duration,
165    ) -> Result<(), torii_core::Error> {
166        sqlx::query(
167            "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES ($1, $2, $3) RETURNING value",
168        )
169        .bind(csrf_state)
170        .bind(pkce_verifier)
171        .bind(Utc::now() + expires_in)
172        .fetch_optional(&self.pool)
173        .await
174        .map_err(|e| {
175            tracing::error!(error = %e, "Failed to store pkce verifier");
176            StorageError::Database("Failed to store pkce verifier".to_string())
177        })?;
178
179        Ok(())
180    }
181
182    async fn get_pkce_verifier(
183        &self,
184        csrf_state: &str,
185    ) -> Result<Option<String>, torii_core::Error> {
186        let pkce_verifier =
187            sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = $1")
188                .bind(csrf_state)
189                .fetch_optional(&self.pool)
190                .await
191                .map_err(|e| {
192                    tracing::error!(error = %e, "Failed to get pkce verifier");
193                    StorageError::Database("Failed to get pkce verifier".to_string())
194                })?;
195
196        Ok(pkce_verifier)
197    }
198
199    async fn get_oauth_account_by_provider_and_subject(
200        &self,
201        provider: &str,
202        subject: &str,
203    ) -> Result<Option<OAuthAccount>, torii_core::Error> {
204        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
205            r#"
206            SELECT id, user_id, provider, subject, created_at, updated_at
207            FROM oauth_accounts
208            WHERE provider = $1 AND subject = $2
209            "#,
210        )
211        .bind(provider)
212        .bind(subject)
213        .fetch_optional(&self.pool)
214        .await
215        .map_err(|e| {
216            tracing::error!(error = %e, "Failed to get oauth account");
217            StorageError::Database("Failed to get oauth account".to_string())
218        })?;
219
220        if let Some(oauth_account) = oauth_account {
221            Ok(Some(oauth_account.into()))
222        } else {
223            Ok(None)
224        }
225    }
226
227    async fn get_user_by_provider_and_subject(
228        &self,
229        provider: &str,
230        subject: &str,
231    ) -> Result<Option<User>, torii_core::Error> {
232        let user = sqlx::query_as::<_, PostgresUser>(
233            r#"
234            SELECT id, email, name, email_verified_at, created_at, updated_at
235            FROM users
236            WHERE provider = $1 AND subject = $2
237            "#,
238        )
239        .bind(provider)
240        .bind(subject)
241        .fetch_optional(&self.pool)
242        .await
243        .map_err(|e| {
244            tracing::error!(error = %e, "Failed to get user by provider and subject");
245            StorageError::Database("Failed to get user by provider and subject".to_string())
246        })?;
247
248        if let Some(user) = user {
249            Ok(Some(user.into()))
250        } else {
251            Ok(None)
252        }
253    }
254
255    async fn link_oauth_account(
256        &self,
257        user_id: &UserId,
258        provider: &str,
259        subject: &str,
260    ) -> Result<(), torii_core::Error> {
261        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
262            .bind(user_id.as_str())
263            .bind(provider)
264            .bind(subject)
265            .bind(Utc::now())
266            .bind(Utc::now())
267            .execute(&self.pool)
268            .await
269            .map_err(|e| {
270                tracing::error!(error = %e, "Failed to link oauth account");
271                StorageError::Database("Failed to link oauth account".to_string())
272            })?;
273
274        Ok(())
275    }
276}