torii_core/
storage.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    Error, OAuthAccount, Session, User, UserId, error::ValidationError, session::SessionId,
7};
8
9#[async_trait]
10pub trait StoragePlugin: Send + Sync + 'static {
11    type Config;
12
13    /// Initialize storage with config
14    async fn initialize(&self, config: Self::Config) -> Result<(), Error>;
15
16    /// Storage health check
17    async fn health_check(&self) -> Result<(), Error>;
18
19    /// Clean up expired data
20    async fn cleanup(&self) -> Result<(), Error>;
21}
22
23#[async_trait]
24pub trait UserStorage: Send + Sync + 'static {
25    type Error: std::error::Error + Send + Sync + 'static;
26
27    async fn create_user(&self, user: &NewUser) -> Result<User, Self::Error>;
28    async fn get_user(&self, id: &UserId) -> Result<Option<User>, Self::Error>;
29    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>, Self::Error>;
30    async fn get_or_create_user_by_email(&self, email: &str) -> Result<User, Self::Error>;
31    async fn update_user(&self, user: &User) -> Result<User, Self::Error>;
32    async fn delete_user(&self, id: &UserId) -> Result<(), Self::Error>;
33    async fn set_user_email_verified(&self, user_id: &UserId) -> Result<(), Self::Error>;
34}
35
36#[async_trait]
37pub trait SessionStorage: Send + Sync + 'static {
38    type Error: std::error::Error + Send + Sync + 'static;
39
40    async fn create_session(&self, session: &Session) -> Result<Session, Self::Error>;
41    async fn get_session(&self, id: &SessionId) -> Result<Session, Self::Error>;
42    async fn delete_session(&self, id: &SessionId) -> Result<(), Self::Error>;
43    async fn cleanup_expired_sessions(&self) -> Result<(), Self::Error>;
44    async fn delete_sessions_for_user(&self, user_id: &UserId) -> Result<(), Self::Error>;
45}
46
47/// Storage methods specific to email/password authentication
48///
49/// This trait extends the base `UserStorage` trait with methods needed for
50/// storing and retrieving password hashes.
51#[async_trait]
52pub trait PasswordStorage: UserStorage {
53    /// Store a password hash for a user
54    async fn set_password_hash(&self, user_id: &UserId, hash: &str) -> Result<(), Self::Error>;
55
56    /// Retrieve a user's password hash
57    async fn get_password_hash(&self, user_id: &UserId) -> Result<Option<String>, Self::Error>;
58}
59
60/// Storage methods specific to OAuth authentication
61///
62/// This trait extends the base `UserStorage` trait with methods needed for
63/// OAuth account management and PKCE verifier storage.
64#[async_trait]
65pub trait OAuthStorage: UserStorage {
66    /// Create a new OAuth account linked to a user
67    async fn create_oauth_account(
68        &self,
69        provider: &str,
70        subject: &str,
71        user_id: &UserId,
72    ) -> Result<OAuthAccount, Self::Error>;
73
74    /// Find a user by their OAuth provider and subject
75    async fn get_user_by_provider_and_subject(
76        &self,
77        provider: &str,
78        subject: &str,
79    ) -> Result<Option<User>, Self::Error>;
80
81    /// Find an OAuth account by provider and subject
82    async fn get_oauth_account_by_provider_and_subject(
83        &self,
84        provider: &str,
85        subject: &str,
86    ) -> Result<Option<OAuthAccount>, Self::Error>;
87
88    /// Link an existing user to an OAuth account
89    async fn link_oauth_account(
90        &self,
91        user_id: &UserId,
92        provider: &str,
93        subject: &str,
94    ) -> Result<(), Self::Error>;
95
96    /// Store a PKCE verifier with an expiration time
97    async fn store_pkce_verifier(
98        &self,
99        csrf_state: &str,
100        pkce_verifier: &str,
101        expires_in: chrono::Duration,
102    ) -> Result<(), Self::Error>;
103
104    /// Retrieve a stored PKCE verifier by CSRF state
105    async fn get_pkce_verifier(&self, csrf_state: &str) -> Result<Option<String>, Self::Error>;
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct NewUser {
110    pub id: UserId,
111    pub email: String,
112    pub name: Option<String>,
113    pub email_verified_at: Option<DateTime<Utc>>,
114}
115
116impl NewUser {
117    pub fn builder() -> NewUserBuilder {
118        NewUserBuilder::default()
119    }
120
121    pub fn new(email: String) -> Self {
122        NewUserBuilder::default()
123            .email(email)
124            .build()
125            .expect("Default builder should never fail")
126    }
127
128    pub fn with_id(id: UserId, email: String) -> Self {
129        NewUserBuilder::default()
130            .id(id)
131            .email(email)
132            .build()
133            .expect("Default builder should never fail")
134    }
135}
136
137#[derive(Default)]
138pub struct NewUserBuilder {
139    id: Option<UserId>,
140    email: Option<String>,
141    name: Option<String>,
142    email_verified_at: Option<DateTime<Utc>>,
143}
144
145impl NewUserBuilder {
146    pub fn id(mut self, id: UserId) -> Self {
147        self.id = Some(id);
148        self
149    }
150
151    pub fn email(mut self, email: String) -> Self {
152        self.email = Some(email);
153        self
154    }
155
156    pub fn name(mut self, name: String) -> Self {
157        self.name = Some(name);
158        self
159    }
160
161    pub fn email_verified_at(mut self, email_verified_at: Option<DateTime<Utc>>) -> Self {
162        self.email_verified_at = email_verified_at;
163        self
164    }
165
166    pub fn build(self) -> Result<NewUser, Error> {
167        Ok(NewUser {
168            id: self.id.unwrap_or(UserId::new_random()),
169            email: self.email.ok_or(ValidationError::MissingField(
170                "Email is required".to_string(),
171            ))?,
172            name: self.name,
173            email_verified_at: self.email_verified_at,
174        })
175    }
176}
177
178/// Storage methods specific to passkey authentication
179///
180/// This trait extends the base `UserStorage` trait with methods needed for
181/// storing and retrieving passkey credentials for a user.
182#[async_trait]
183pub trait PasskeyStorage: UserStorage {
184    /// Add a passkey credential for a user
185    async fn add_passkey(
186        &self,
187        user_id: &UserId,
188        credential_id: &str,
189        passkey_json: &str,
190    ) -> Result<(), Self::Error>;
191
192    /// Get a passkey by credential ID
193    async fn get_passkey_by_credential_id(
194        &self,
195        credential_id: &str,
196    ) -> Result<Option<String>, Self::Error>;
197
198    /// Get all passkeys for a user
199    async fn get_passkeys(&self, user_id: &UserId) -> Result<Vec<String>, Self::Error>;
200
201    /// Set a passkey challenge for a user
202    async fn set_passkey_challenge(
203        &self,
204        challenge_id: &str,
205        challenge: &str,
206        expires_in: chrono::Duration,
207    ) -> Result<(), Self::Error>;
208
209    /// Get a passkey challenge
210    async fn get_passkey_challenge(
211        &self,
212        challenge_id: &str,
213    ) -> Result<Option<String>, Self::Error>;
214}
215
216#[derive(Debug, Clone)]
217pub struct MagicToken {
218    pub user_id: UserId,
219    pub token: String,
220    pub used_at: Option<DateTime<Utc>>,
221    pub expires_at: DateTime<Utc>,
222    pub created_at: DateTime<Utc>,
223    pub updated_at: DateTime<Utc>,
224}
225
226impl MagicToken {
227    pub fn new(
228        user_id: UserId,
229        token: String,
230        used_at: Option<DateTime<Utc>>,
231        expires_at: DateTime<Utc>,
232        created_at: DateTime<Utc>,
233        updated_at: DateTime<Utc>,
234    ) -> Self {
235        Self {
236            user_id,
237            token,
238            used_at,
239            expires_at,
240            created_at,
241            updated_at,
242        }
243    }
244
245    pub fn used(&self) -> bool {
246        self.used_at.is_some()
247    }
248}
249
250impl PartialEq for MagicToken {
251    fn eq(&self, other: &Self) -> bool {
252        self.user_id == other.user_id
253            && self.token == other.token
254            && self.used_at == other.used_at
255            // Some databases may not store the timestamp with more precision than seconds, so we compare the timestamps as integers
256            && self.expires_at.timestamp() == other.expires_at.timestamp()
257            && self.created_at.timestamp() == other.created_at.timestamp()
258            && self.updated_at.timestamp() == other.updated_at.timestamp()
259    }
260}
261
262#[async_trait]
263pub trait MagicLinkStorage: UserStorage {
264    async fn save_magic_token(&self, token: &MagicToken) -> Result<(), Self::Error>;
265    async fn get_magic_token(&self, token: &str) -> Result<Option<MagicToken>, Self::Error>;
266    async fn set_magic_token_used(&self, token: &str) -> Result<(), Self::Error>;
267}