torii_storage_postgres/
oauth.rs

1use crate::{PostgresStorage, PostgresUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::time::Duration;
5use torii_core::error::{StorageError, ValidationError};
6use torii_core::storage::OAuthStorage;
7use torii_core::{OAuthAccount, User, UserId};
8
9#[derive(Default)]
10pub struct PostgresOAuthAccountBuilder {
11    user_id: Option<UserId>,
12    provider: Option<String>,
13    subject: Option<String>,
14    created_at: Option<DateTime<Utc>>,
15    updated_at: Option<DateTime<Utc>>,
16}
17
18impl PostgresOAuthAccountBuilder {
19    pub fn user_id(mut self, user_id: UserId) -> Self {
20        self.user_id = Some(user_id);
21        self
22    }
23
24    pub fn provider(mut self, provider: String) -> Self {
25        self.provider = Some(provider);
26        self
27    }
28
29    pub fn subject(mut self, subject: String) -> Self {
30        self.subject = Some(subject);
31        self
32    }
33
34    pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
35        self.created_at = Some(created_at);
36        self
37    }
38
39    pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
40        self.updated_at = Some(updated_at);
41        self
42    }
43
44    pub fn build(self) -> Result<PostgresOAuthAccount, torii_core::Error> {
45        let now = Utc::now();
46        Ok(PostgresOAuthAccount {
47            user_id: self
48                .user_id
49                .ok_or(ValidationError::MissingField(
50                    "User ID is required".to_string(),
51                ))?
52                .to_string(),
53            provider: self.provider.ok_or(ValidationError::MissingField(
54                "Provider is required".to_string(),
55            ))?,
56            subject: self.subject.ok_or(ValidationError::MissingField(
57                "Subject is required".to_string(),
58            ))?,
59            created_at: self.created_at.unwrap_or(now),
60            updated_at: self.updated_at.unwrap_or(now),
61        })
62    }
63}
64
65#[derive(Debug, Clone, sqlx::FromRow)]
66pub struct PostgresOAuthAccount {
67    user_id: String,
68    provider: String,
69    subject: String,
70    created_at: DateTime<Utc>,
71    updated_at: DateTime<Utc>,
72}
73
74impl PostgresOAuthAccount {
75    pub fn builder() -> PostgresOAuthAccountBuilder {
76        PostgresOAuthAccountBuilder::default()
77    }
78
79    pub fn new(user_id: UserId, provider: impl Into<String>, subject: impl Into<String>) -> Self {
80        PostgresOAuthAccountBuilder::default()
81            .user_id(user_id)
82            .provider(provider.into())
83            .subject(subject.into())
84            .build()
85            .expect("Default builder should never fail")
86    }
87
88    pub fn is_expired(&self, ttl: Duration) -> bool {
89        Utc::now() > self.created_at + ttl
90    }
91}
92
93impl From<PostgresOAuthAccount> for OAuthAccount {
94    fn from(oauth_account: PostgresOAuthAccount) -> Self {
95        OAuthAccount::builder()
96            .user_id(UserId::new(&oauth_account.user_id))
97            .provider(oauth_account.provider)
98            .subject(oauth_account.subject)
99            .created_at(oauth_account.created_at)
100            .updated_at(oauth_account.updated_at)
101            .build()
102            .expect("Default builder should never fail")
103    }
104}
105
106impl From<OAuthAccount> for PostgresOAuthAccount {
107    fn from(oauth_account: OAuthAccount) -> Self {
108        PostgresOAuthAccount::builder()
109            .user_id(oauth_account.user_id)
110            .provider(oauth_account.provider)
111            .subject(oauth_account.subject)
112            .created_at(oauth_account.created_at)
113            .updated_at(oauth_account.updated_at)
114            .build()
115            .expect("Default builder should never fail")
116    }
117}
118
119#[async_trait]
120impl OAuthStorage for PostgresStorage {
121    async fn create_oauth_account(
122        &self,
123        provider: &str,
124        subject: &str,
125        user_id: &UserId,
126    ) -> Result<OAuthAccount, Self::Error> {
127        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1::uuid, $2, $3, $4, $5)")
128            .bind(user_id.as_str())
129            .bind(provider)
130            .bind(subject)
131            .bind(Utc::now())
132            .bind(Utc::now())
133            .execute(&self.pool)
134            .await
135            .map_err(|e| {
136                tracing::error!(error = %e, "Failed to create oauth account");
137                StorageError::Database("Failed to create oauth account".to_string())
138            })?;
139
140        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
141            r#"
142            SELECT user_id::text, provider, subject, created_at, updated_at
143            FROM oauth_accounts
144            WHERE user_id::text = $1
145            "#,
146        )
147        .bind(user_id.as_str())
148        .fetch_one(&self.pool)
149        .await
150        .map_err(|e| {
151            tracing::error!(error = %e, "Failed to get oauth account");
152            StorageError::Database("Failed to get oauth account".to_string())
153        })?;
154
155        Ok(oauth_account.into())
156    }
157
158    async fn store_pkce_verifier(
159        &self,
160        csrf_state: &str,
161        pkce_verifier: &str,
162        expires_in: chrono::Duration,
163    ) -> Result<(), Self::Error> {
164        sqlx::query(
165            "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES ($1::text, $2, $3) RETURNING value",
166        )
167        .bind(csrf_state)
168        .bind(pkce_verifier)
169        .bind(Utc::now() + expires_in)
170        .fetch_optional(&self.pool)
171        .await
172        .map_err(|e| {
173            tracing::error!(error = %e, "Failed to store pkce verifier");
174            StorageError::Database("Failed to store pkce verifier".to_string())
175        })?;
176
177        Ok(())
178    }
179
180    async fn get_pkce_verifier(&self, csrf_state: &str) -> Result<Option<String>, Self::Error> {
181        let pkce_verifier =
182            sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = $1")
183                .bind(csrf_state)
184                .fetch_optional(&self.pool)
185                .await
186                .map_err(|e| {
187                    tracing::error!(error = %e, "Failed to get pkce verifier");
188                    StorageError::Database("Failed to get pkce verifier".to_string())
189                })?;
190
191        Ok(pkce_verifier)
192    }
193
194    async fn get_oauth_account_by_provider_and_subject(
195        &self,
196        provider: &str,
197        subject: &str,
198    ) -> Result<Option<OAuthAccount>, Self::Error> {
199        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
200            r#"
201            SELECT user_id::text, provider, subject, created_at, updated_at
202            FROM oauth_accounts
203            WHERE provider = $1 AND subject = $2
204            "#,
205        )
206        .bind(provider)
207        .bind(subject)
208        .fetch_optional(&self.pool)
209        .await
210        .map_err(|e| {
211            tracing::error!(error = %e, "Failed to get oauth account");
212            StorageError::Database("Failed to get oauth account".to_string())
213        })?;
214
215        if let Some(oauth_account) = oauth_account {
216            Ok(Some(oauth_account.into()))
217        } else {
218            Ok(None)
219        }
220    }
221
222    async fn get_user_by_provider_and_subject(
223        &self,
224        provider: &str,
225        subject: &str,
226    ) -> Result<Option<User>, Self::Error> {
227        let user = sqlx::query_as::<_, PostgresUser>(
228            r#"
229            SELECT id::text, email, name, email_verified_at, created_at, updated_at
230            FROM users
231            WHERE provider = $1 AND subject = $2
232            "#,
233        )
234        .bind(provider)
235        .bind(subject)
236        .fetch_optional(&self.pool)
237        .await
238        .map_err(|e| {
239            tracing::error!(error = %e, "Failed to get user by provider and subject");
240            StorageError::Database("Failed to get user by provider and subject".to_string())
241        })?;
242
243        if let Some(user) = user {
244            Ok(Some(user.into()))
245        } else {
246            Ok(None)
247        }
248    }
249
250    async fn link_oauth_account(
251        &self,
252        user_id: &UserId,
253        provider: &str,
254        subject: &str,
255    ) -> Result<(), Self::Error> {
256        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1::uuid, $2, $3, $4, $5)")
257            .bind(user_id.as_str())
258            .bind(provider)
259            .bind(subject)
260            .bind(Utc::now())
261            .bind(Utc::now())
262            .execute(&self.pool)
263            .await
264            .map_err(|e| {
265                tracing::error!(error = %e, "Failed to link oauth account");
266                StorageError::Database("Failed to link oauth account".to_string())
267            })?;
268
269        Ok(())
270    }
271}