torii_core/
session.rs

1//! Session management
2//!
3//! This module contains the core session struct and related functionality.
4//!
5//! Sessions are used to track user sessions and are used to authenticate users. The core session struct is defined as follows:
6//!
7//! | Field        | Type             | Description                                            |
8//! | ------------ | ---------------- | ------------------------------------------------------ |
9//! | `id`         | `String`         | The unique identifier for the session.                 |
10//! | `user_id`    | `String`         | The unique identifier for the user.                    |
11//! | `user_agent` | `Option<String>` | The user agent of the client that created the session. |
12//! | `ip_address` | `Option<String>` | The IP address of the client that created the session. |
13//! | `created_at` | `DateTime`       | The timestamp when the session was created.            |
14//! | `updated_at` | `DateTime`       | The timestamp when the session was last updated.       |
15//! | `expires_at` | `DateTime`       | The timestamp when the session will expire.            |
16use std::sync::Arc;
17
18use crate::{
19    Error, SessionStorage,
20    error::{SessionError, StorageError, ValidationError},
21    user::UserId,
22};
23use async_trait::async_trait;
24use chrono::{DateTime, Duration, Utc};
25use serde::{Deserialize, Serialize};
26use uuid::Uuid;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
29pub struct SessionId(String);
30
31impl SessionId {
32    pub fn new(id: &str) -> Self {
33        Self(id.to_string())
34    }
35
36    pub fn new_random() -> Self {
37        Self(Uuid::new_v4().to_string())
38    }
39
40    pub fn into_inner(self) -> String {
41        self.0
42    }
43
44    pub fn as_str(&self) -> &str {
45        &self.0
46    }
47}
48
49impl Default for SessionId {
50    fn default() -> Self {
51        Self::new_random()
52    }
53}
54
55impl From<String> for SessionId {
56    fn from(s: String) -> Self {
57        Self(s)
58    }
59}
60
61impl From<&str> for SessionId {
62    fn from(s: &str) -> Self {
63        Self(s.to_string())
64    }
65}
66
67impl std::fmt::Display for SessionId {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.0)
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct Session {
75    /// The unique identifier for the session.
76    pub id: SessionId,
77
78    /// The unique identifier for the user.
79    pub user_id: UserId,
80
81    /// The user agent of the client that created the session.
82    pub user_agent: Option<String>,
83
84    /// The IP address of the client that created the session.
85    pub ip_address: Option<String>,
86
87    /// The timestamp when the session was created.
88    pub created_at: DateTime<Utc>,
89
90    /// The timestamp when the session was last updated.
91    pub updated_at: DateTime<Utc>,
92
93    /// The timestamp when the session will expire.
94    pub expires_at: DateTime<Utc>,
95}
96
97impl Session {
98    pub fn builder() -> SessionBuilder {
99        SessionBuilder::default()
100    }
101
102    pub fn is_expired(&self) -> bool {
103        Utc::now() > self.expires_at
104    }
105}
106
107#[derive(Default)]
108pub struct SessionBuilder {
109    id: Option<SessionId>,
110    user_id: Option<UserId>,
111    user_agent: Option<String>,
112    ip_address: Option<String>,
113    created_at: Option<DateTime<Utc>>,
114    updated_at: Option<DateTime<Utc>>,
115    expires_at: Option<DateTime<Utc>>,
116}
117
118impl SessionBuilder {
119    pub fn id(mut self, id: SessionId) -> Self {
120        self.id = Some(id);
121        self
122    }
123
124    pub fn user_id(mut self, user_id: UserId) -> Self {
125        self.user_id = Some(user_id);
126        self
127    }
128
129    pub fn user_agent(mut self, user_agent: Option<String>) -> Self {
130        self.user_agent = user_agent;
131        self
132    }
133
134    pub fn ip_address(mut self, ip_address: Option<String>) -> Self {
135        self.ip_address = ip_address;
136        self
137    }
138
139    pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
140        self.created_at = Some(created_at);
141        self
142    }
143
144    pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
145        self.updated_at = Some(updated_at);
146        self
147    }
148
149    pub fn expires_at(mut self, expires_at: DateTime<Utc>) -> Self {
150        self.expires_at = Some(expires_at);
151        self
152    }
153
154    pub fn build(self) -> Result<Session, Error> {
155        let now = Utc::now();
156        Ok(Session {
157            id: self.id.unwrap_or(SessionId::new_random()),
158            user_id: self.user_id.ok_or(ValidationError::MissingField(
159                "User ID is required".to_string(),
160            ))?,
161            user_agent: self.user_agent,
162            ip_address: self.ip_address,
163            created_at: self.created_at.unwrap_or(now),
164            updated_at: self.updated_at.unwrap_or(now),
165            expires_at: self.expires_at.unwrap_or(now + Duration::days(30)),
166        })
167    }
168}
169
170#[async_trait]
171pub trait SessionManager {
172    async fn create_session(
173        &self,
174        user_id: &UserId,
175        user_agent: Option<String>,
176        ip_address: Option<String>,
177        duration: Duration,
178    ) -> Result<Session, Error>;
179    async fn get_session(&self, id: &SessionId) -> Result<Session, Error>;
180    async fn delete_session(&self, id: &SessionId) -> Result<(), Error>;
181    async fn cleanup_expired_sessions(&self) -> Result<(), Error>;
182    async fn delete_sessions_for_user(&self, user_id: &UserId) -> Result<(), Error>;
183}
184
185pub struct DefaultSessionManager<S: SessionStorage> {
186    storage: Arc<S>,
187}
188
189impl<S: SessionStorage> DefaultSessionManager<S> {
190    pub fn new(storage: Arc<S>) -> Self {
191        Self { storage }
192    }
193}
194
195#[async_trait]
196impl<S: SessionStorage> SessionManager for DefaultSessionManager<S> {
197    async fn create_session(
198        &self,
199        user_id: &UserId,
200        user_agent: Option<String>,
201        ip_address: Option<String>,
202        duration: Duration,
203    ) -> Result<Session, Error> {
204        let session = Session::builder()
205            .user_id(user_id.clone())
206            .user_agent(user_agent)
207            .ip_address(ip_address)
208            .expires_at(Utc::now() + duration)
209            .build()?;
210
211        let session = self
212            .storage
213            .create_session(&session)
214            .await
215            .map_err(|e| StorageError::Database(e.to_string()))?;
216
217        Ok(session)
218    }
219
220    async fn get_session(&self, id: &SessionId) -> Result<Session, Error> {
221        let session = self
222            .storage
223            .get_session(id)
224            .await
225            .map_err(|e| StorageError::Database(e.to_string()))?;
226
227        if session.is_expired() {
228            self.delete_session(id).await?;
229            return Err(Error::Session(SessionError::Expired));
230        }
231
232        Ok(session)
233    }
234
235    async fn delete_session(&self, id: &SessionId) -> Result<(), Error> {
236        self.storage
237            .delete_session(id)
238            .await
239            .map_err(|e| StorageError::Database(e.to_string()))?;
240
241        Ok(())
242    }
243
244    async fn cleanup_expired_sessions(&self) -> Result<(), Error> {
245        self.storage
246            .cleanup_expired_sessions()
247            .await
248            .map_err(|e| StorageError::Database(e.to_string()))?;
249
250        Ok(())
251    }
252
253    async fn delete_sessions_for_user(&self, user_id: &UserId) -> Result<(), Error> {
254        self.storage
255            .delete_sessions_for_user(user_id)
256            .await
257            .map_err(|e| StorageError::Database(e.to_string()))?;
258
259        Ok(())
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use chrono::Duration;
266
267    use super::*;
268
269    #[test]
270    fn test_session_id() {
271        let id = SessionId::new_random();
272        assert_eq!(id.to_string(), id.0.to_string());
273    }
274
275    #[test]
276    fn test_session() {
277        let session = Session::builder()
278            .user_id(UserId::new_random())
279            .user_agent(Some("test".to_string()))
280            .ip_address(Some("127.0.0.1".to_string()))
281            .expires_at(Utc::now() + Duration::days(30))
282            .build()
283            .unwrap();
284
285        assert_eq!(session.id.to_string(), session.id.0.to_string());
286    }
287}