1use crate::{PostgresStorage, PostgresUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::time::Duration;
5use torii_core::error::{StorageError, ValidationError};
6use torii_core::storage::OAuthStorage;
7use torii_core::{OAuthAccount, User, UserId};
8
9#[derive(Default)]
10pub struct PostgresOAuthAccountBuilder {
11 user_id: Option<UserId>,
12 provider: Option<String>,
13 subject: Option<String>,
14 created_at: Option<DateTime<Utc>>,
15 updated_at: Option<DateTime<Utc>>,
16}
17
18impl PostgresOAuthAccountBuilder {
19 pub fn user_id(mut self, user_id: UserId) -> Self {
20 self.user_id = Some(user_id);
21 self
22 }
23
24 pub fn provider(mut self, provider: String) -> Self {
25 self.provider = Some(provider);
26 self
27 }
28
29 pub fn subject(mut self, subject: String) -> Self {
30 self.subject = Some(subject);
31 self
32 }
33
34 pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
35 self.created_at = Some(created_at);
36 self
37 }
38
39 pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
40 self.updated_at = Some(updated_at);
41 self
42 }
43
44 pub fn build(self) -> Result<PostgresOAuthAccount, torii_core::Error> {
45 let now = Utc::now();
46 Ok(PostgresOAuthAccount {
47 user_id: self
48 .user_id
49 .ok_or(ValidationError::MissingField(
50 "User ID is required".to_string(),
51 ))?
52 .to_string(),
53 provider: self.provider.ok_or(ValidationError::MissingField(
54 "Provider is required".to_string(),
55 ))?,
56 subject: self.subject.ok_or(ValidationError::MissingField(
57 "Subject is required".to_string(),
58 ))?,
59 created_at: self.created_at.unwrap_or(now),
60 updated_at: self.updated_at.unwrap_or(now),
61 })
62 }
63}
64
65#[derive(Debug, Clone, sqlx::FromRow)]
66pub struct PostgresOAuthAccount {
67 user_id: String,
68 provider: String,
69 subject: String,
70 created_at: DateTime<Utc>,
71 updated_at: DateTime<Utc>,
72}
73
74impl PostgresOAuthAccount {
75 pub fn builder() -> PostgresOAuthAccountBuilder {
76 PostgresOAuthAccountBuilder::default()
77 }
78
79 pub fn new(user_id: UserId, provider: impl Into<String>, subject: impl Into<String>) -> Self {
80 PostgresOAuthAccountBuilder::default()
81 .user_id(user_id)
82 .provider(provider.into())
83 .subject(subject.into())
84 .build()
85 .expect("Default builder should never fail")
86 }
87
88 pub fn is_expired(&self, ttl: Duration) -> bool {
89 Utc::now() > self.created_at + ttl
90 }
91}
92
93impl From<PostgresOAuthAccount> for OAuthAccount {
94 fn from(oauth_account: PostgresOAuthAccount) -> Self {
95 OAuthAccount::builder()
96 .user_id(UserId::new(&oauth_account.user_id))
97 .provider(oauth_account.provider)
98 .subject(oauth_account.subject)
99 .created_at(oauth_account.created_at)
100 .updated_at(oauth_account.updated_at)
101 .build()
102 .expect("Default builder should never fail")
103 }
104}
105
106impl From<OAuthAccount> for PostgresOAuthAccount {
107 fn from(oauth_account: OAuthAccount) -> Self {
108 PostgresOAuthAccount::builder()
109 .user_id(oauth_account.user_id)
110 .provider(oauth_account.provider)
111 .subject(oauth_account.subject)
112 .created_at(oauth_account.created_at)
113 .updated_at(oauth_account.updated_at)
114 .build()
115 .expect("Default builder should never fail")
116 }
117}
118
119#[async_trait]
120impl OAuthStorage for PostgresStorage {
121 async fn create_oauth_account(
122 &self,
123 provider: &str,
124 subject: &str,
125 user_id: &UserId,
126 ) -> Result<OAuthAccount, Self::Error> {
127 sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1::uuid, $2, $3, $4, $5)")
128 .bind(user_id.as_str())
129 .bind(provider)
130 .bind(subject)
131 .bind(Utc::now())
132 .bind(Utc::now())
133 .execute(&self.pool)
134 .await
135 .map_err(|e| {
136 tracing::error!(error = %e, "Failed to create oauth account");
137 StorageError::Database("Failed to create oauth account".to_string())
138 })?;
139
140 let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
141 r#"
142 SELECT user_id::text, provider, subject, created_at, updated_at
143 FROM oauth_accounts
144 WHERE user_id::text = $1
145 "#,
146 )
147 .bind(user_id.as_str())
148 .fetch_one(&self.pool)
149 .await
150 .map_err(|e| {
151 tracing::error!(error = %e, "Failed to get oauth account");
152 StorageError::Database("Failed to get oauth account".to_string())
153 })?;
154
155 Ok(oauth_account.into())
156 }
157
158 async fn store_pkce_verifier(
159 &self,
160 csrf_state: &str,
161 pkce_verifier: &str,
162 expires_in: chrono::Duration,
163 ) -> Result<(), Self::Error> {
164 sqlx::query(
165 "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES ($1::text, $2, $3) RETURNING value",
166 )
167 .bind(csrf_state)
168 .bind(pkce_verifier)
169 .bind(Utc::now() + expires_in)
170 .fetch_optional(&self.pool)
171 .await
172 .map_err(|e| {
173 tracing::error!(error = %e, "Failed to store pkce verifier");
174 StorageError::Database("Failed to store pkce verifier".to_string())
175 })?;
176
177 Ok(())
178 }
179
180 async fn get_pkce_verifier(&self, csrf_state: &str) -> Result<Option<String>, Self::Error> {
181 let pkce_verifier =
182 sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = $1")
183 .bind(csrf_state)
184 .fetch_optional(&self.pool)
185 .await
186 .map_err(|e| {
187 tracing::error!(error = %e, "Failed to get pkce verifier");
188 StorageError::Database("Failed to get pkce verifier".to_string())
189 })?;
190
191 Ok(pkce_verifier)
192 }
193
194 async fn get_oauth_account_by_provider_and_subject(
195 &self,
196 provider: &str,
197 subject: &str,
198 ) -> Result<Option<OAuthAccount>, Self::Error> {
199 let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
200 r#"
201 SELECT user_id::text, provider, subject, created_at, updated_at
202 FROM oauth_accounts
203 WHERE provider = $1 AND subject = $2
204 "#,
205 )
206 .bind(provider)
207 .bind(subject)
208 .fetch_optional(&self.pool)
209 .await
210 .map_err(|e| {
211 tracing::error!(error = %e, "Failed to get oauth account");
212 StorageError::Database("Failed to get oauth account".to_string())
213 })?;
214
215 if let Some(oauth_account) = oauth_account {
216 Ok(Some(oauth_account.into()))
217 } else {
218 Ok(None)
219 }
220 }
221
222 async fn get_user_by_provider_and_subject(
223 &self,
224 provider: &str,
225 subject: &str,
226 ) -> Result<Option<User>, Self::Error> {
227 let user = sqlx::query_as::<_, PostgresUser>(
228 r#"
229 SELECT id::text, email, name, email_verified_at, created_at, updated_at
230 FROM users
231 WHERE provider = $1 AND subject = $2
232 "#,
233 )
234 .bind(provider)
235 .bind(subject)
236 .fetch_optional(&self.pool)
237 .await
238 .map_err(|e| {
239 tracing::error!(error = %e, "Failed to get user by provider and subject");
240 StorageError::Database("Failed to get user by provider and subject".to_string())
241 })?;
242
243 if let Some(user) = user {
244 Ok(Some(user.into()))
245 } else {
246 Ok(None)
247 }
248 }
249
250 async fn link_oauth_account(
251 &self,
252 user_id: &UserId,
253 provider: &str,
254 subject: &str,
255 ) -> Result<(), Self::Error> {
256 sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1::uuid, $2, $3, $4, $5)")
257 .bind(user_id.as_str())
258 .bind(provider)
259 .bind(subject)
260 .bind(Utc::now())
261 .bind(Utc::now())
262 .execute(&self.pool)
263 .await
264 .map_err(|e| {
265 tracing::error!(error = %e, "Failed to link oauth account");
266 StorageError::Database("Failed to link oauth account".to_string())
267 })?;
268
269 Ok(())
270 }
271}