torii_core/
storage.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    Error, OAuthAccount, Session, User, UserId, error::ValidationError, session::SessionToken,
7};
8
9#[async_trait]
10pub trait StoragePlugin: Send + Sync + 'static {
11    type Config;
12
13    /// Initialize storage with config
14    async fn initialize(&self, config: Self::Config) -> Result<(), Error>;
15
16    /// Storage health check
17    async fn health_check(&self) -> Result<(), Error>;
18
19    /// Clean up expired data
20    async fn cleanup(&self) -> Result<(), Error>;
21}
22
23#[async_trait]
24pub trait UserStorage: Send + Sync + 'static {
25    type Error: std::error::Error + Send + Sync + 'static;
26
27    async fn create_user(&self, user: &NewUser) -> Result<User, Self::Error>;
28    async fn get_user(&self, id: &UserId) -> Result<Option<User>, Self::Error>;
29    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>, Self::Error>;
30    async fn get_or_create_user_by_email(&self, email: &str) -> Result<User, Self::Error>;
31    async fn update_user(&self, user: &User) -> Result<User, Self::Error>;
32    async fn delete_user(&self, id: &UserId) -> Result<(), Self::Error>;
33    async fn set_user_email_verified(&self, user_id: &UserId) -> Result<(), Self::Error>;
34}
35
36#[async_trait]
37pub trait SessionStorage: Send + Sync + 'static {
38    type Error: std::error::Error + Send + Sync + 'static;
39
40    async fn create_session(&self, session: &Session) -> Result<Session, Self::Error>;
41    async fn get_session(&self, token: &SessionToken) -> Result<Option<Session>, Self::Error>;
42    async fn delete_session(&self, token: &SessionToken) -> Result<(), Self::Error>;
43    async fn cleanup_expired_sessions(&self) -> Result<(), Self::Error>;
44    async fn delete_sessions_for_user(&self, user_id: &UserId) -> Result<(), Self::Error>;
45}
46
47/// Storage methods specific to email/password authentication
48///
49/// This trait extends the base `UserStorage` trait with methods needed for
50/// storing and retrieving password hashes.
51#[async_trait]
52pub trait PasswordStorage: UserStorage {
53    type Error: std::error::Error + Send + Sync + 'static;
54    /// Store a password hash for a user
55    async fn set_password_hash(
56        &self,
57        user_id: &UserId,
58        hash: &str,
59    ) -> Result<(), <Self as PasswordStorage>::Error>;
60
61    /// Retrieve a user's password hash
62    async fn get_password_hash(
63        &self,
64        user_id: &UserId,
65    ) -> Result<Option<String>, <Self as PasswordStorage>::Error>;
66}
67
68/// Storage methods specific to OAuth authentication
69///
70/// This trait extends the base `UserStorage` trait with methods needed for
71/// OAuth account management and PKCE verifier storage.
72#[async_trait]
73pub trait OAuthStorage: UserStorage {
74    type Error: std::error::Error + Send + Sync + 'static;
75    /// Create a new OAuth account linked to a user
76    async fn create_oauth_account(
77        &self,
78        provider: &str,
79        subject: &str,
80        user_id: &UserId,
81    ) -> Result<OAuthAccount, <Self as OAuthStorage>::Error>;
82
83    /// Find a user by their OAuth provider and subject
84    async fn get_user_by_provider_and_subject(
85        &self,
86        provider: &str,
87        subject: &str,
88    ) -> Result<Option<User>, <Self as OAuthStorage>::Error>;
89
90    /// Find an OAuth account by provider and subject
91    async fn get_oauth_account_by_provider_and_subject(
92        &self,
93        provider: &str,
94        subject: &str,
95    ) -> Result<Option<OAuthAccount>, <Self as OAuthStorage>::Error>;
96
97    /// Link an existing user to an OAuth account
98    async fn link_oauth_account(
99        &self,
100        user_id: &UserId,
101        provider: &str,
102        subject: &str,
103    ) -> Result<(), <Self as OAuthStorage>::Error>;
104
105    /// Store a PKCE verifier with an expiration time
106    async fn store_pkce_verifier(
107        &self,
108        csrf_state: &str,
109        pkce_verifier: &str,
110        expires_in: chrono::Duration,
111    ) -> Result<(), <Self as OAuthStorage>::Error>;
112
113    /// Retrieve a stored PKCE verifier by CSRF state
114    async fn get_pkce_verifier(
115        &self,
116        csrf_state: &str,
117    ) -> Result<Option<String>, <Self as OAuthStorage>::Error>;
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct NewUser {
122    pub id: UserId,
123    pub email: String,
124    pub name: Option<String>,
125    pub email_verified_at: Option<DateTime<Utc>>,
126}
127
128impl NewUser {
129    pub fn builder() -> NewUserBuilder {
130        NewUserBuilder::default()
131    }
132
133    pub fn new(email: String) -> Self {
134        NewUserBuilder::default()
135            .email(email)
136            .build()
137            .expect("Default builder should never fail")
138    }
139
140    pub fn with_id(id: UserId, email: String) -> Self {
141        NewUserBuilder::default()
142            .id(id)
143            .email(email)
144            .build()
145            .expect("Default builder should never fail")
146    }
147}
148
149#[derive(Default)]
150pub struct NewUserBuilder {
151    id: Option<UserId>,
152    email: Option<String>,
153    name: Option<String>,
154    email_verified_at: Option<DateTime<Utc>>,
155}
156
157impl NewUserBuilder {
158    pub fn id(mut self, id: UserId) -> Self {
159        self.id = Some(id);
160        self
161    }
162
163    pub fn email(mut self, email: String) -> Self {
164        self.email = Some(email);
165        self
166    }
167
168    pub fn name(mut self, name: String) -> Self {
169        self.name = Some(name);
170        self
171    }
172
173    pub fn email_verified_at(mut self, email_verified_at: Option<DateTime<Utc>>) -> Self {
174        self.email_verified_at = email_verified_at;
175        self
176    }
177
178    pub fn build(self) -> Result<NewUser, Error> {
179        Ok(NewUser {
180            id: self.id.unwrap_or_default(),
181            email: self.email.ok_or(ValidationError::MissingField(
182                "Email is required".to_string(),
183            ))?,
184            name: self.name,
185            email_verified_at: self.email_verified_at,
186        })
187    }
188}
189
190/// Storage methods specific to passkey authentication
191///
192/// This trait extends the base `UserStorage` trait with methods needed for
193/// storing and retrieving passkey credentials for a user.
194#[async_trait]
195pub trait PasskeyStorage: UserStorage {
196    type Error: std::error::Error + Send + Sync + 'static;
197
198    /// Add a passkey credential for a user
199    async fn add_passkey(
200        &self,
201        user_id: &UserId,
202        credential_id: &str,
203        passkey_json: &str,
204    ) -> Result<(), <Self as PasskeyStorage>::Error>;
205
206    /// Get a passkey by credential ID
207    async fn get_passkey_by_credential_id(
208        &self,
209        credential_id: &str,
210    ) -> Result<Option<String>, <Self as PasskeyStorage>::Error>;
211
212    /// Get all passkeys for a user
213    async fn get_passkeys(
214        &self,
215        user_id: &UserId,
216    ) -> Result<Vec<String>, <Self as PasskeyStorage>::Error>;
217
218    /// Set a passkey challenge for a user
219    async fn set_passkey_challenge(
220        &self,
221        challenge_id: &str,
222        challenge: &str,
223        expires_in: chrono::Duration,
224    ) -> Result<(), <Self as PasskeyStorage>::Error>;
225
226    /// Get a passkey challenge
227    async fn get_passkey_challenge(
228        &self,
229        challenge_id: &str,
230    ) -> Result<Option<String>, <Self as PasskeyStorage>::Error>;
231}
232
233#[derive(Debug, Clone)]
234pub struct MagicToken {
235    pub user_id: UserId,
236    pub token: String,
237    pub used_at: Option<DateTime<Utc>>,
238    pub expires_at: DateTime<Utc>,
239    pub created_at: DateTime<Utc>,
240    pub updated_at: DateTime<Utc>,
241}
242
243impl MagicToken {
244    pub fn new(
245        user_id: UserId,
246        token: String,
247        used_at: Option<DateTime<Utc>>,
248        expires_at: DateTime<Utc>,
249        created_at: DateTime<Utc>,
250        updated_at: DateTime<Utc>,
251    ) -> Self {
252        Self {
253            user_id,
254            token,
255            used_at,
256            expires_at,
257            created_at,
258            updated_at,
259        }
260    }
261
262    pub fn used(&self) -> bool {
263        self.used_at.is_some()
264    }
265}
266
267impl PartialEq for MagicToken {
268    fn eq(&self, other: &Self) -> bool {
269        self.user_id == other.user_id
270            && self.token == other.token
271            && self.used_at == other.used_at
272            // Some databases may not store the timestamp with more precision than seconds, so we compare the timestamps as integers
273            && self.expires_at.timestamp() == other.expires_at.timestamp()
274            && self.created_at.timestamp() == other.created_at.timestamp()
275            && self.updated_at.timestamp() == other.updated_at.timestamp()
276    }
277}
278
279#[async_trait]
280pub trait MagicLinkStorage: UserStorage {
281    type Error: std::error::Error + Send + Sync + 'static;
282
283    async fn save_magic_token(
284        &self,
285        token: &MagicToken,
286    ) -> Result<(), <Self as MagicLinkStorage>::Error>;
287    async fn get_magic_token(
288        &self,
289        token: &str,
290    ) -> Result<Option<MagicToken>, <Self as MagicLinkStorage>::Error>;
291    async fn set_magic_token_used(
292        &self,
293        token: &str,
294    ) -> Result<(), <Self as MagicLinkStorage>::Error>;
295}