1use crate::{PostgresStorage, PostgresUser};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::time::Duration;
5use torii_core::error::{StorageError, ValidationError};
6use torii_core::storage::OAuthStorage;
7use torii_core::{OAuthAccount, User, UserId};
8
9#[derive(Default)]
10pub struct PostgresOAuthAccountBuilder {
11 user_id: Option<UserId>,
12 provider: Option<String>,
13 subject: Option<String>,
14 created_at: Option<DateTime<Utc>>,
15 updated_at: Option<DateTime<Utc>>,
16}
17
18impl PostgresOAuthAccountBuilder {
19 pub fn user_id(mut self, user_id: UserId) -> Self {
20 self.user_id = Some(user_id);
21 self
22 }
23
24 pub fn provider(mut self, provider: String) -> Self {
25 self.provider = Some(provider);
26 self
27 }
28
29 pub fn subject(mut self, subject: String) -> Self {
30 self.subject = Some(subject);
31 self
32 }
33
34 pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
35 self.created_at = Some(created_at);
36 self
37 }
38
39 pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
40 self.updated_at = Some(updated_at);
41 self
42 }
43
44 pub fn build(self) -> Result<PostgresOAuthAccount, torii_core::Error> {
45 let now = Utc::now();
46 Ok(PostgresOAuthAccount {
47 id: None,
48 user_id: self
49 .user_id
50 .ok_or(ValidationError::MissingField(
51 "User ID is required".to_string(),
52 ))?
53 .to_string(),
54 provider: self.provider.ok_or(ValidationError::MissingField(
55 "Provider is required".to_string(),
56 ))?,
57 subject: self.subject.ok_or(ValidationError::MissingField(
58 "Subject is required".to_string(),
59 ))?,
60 created_at: self.created_at.unwrap_or(now),
61 updated_at: self.updated_at.unwrap_or(now),
62 })
63 }
64}
65
66#[derive(Debug, Clone, sqlx::FromRow)]
67pub struct PostgresOAuthAccount {
68 pub id: Option<i64>,
69 pub user_id: String,
70 pub provider: String,
71 pub subject: String,
72 pub created_at: DateTime<Utc>,
73 pub updated_at: DateTime<Utc>,
74}
75
76impl PostgresOAuthAccount {
77 pub fn builder() -> PostgresOAuthAccountBuilder {
78 PostgresOAuthAccountBuilder::default()
79 }
80
81 pub fn new(user_id: UserId, provider: impl Into<String>, subject: impl Into<String>) -> Self {
82 PostgresOAuthAccountBuilder::default()
83 .user_id(user_id)
84 .provider(provider.into())
85 .subject(subject.into())
86 .build()
87 .expect("Default builder should never fail")
88 }
89
90 pub fn is_expired(&self, ttl: Duration) -> bool {
91 Utc::now() > self.created_at + ttl
92 }
93}
94
95impl From<PostgresOAuthAccount> for OAuthAccount {
96 fn from(oauth_account: PostgresOAuthAccount) -> Self {
97 OAuthAccount::builder()
98 .user_id(UserId::new(&oauth_account.user_id))
99 .provider(oauth_account.provider)
100 .subject(oauth_account.subject)
101 .created_at(oauth_account.created_at)
102 .updated_at(oauth_account.updated_at)
103 .build()
104 .expect("Default builder should never fail")
105 }
106}
107
108impl From<OAuthAccount> for PostgresOAuthAccount {
109 fn from(oauth_account: OAuthAccount) -> Self {
110 PostgresOAuthAccount::builder()
111 .user_id(oauth_account.user_id)
112 .provider(oauth_account.provider)
113 .subject(oauth_account.subject)
114 .created_at(oauth_account.created_at)
115 .updated_at(oauth_account.updated_at)
116 .build()
117 .expect("Default builder should never fail")
118 }
119}
120
121#[async_trait]
122impl OAuthStorage for PostgresStorage {
123 async fn create_oauth_account(
124 &self,
125 provider: &str,
126 subject: &str,
127 user_id: &UserId,
128 ) -> Result<OAuthAccount, torii_core::Error> {
129 sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
130 .bind(user_id.as_str())
131 .bind(provider)
132 .bind(subject)
133 .bind(Utc::now())
134 .bind(Utc::now())
135 .execute(&self.pool)
136 .await
137 .map_err(|e| {
138 tracing::error!(error = %e, "Failed to create oauth account");
139 StorageError::Database("Failed to create oauth account".to_string())
140 })?;
141
142 let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
143 r#"
144 SELECT id, user_id, provider, subject, created_at, updated_at
145 FROM oauth_accounts
146 WHERE user_id = $1
147 "#,
148 )
149 .bind(user_id.as_str())
150 .fetch_one(&self.pool)
151 .await
152 .map_err(|e| {
153 tracing::error!(error = %e, "Failed to get oauth account");
154 StorageError::Database("Failed to get oauth account".to_string())
155 })?;
156
157 Ok(oauth_account.into())
158 }
159
160 async fn store_pkce_verifier(
161 &self,
162 csrf_state: &str,
163 pkce_verifier: &str,
164 expires_in: chrono::Duration,
165 ) -> Result<(), torii_core::Error> {
166 sqlx::query(
167 "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES ($1, $2, $3) RETURNING value",
168 )
169 .bind(csrf_state)
170 .bind(pkce_verifier)
171 .bind(Utc::now() + expires_in)
172 .fetch_optional(&self.pool)
173 .await
174 .map_err(|e| {
175 tracing::error!(error = %e, "Failed to store pkce verifier");
176 StorageError::Database("Failed to store pkce verifier".to_string())
177 })?;
178
179 Ok(())
180 }
181
182 async fn get_pkce_verifier(
183 &self,
184 csrf_state: &str,
185 ) -> Result<Option<String>, torii_core::Error> {
186 let pkce_verifier =
187 sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = $1")
188 .bind(csrf_state)
189 .fetch_optional(&self.pool)
190 .await
191 .map_err(|e| {
192 tracing::error!(error = %e, "Failed to get pkce verifier");
193 StorageError::Database("Failed to get pkce verifier".to_string())
194 })?;
195
196 Ok(pkce_verifier)
197 }
198
199 async fn get_oauth_account_by_provider_and_subject(
200 &self,
201 provider: &str,
202 subject: &str,
203 ) -> Result<Option<OAuthAccount>, torii_core::Error> {
204 let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
205 r#"
206 SELECT id, user_id, provider, subject, created_at, updated_at
207 FROM oauth_accounts
208 WHERE provider = $1 AND subject = $2
209 "#,
210 )
211 .bind(provider)
212 .bind(subject)
213 .fetch_optional(&self.pool)
214 .await
215 .map_err(|e| {
216 tracing::error!(error = %e, "Failed to get oauth account");
217 StorageError::Database("Failed to get oauth account".to_string())
218 })?;
219
220 if let Some(oauth_account) = oauth_account {
221 Ok(Some(oauth_account.into()))
222 } else {
223 Ok(None)
224 }
225 }
226
227 async fn get_user_by_provider_and_subject(
228 &self,
229 provider: &str,
230 subject: &str,
231 ) -> Result<Option<User>, torii_core::Error> {
232 let user = sqlx::query_as::<_, PostgresUser>(
233 r#"
234 SELECT id, email, name, email_verified_at, created_at, updated_at
235 FROM users
236 WHERE provider = $1 AND subject = $2
237 "#,
238 )
239 .bind(provider)
240 .bind(subject)
241 .fetch_optional(&self.pool)
242 .await
243 .map_err(|e| {
244 tracing::error!(error = %e, "Failed to get user by provider and subject");
245 StorageError::Database("Failed to get user by provider and subject".to_string())
246 })?;
247
248 if let Some(user) = user {
249 Ok(Some(user.into()))
250 } else {
251 Ok(None)
252 }
253 }
254
255 async fn link_oauth_account(
256 &self,
257 user_id: &UserId,
258 provider: &str,
259 subject: &str,
260 ) -> Result<(), torii_core::Error> {
261 sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
262 .bind(user_id.as_str())
263 .bind(provider)
264 .bind(subject)
265 .bind(Utc::now())
266 .bind(Utc::now())
267 .execute(&self.pool)
268 .await
269 .map_err(|e| {
270 tracing::error!(error = %e, "Failed to link oauth account");
271 StorageError::Database("Failed to link oauth account".to_string())
272 })?;
273
274 Ok(())
275 }
276}