1use crate::{PostgresStorage, PostgresUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::time::Duration;
5use torii_core::error::{StorageError, ValidationError};
6use torii_core::storage::OAuthStorage;
7use torii_core::{OAuthAccount, User, UserId};
8
9#[derive(Default)]
10pub struct PostgresOAuthAccountBuilder {
11 user_id: Option<UserId>,
12 provider: Option<String>,
13 subject: Option<String>,
14 created_at: Option<DateTime<Utc>>,
15 updated_at: Option<DateTime<Utc>>,
16}
17
18impl PostgresOAuthAccountBuilder {
19 pub fn user_id(mut self, user_id: UserId) -> Self {
20 self.user_id = Some(user_id);
21 self
22 }
23
24 pub fn provider(mut self, provider: String) -> Self {
25 self.provider = Some(provider);
26 self
27 }
28
29 pub fn subject(mut self, subject: String) -> Self {
30 self.subject = Some(subject);
31 self
32 }
33
34 pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
35 self.created_at = Some(created_at);
36 self
37 }
38
39 pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
40 self.updated_at = Some(updated_at);
41 self
42 }
43
44 pub fn build(self) -> Result<PostgresOAuthAccount, torii_core::Error> {
45 let now = Utc::now();
46 Ok(PostgresOAuthAccount {
47 id: None,
48 user_id: self
49 .user_id
50 .ok_or(ValidationError::MissingField(
51 "User ID is required".to_string(),
52 ))?
53 .to_string(),
54 provider: self.provider.ok_or(ValidationError::MissingField(
55 "Provider is required".to_string(),
56 ))?,
57 subject: self.subject.ok_or(ValidationError::MissingField(
58 "Subject is required".to_string(),
59 ))?,
60 created_at: self.created_at.unwrap_or(now),
61 updated_at: self.updated_at.unwrap_or(now),
62 })
63 }
64}
65
66#[derive(Debug, Clone, sqlx::FromRow)]
67pub struct PostgresOAuthAccount {
68 pub id: Option<i64>,
69 pub user_id: String,
70 pub provider: String,
71 pub subject: String,
72 pub created_at: DateTime<Utc>,
73 pub updated_at: DateTime<Utc>,
74}
75
76impl PostgresOAuthAccount {
77 pub fn builder() -> PostgresOAuthAccountBuilder {
78 PostgresOAuthAccountBuilder::default()
79 }
80
81 pub fn new(user_id: UserId, provider: impl Into<String>, subject: impl Into<String>) -> Self {
82 PostgresOAuthAccountBuilder::default()
83 .user_id(user_id)
84 .provider(provider.into())
85 .subject(subject.into())
86 .build()
87 .expect("Default builder should never fail")
88 }
89
90 pub fn is_expired(&self, ttl: Duration) -> bool {
91 Utc::now() > self.created_at + ttl
92 }
93}
94
95impl From<PostgresOAuthAccount> for OAuthAccount {
96 fn from(oauth_account: PostgresOAuthAccount) -> Self {
97 OAuthAccount::builder()
98 .user_id(UserId::new(&oauth_account.user_id))
99 .provider(oauth_account.provider)
100 .subject(oauth_account.subject)
101 .created_at(oauth_account.created_at)
102 .updated_at(oauth_account.updated_at)
103 .build()
104 .expect("Default builder should never fail")
105 }
106}
107
108impl From<OAuthAccount> for PostgresOAuthAccount {
109 fn from(oauth_account: OAuthAccount) -> Self {
110 PostgresOAuthAccount::builder()
111 .user_id(oauth_account.user_id)
112 .provider(oauth_account.provider)
113 .subject(oauth_account.subject)
114 .created_at(oauth_account.created_at)
115 .updated_at(oauth_account.updated_at)
116 .build()
117 .expect("Default builder should never fail")
118 }
119}
120
121#[async_trait]
122impl OAuthStorage for PostgresStorage {
123 type Error = StorageError;
124
125 async fn create_oauth_account(
126 &self,
127 provider: &str,
128 subject: &str,
129 user_id: &UserId,
130 ) -> Result<OAuthAccount, <Self as OAuthStorage>::Error> {
131 sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
132 .bind(user_id.as_str())
133 .bind(provider)
134 .bind(subject)
135 .bind(Utc::now())
136 .bind(Utc::now())
137 .execute(&self.pool)
138 .await
139 .map_err(|e| {
140 tracing::error!(error = %e, "Failed to create oauth account");
141 StorageError::Database("Failed to create oauth account".to_string())
142 })?;
143
144 let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
145 r#"
146 SELECT id, user_id, provider, subject, created_at, updated_at
147 FROM oauth_accounts
148 WHERE user_id = $1
149 "#,
150 )
151 .bind(user_id.as_str())
152 .fetch_one(&self.pool)
153 .await
154 .map_err(|e| {
155 tracing::error!(error = %e, "Failed to get oauth account");
156 StorageError::Database("Failed to get oauth account".to_string())
157 })?;
158
159 Ok(oauth_account.into())
160 }
161
162 async fn store_pkce_verifier(
163 &self,
164 csrf_state: &str,
165 pkce_verifier: &str,
166 expires_in: chrono::Duration,
167 ) -> Result<(), <Self as OAuthStorage>::Error> {
168 sqlx::query(
169 "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES ($1, $2, $3) RETURNING value",
170 )
171 .bind(csrf_state)
172 .bind(pkce_verifier)
173 .bind(Utc::now() + expires_in)
174 .fetch_optional(&self.pool)
175 .await
176 .map_err(|e| {
177 tracing::error!(error = %e, "Failed to store pkce verifier");
178 StorageError::Database("Failed to store pkce verifier".to_string())
179 })?;
180
181 Ok(())
182 }
183
184 async fn get_pkce_verifier(
185 &self,
186 csrf_state: &str,
187 ) -> Result<Option<String>, <Self as OAuthStorage>::Error> {
188 let pkce_verifier =
189 sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = $1")
190 .bind(csrf_state)
191 .fetch_optional(&self.pool)
192 .await
193 .map_err(|e| {
194 tracing::error!(error = %e, "Failed to get pkce verifier");
195 StorageError::Database("Failed to get pkce verifier".to_string())
196 })?;
197
198 Ok(pkce_verifier)
199 }
200
201 async fn get_oauth_account_by_provider_and_subject(
202 &self,
203 provider: &str,
204 subject: &str,
205 ) -> Result<Option<OAuthAccount>, <Self as OAuthStorage>::Error> {
206 let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
207 r#"
208 SELECT id, user_id, provider, subject, created_at, updated_at
209 FROM oauth_accounts
210 WHERE provider = $1 AND subject = $2
211 "#,
212 )
213 .bind(provider)
214 .bind(subject)
215 .fetch_optional(&self.pool)
216 .await
217 .map_err(|e| {
218 tracing::error!(error = %e, "Failed to get oauth account");
219 StorageError::Database("Failed to get oauth account".to_string())
220 })?;
221
222 if let Some(oauth_account) = oauth_account {
223 Ok(Some(oauth_account.into()))
224 } else {
225 Ok(None)
226 }
227 }
228
229 async fn get_user_by_provider_and_subject(
230 &self,
231 provider: &str,
232 subject: &str,
233 ) -> Result<Option<User>, <Self as OAuthStorage>::Error> {
234 let user = sqlx::query_as::<_, PostgresUser>(
235 r#"
236 SELECT id, email, name, email_verified_at, created_at, updated_at
237 FROM users
238 WHERE provider = $1 AND subject = $2
239 "#,
240 )
241 .bind(provider)
242 .bind(subject)
243 .fetch_optional(&self.pool)
244 .await
245 .map_err(|e| {
246 tracing::error!(error = %e, "Failed to get user by provider and subject");
247 StorageError::Database("Failed to get user by provider and subject".to_string())
248 })?;
249
250 if let Some(user) = user {
251 Ok(Some(user.into()))
252 } else {
253 Ok(None)
254 }
255 }
256
257 async fn link_oauth_account(
258 &self,
259 user_id: &UserId,
260 provider: &str,
261 subject: &str,
262 ) -> Result<(), <Self as OAuthStorage>::Error> {
263 sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
264 .bind(user_id.as_str())
265 .bind(provider)
266 .bind(subject)
267 .bind(Utc::now())
268 .bind(Utc::now())
269 .execute(&self.pool)
270 .await
271 .map_err(|e| {
272 tracing::error!(error = %e, "Failed to link oauth account");
273 StorageError::Database("Failed to link oauth account".to_string())
274 })?;
275
276 Ok(())
277 }
278}