1use std::sync::Arc;
17
18use crate::{
19 Error, SessionStorage,
20 error::{SessionError, StorageError, ValidationError},
21 user::UserId,
22};
23use async_trait::async_trait;
24use chrono::{DateTime, Duration, Utc};
25use serde::{Deserialize, Serialize};
26use uuid::Uuid;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
29pub struct SessionId(String);
30
31impl SessionId {
32 pub fn new(id: &str) -> Self {
33 Self(id.to_string())
34 }
35
36 pub fn new_random() -> Self {
37 Self(Uuid::new_v4().to_string())
38 }
39
40 pub fn into_inner(self) -> String {
41 self.0
42 }
43
44 pub fn as_str(&self) -> &str {
45 &self.0
46 }
47}
48
49impl Default for SessionId {
50 fn default() -> Self {
51 Self::new_random()
52 }
53}
54
55impl From<String> for SessionId {
56 fn from(s: String) -> Self {
57 Self(s)
58 }
59}
60
61impl From<&str> for SessionId {
62 fn from(s: &str) -> Self {
63 Self(s.to_string())
64 }
65}
66
67impl std::fmt::Display for SessionId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct Session {
75 pub id: SessionId,
77
78 pub user_id: UserId,
80
81 pub user_agent: Option<String>,
83
84 pub ip_address: Option<String>,
86
87 pub created_at: DateTime<Utc>,
89
90 pub updated_at: DateTime<Utc>,
92
93 pub expires_at: DateTime<Utc>,
95}
96
97impl Session {
98 pub fn builder() -> SessionBuilder {
99 SessionBuilder::default()
100 }
101
102 pub fn is_expired(&self) -> bool {
103 Utc::now() > self.expires_at
104 }
105}
106
107#[derive(Default)]
108pub struct SessionBuilder {
109 id: Option<SessionId>,
110 user_id: Option<UserId>,
111 user_agent: Option<String>,
112 ip_address: Option<String>,
113 created_at: Option<DateTime<Utc>>,
114 updated_at: Option<DateTime<Utc>>,
115 expires_at: Option<DateTime<Utc>>,
116}
117
118impl SessionBuilder {
119 pub fn id(mut self, id: SessionId) -> Self {
120 self.id = Some(id);
121 self
122 }
123
124 pub fn user_id(mut self, user_id: UserId) -> Self {
125 self.user_id = Some(user_id);
126 self
127 }
128
129 pub fn user_agent(mut self, user_agent: Option<String>) -> Self {
130 self.user_agent = user_agent;
131 self
132 }
133
134 pub fn ip_address(mut self, ip_address: Option<String>) -> Self {
135 self.ip_address = ip_address;
136 self
137 }
138
139 pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
140 self.created_at = Some(created_at);
141 self
142 }
143
144 pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
145 self.updated_at = Some(updated_at);
146 self
147 }
148
149 pub fn expires_at(mut self, expires_at: DateTime<Utc>) -> Self {
150 self.expires_at = Some(expires_at);
151 self
152 }
153
154 pub fn build(self) -> Result<Session, Error> {
155 let now = Utc::now();
156 Ok(Session {
157 id: self.id.unwrap_or(SessionId::new_random()),
158 user_id: self.user_id.ok_or(ValidationError::MissingField(
159 "User ID is required".to_string(),
160 ))?,
161 user_agent: self.user_agent,
162 ip_address: self.ip_address,
163 created_at: self.created_at.unwrap_or(now),
164 updated_at: self.updated_at.unwrap_or(now),
165 expires_at: self.expires_at.unwrap_or(now + Duration::days(30)),
166 })
167 }
168}
169
170#[async_trait]
171pub trait SessionManager {
172 async fn create_session(
173 &self,
174 user_id: &UserId,
175 user_agent: Option<String>,
176 ip_address: Option<String>,
177 duration: Duration,
178 ) -> Result<Session, Error>;
179 async fn get_session(&self, id: &SessionId) -> Result<Session, Error>;
180 async fn delete_session(&self, id: &SessionId) -> Result<(), Error>;
181 async fn cleanup_expired_sessions(&self) -> Result<(), Error>;
182 async fn delete_sessions_for_user(&self, user_id: &UserId) -> Result<(), Error>;
183}
184
185pub struct DefaultSessionManager<S: SessionStorage> {
186 storage: Arc<S>,
187}
188
189impl<S: SessionStorage> DefaultSessionManager<S> {
190 pub fn new(storage: Arc<S>) -> Self {
191 Self { storage }
192 }
193}
194
195#[async_trait]
196impl<S: SessionStorage> SessionManager for DefaultSessionManager<S> {
197 async fn create_session(
198 &self,
199 user_id: &UserId,
200 user_agent: Option<String>,
201 ip_address: Option<String>,
202 duration: Duration,
203 ) -> Result<Session, Error> {
204 let session = Session::builder()
205 .user_id(user_id.clone())
206 .user_agent(user_agent)
207 .ip_address(ip_address)
208 .expires_at(Utc::now() + duration)
209 .build()?;
210
211 let session = self
212 .storage
213 .create_session(&session)
214 .await
215 .map_err(|e| StorageError::Database(e.to_string()))?;
216
217 Ok(session)
218 }
219
220 async fn get_session(&self, id: &SessionId) -> Result<Session, Error> {
221 let session = self
222 .storage
223 .get_session(id)
224 .await
225 .map_err(|e| StorageError::Database(e.to_string()))?;
226
227 if session.is_expired() {
228 self.delete_session(id).await?;
229 return Err(Error::Session(SessionError::Expired));
230 }
231
232 Ok(session)
233 }
234
235 async fn delete_session(&self, id: &SessionId) -> Result<(), Error> {
236 self.storage
237 .delete_session(id)
238 .await
239 .map_err(|e| StorageError::Database(e.to_string()))?;
240
241 Ok(())
242 }
243
244 async fn cleanup_expired_sessions(&self) -> Result<(), Error> {
245 self.storage
246 .cleanup_expired_sessions()
247 .await
248 .map_err(|e| StorageError::Database(e.to_string()))?;
249
250 Ok(())
251 }
252
253 async fn delete_sessions_for_user(&self, user_id: &UserId) -> Result<(), Error> {
254 self.storage
255 .delete_sessions_for_user(user_id)
256 .await
257 .map_err(|e| StorageError::Database(e.to_string()))?;
258
259 Ok(())
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use chrono::Duration;
266
267 use super::*;
268
269 #[test]
270 fn test_session_id() {
271 let id = SessionId::new_random();
272 assert_eq!(id.to_string(), id.0.to_string());
273 }
274
275 #[test]
276 fn test_session() {
277 let session = Session::builder()
278 .user_id(UserId::new_random())
279 .user_agent(Some("test".to_string()))
280 .ip_address(Some("127.0.0.1".to_string()))
281 .expires_at(Utc::now() + Duration::days(30))
282 .build()
283 .unwrap();
284
285 assert_eq!(session.id.to_string(), session.id.0.to_string());
286 }
287}