torii_core/repositories/
adapter.rs

1use crate::{
2    Error, OAuthAccount, Session, SessionStorage, User, UserId,
3    repositories::{
4        OAuthRepository, PasskeyCredential, PasskeyRepository, PasswordRepository,
5        RepositoryProvider, SessionRepository, TokenRepository, UserRepository,
6    },
7    session::SessionToken,
8    storage::{NewUser, SecureToken, TokenPurpose},
9};
10use async_trait::async_trait;
11use chrono::Duration;
12use std::sync::Arc;
13
14/// Adapter that wraps a RepositoryProvider and implements individual repository traits
15pub struct UserRepositoryAdapter<R: RepositoryProvider> {
16    provider: Arc<R>,
17}
18
19impl<R: RepositoryProvider> UserRepositoryAdapter<R> {
20    pub fn new(provider: Arc<R>) -> Self {
21        Self { provider }
22    }
23}
24
25#[async_trait]
26impl<R: RepositoryProvider> UserRepository for UserRepositoryAdapter<R> {
27    async fn create(&self, user: NewUser) -> Result<User, Error> {
28        self.provider.user().create(user).await
29    }
30
31    async fn find_by_id(&self, id: &UserId) -> Result<Option<User>, Error> {
32        self.provider.user().find_by_id(id).await
33    }
34
35    async fn find_by_email(&self, email: &str) -> Result<Option<User>, Error> {
36        self.provider.user().find_by_email(email).await
37    }
38
39    async fn find_or_create_by_email(&self, email: &str) -> Result<User, Error> {
40        self.provider.user().find_or_create_by_email(email).await
41    }
42
43    async fn update(&self, user: &User) -> Result<User, Error> {
44        self.provider.user().update(user).await
45    }
46
47    async fn delete(&self, id: &UserId) -> Result<(), Error> {
48        self.provider.user().delete(id).await
49    }
50
51    async fn mark_email_verified(&self, user_id: &UserId) -> Result<(), Error> {
52        self.provider.user().mark_email_verified(user_id).await
53    }
54}
55
56pub struct SessionRepositoryAdapter<R: RepositoryProvider> {
57    provider: Arc<R>,
58}
59
60impl<R: RepositoryProvider> SessionRepositoryAdapter<R> {
61    pub fn new(provider: Arc<R>) -> Self {
62        Self { provider }
63    }
64}
65
66#[async_trait]
67impl<R: RepositoryProvider> SessionRepository for SessionRepositoryAdapter<R> {
68    async fn create(&self, session: Session) -> Result<Session, Error> {
69        self.provider.session().create(session).await
70    }
71
72    async fn find_by_token(&self, token: &SessionToken) -> Result<Option<Session>, Error> {
73        self.provider.session().find_by_token(token).await
74    }
75
76    async fn delete(&self, token: &SessionToken) -> Result<(), Error> {
77        self.provider.session().delete(token).await
78    }
79
80    async fn delete_by_user_id(&self, user_id: &UserId) -> Result<(), Error> {
81        self.provider.session().delete_by_user_id(user_id).await
82    }
83
84    async fn cleanup_expired(&self) -> Result<(), Error> {
85        self.provider.session().cleanup_expired().await
86    }
87}
88
89/// Implementation of SessionStorage for SessionRepositoryAdapter
90/// This allows the adapter to be used with the OpaqueSessionProvider
91#[async_trait]
92impl<R: RepositoryProvider> SessionStorage for SessionRepositoryAdapter<R> {
93    async fn create_session(&self, session: &Session) -> Result<Session, Error> {
94        self.create(session.clone()).await
95    }
96
97    async fn get_session(&self, token: &SessionToken) -> Result<Option<Session>, Error> {
98        self.find_by_token(token).await
99    }
100
101    async fn delete_session(&self, token: &SessionToken) -> Result<(), Error> {
102        self.delete(token).await
103    }
104
105    async fn cleanup_expired_sessions(&self) -> Result<(), Error> {
106        self.cleanup_expired().await
107    }
108
109    async fn delete_sessions_for_user(&self, user_id: &UserId) -> Result<(), Error> {
110        self.delete_by_user_id(user_id).await
111    }
112}
113
114pub struct PasswordRepositoryAdapter<R: RepositoryProvider> {
115    provider: Arc<R>,
116}
117
118impl<R: RepositoryProvider> PasswordRepositoryAdapter<R> {
119    pub fn new(provider: Arc<R>) -> Self {
120        Self { provider }
121    }
122}
123
124#[async_trait]
125impl<R: RepositoryProvider> PasswordRepository for PasswordRepositoryAdapter<R> {
126    async fn set_password_hash(&self, user_id: &UserId, hash: &str) -> Result<(), Error> {
127        self.provider
128            .password()
129            .set_password_hash(user_id, hash)
130            .await
131    }
132
133    async fn get_password_hash(&self, user_id: &UserId) -> Result<Option<String>, Error> {
134        self.provider.password().get_password_hash(user_id).await
135    }
136
137    async fn remove_password_hash(&self, user_id: &UserId) -> Result<(), Error> {
138        self.provider.password().remove_password_hash(user_id).await
139    }
140}
141
142pub struct OAuthRepositoryAdapter<R: RepositoryProvider> {
143    provider: Arc<R>,
144}
145
146impl<R: RepositoryProvider> OAuthRepositoryAdapter<R> {
147    pub fn new(provider: Arc<R>) -> Self {
148        Self { provider }
149    }
150}
151
152#[async_trait]
153impl<R: RepositoryProvider> OAuthRepository for OAuthRepositoryAdapter<R> {
154    async fn create_account(
155        &self,
156        provider: &str,
157        subject: &str,
158        user_id: &UserId,
159    ) -> Result<OAuthAccount, Error> {
160        self.provider
161            .oauth()
162            .create_account(provider, subject, user_id)
163            .await
164    }
165
166    async fn find_user_by_provider(
167        &self,
168        provider: &str,
169        subject: &str,
170    ) -> Result<Option<User>, Error> {
171        self.provider
172            .oauth()
173            .find_user_by_provider(provider, subject)
174            .await
175    }
176
177    async fn find_account_by_provider(
178        &self,
179        provider: &str,
180        subject: &str,
181    ) -> Result<Option<OAuthAccount>, Error> {
182        self.provider
183            .oauth()
184            .find_account_by_provider(provider, subject)
185            .await
186    }
187
188    async fn link_account(
189        &self,
190        user_id: &UserId,
191        provider: &str,
192        subject: &str,
193    ) -> Result<(), Error> {
194        self.provider
195            .oauth()
196            .link_account(user_id, provider, subject)
197            .await
198    }
199
200    async fn store_pkce_verifier(
201        &self,
202        csrf_state: &str,
203        pkce_verifier: &str,
204        expires_in: Duration,
205    ) -> Result<(), Error> {
206        self.provider
207            .oauth()
208            .store_pkce_verifier(csrf_state, pkce_verifier, expires_in)
209            .await
210    }
211
212    async fn get_pkce_verifier(&self, csrf_state: &str) -> Result<Option<String>, Error> {
213        self.provider.oauth().get_pkce_verifier(csrf_state).await
214    }
215
216    async fn delete_pkce_verifier(&self, csrf_state: &str) -> Result<(), Error> {
217        self.provider.oauth().delete_pkce_verifier(csrf_state).await
218    }
219}
220
221pub struct PasskeyRepositoryAdapter<R: RepositoryProvider> {
222    provider: Arc<R>,
223}
224
225impl<R: RepositoryProvider> PasskeyRepositoryAdapter<R> {
226    pub fn new(provider: Arc<R>) -> Self {
227        Self { provider }
228    }
229}
230
231#[async_trait]
232impl<R: RepositoryProvider> PasskeyRepository for PasskeyRepositoryAdapter<R> {
233    async fn add_credential(
234        &self,
235        user_id: &UserId,
236        credential_id: Vec<u8>,
237        public_key: Vec<u8>,
238        name: Option<String>,
239    ) -> Result<PasskeyCredential, Error> {
240        self.provider
241            .passkey()
242            .add_credential(user_id, credential_id, public_key, name)
243            .await
244    }
245
246    async fn get_credentials_for_user(
247        &self,
248        user_id: &UserId,
249    ) -> Result<Vec<PasskeyCredential>, Error> {
250        self.provider
251            .passkey()
252            .get_credentials_for_user(user_id)
253            .await
254    }
255
256    async fn get_credential(
257        &self,
258        credential_id: &[u8],
259    ) -> Result<Option<PasskeyCredential>, Error> {
260        self.provider.passkey().get_credential(credential_id).await
261    }
262
263    async fn update_last_used(&self, credential_id: &[u8]) -> Result<(), Error> {
264        self.provider
265            .passkey()
266            .update_last_used(credential_id)
267            .await
268    }
269
270    async fn delete_credential(&self, credential_id: &[u8]) -> Result<(), Error> {
271        self.provider
272            .passkey()
273            .delete_credential(credential_id)
274            .await
275    }
276
277    async fn delete_all_for_user(&self, user_id: &UserId) -> Result<(), Error> {
278        self.provider.passkey().delete_all_for_user(user_id).await
279    }
280}
281
282/// Adapter that wraps a RepositoryProvider and implements TokenRepository
283pub struct TokenRepositoryAdapter<R: RepositoryProvider> {
284    provider: Arc<R>,
285}
286
287impl<R: RepositoryProvider> TokenRepositoryAdapter<R> {
288    pub fn new(provider: Arc<R>) -> Self {
289        Self { provider }
290    }
291}
292
293#[async_trait]
294impl<R: RepositoryProvider> TokenRepository for TokenRepositoryAdapter<R> {
295    async fn create_token(
296        &self,
297        user_id: &UserId,
298        purpose: TokenPurpose,
299        expires_in: Duration,
300    ) -> Result<SecureToken, Error> {
301        self.provider
302            .token()
303            .create_token(user_id, purpose, expires_in)
304            .await
305    }
306
307    async fn verify_token(
308        &self,
309        token: &str,
310        purpose: TokenPurpose,
311    ) -> Result<Option<SecureToken>, Error> {
312        self.provider.token().verify_token(token, purpose).await
313    }
314
315    async fn check_token(&self, token: &str, purpose: TokenPurpose) -> Result<bool, Error> {
316        self.provider.token().check_token(token, purpose).await
317    }
318
319    async fn cleanup_expired_tokens(&self) -> Result<(), Error> {
320        self.provider.token().cleanup_expired_tokens().await
321    }
322}