1use chrono::{Duration, Utc};
8use rand::RngCore;
9use sha2::{Digest, Sha256};
10
11use super::error::AuthError;
12use crate::storage::DbPool;
13
14const SESSION_LIFETIME_DAYS: i64 = 7;
16
17#[derive(Debug)]
19pub struct Session {
20 pub id: String,
21 pub csrf_token: String,
22 pub created_at: String,
23 pub expires_at: String,
24 pub last_accessed_at: String,
25}
26
27pub struct NewSession {
30 pub raw_token: String,
31 pub csrf_token: String,
32 pub expires_at: String,
33}
34
35fn hash_token(raw_token: &str) -> String {
37 let mut hasher = Sha256::new();
38 hasher.update(raw_token.as_bytes());
39 hex::encode(hasher.finalize())
40}
41
42fn random_hex(bytes: usize) -> String {
44 let mut buf = vec![0u8; bytes];
45 rand::rng().fill_bytes(&mut buf);
46 hex::encode(&buf)
47}
48
49pub async fn create_session(pool: &DbPool) -> Result<NewSession, AuthError> {
53 let id = random_hex(16);
54 let raw_token = random_hex(32);
55 let csrf_token = random_hex(16);
56 let token_hash = hash_token(&raw_token);
57 let now = Utc::now();
58 let expires_at = now + Duration::days(SESSION_LIFETIME_DAYS);
59 let now_str = now.format("%Y-%m-%dT%H:%M:%SZ").to_string();
60 let expires_str = expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string();
61
62 sqlx::query(
63 "INSERT INTO sessions (id, token_hash, csrf_token, created_at, expires_at, last_accessed_at)
64 VALUES (?, ?, ?, ?, ?, ?)",
65 )
66 .bind(&id)
67 .bind(&token_hash)
68 .bind(&csrf_token)
69 .bind(&now_str)
70 .bind(&expires_str)
71 .bind(&now_str)
72 .execute(pool)
73 .await
74 .map_err(|e| AuthError::Database { source: e })?;
75
76 Ok(NewSession {
77 raw_token,
78 csrf_token,
79 expires_at: expires_str,
80 })
81}
82
83pub async fn validate_session(
87 pool: &DbPool,
88 raw_token: &str,
89) -> Result<Option<Session>, AuthError> {
90 let token_hash = hash_token(raw_token);
91 let now_str = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
92
93 let row = sqlx::query_as::<_, (String, String, String, String, String)>(
94 "SELECT id, csrf_token, created_at, expires_at, last_accessed_at
95 FROM sessions WHERE token_hash = ? AND expires_at > ?",
96 )
97 .bind(&token_hash)
98 .bind(&now_str)
99 .fetch_optional(pool)
100 .await
101 .map_err(|e| AuthError::Database { source: e })?;
102
103 let Some((id, csrf_token, created_at, expires_at, last_accessed_at)) = row else {
104 return Ok(None);
105 };
106
107 sqlx::query("UPDATE sessions SET last_accessed_at = ? WHERE id = ?")
109 .bind(&now_str)
110 .bind(&id)
111 .execute(pool)
112 .await
113 .map_err(|e| AuthError::Database { source: e })?;
114
115 Ok(Some(Session {
116 id,
117 csrf_token,
118 created_at,
119 expires_at,
120 last_accessed_at,
121 }))
122}
123
124pub async fn delete_session(pool: &DbPool, raw_token: &str) -> Result<(), AuthError> {
126 let token_hash = hash_token(raw_token);
127 sqlx::query("DELETE FROM sessions WHERE token_hash = ?")
128 .bind(&token_hash)
129 .execute(pool)
130 .await
131 .map_err(|e| AuthError::Database { source: e })?;
132 Ok(())
133}
134
135pub async fn cleanup_expired(pool: &DbPool) -> Result<u64, AuthError> {
137 let now_str = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
138 let result = sqlx::query("DELETE FROM sessions WHERE expires_at <= ?")
139 .bind(&now_str)
140 .execute(pool)
141 .await
142 .map_err(|e| AuthError::Database { source: e })?;
143 Ok(result.rows_affected())
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::storage::init_test_db;
150
151 #[tokio::test]
152 async fn create_and_validate_session() {
153 let pool = init_test_db().await.unwrap();
154 let new = create_session(&pool).await.unwrap();
155 assert!(!new.raw_token.is_empty());
156 assert!(!new.csrf_token.is_empty());
157
158 let session = validate_session(&pool, &new.raw_token).await.unwrap();
159 assert!(session.is_some());
160 let session = session.unwrap();
161 assert_eq!(session.csrf_token, new.csrf_token);
162 }
163
164 #[tokio::test]
165 async fn validate_invalid_token_returns_none() {
166 let pool = init_test_db().await.unwrap();
167 let session = validate_session(&pool, "nonexistent-token").await.unwrap();
168 assert!(session.is_none());
169 }
170
171 #[tokio::test]
172 async fn delete_session_invalidates_token() {
173 let pool = init_test_db().await.unwrap();
174 let new = create_session(&pool).await.unwrap();
175 delete_session(&pool, &new.raw_token).await.unwrap();
176 let session = validate_session(&pool, &new.raw_token).await.unwrap();
177 assert!(session.is_none());
178 }
179
180 #[tokio::test]
181 async fn cleanup_expired_removes_old_sessions() {
182 let pool = init_test_db().await.unwrap();
183
184 sqlx::query(
186 "INSERT INTO sessions (id, token_hash, csrf_token, created_at, expires_at, last_accessed_at)
187 VALUES ('old', 'oldhash', 'oldcsrf', '2020-01-01T00:00:00Z', '2020-01-02T00:00:00Z', '2020-01-01T00:00:00Z')",
188 )
189 .execute(&pool)
190 .await
191 .unwrap();
192
193 let removed = cleanup_expired(&pool).await.unwrap();
194 assert_eq!(removed, 1);
195 }
196
197 #[tokio::test]
198 async fn cleanup_expired_preserves_active_sessions() {
199 let pool = init_test_db().await.unwrap();
200
201 sqlx::query(
203 "INSERT INTO sessions (id, token_hash, csrf_token, created_at, expires_at, last_accessed_at)
204 VALUES ('expired', 'hash1', 'csrf1', '2020-01-01T00:00:00Z', '2020-01-02T00:00:00Z', '2020-01-01T00:00:00Z')",
205 )
206 .execute(&pool)
207 .await
208 .unwrap();
209
210 let new = create_session(&pool).await.unwrap();
212
213 let removed = cleanup_expired(&pool).await.unwrap();
214 assert_eq!(removed, 1);
215
216 let session = validate_session(&pool, &new.raw_token).await.unwrap();
218 assert!(session.is_some());
219 }
220
221 #[tokio::test]
222 async fn cleanup_expired_returns_zero_when_none_expired() {
223 let pool = init_test_db().await.unwrap();
224 create_session(&pool).await.unwrap();
225
226 let removed = cleanup_expired(&pool).await.unwrap();
227 assert_eq!(removed, 0);
228 }
229
230 #[tokio::test]
231 async fn multiple_sessions_are_independent() {
232 let pool = init_test_db().await.unwrap();
233
234 let s1 = create_session(&pool).await.unwrap();
235 let s2 = create_session(&pool).await.unwrap();
236
237 assert!(validate_session(&pool, &s1.raw_token)
239 .await
240 .unwrap()
241 .is_some());
242 assert!(validate_session(&pool, &s2.raw_token)
243 .await
244 .unwrap()
245 .is_some());
246
247 delete_session(&pool, &s1.raw_token).await.unwrap();
249 assert!(validate_session(&pool, &s1.raw_token)
250 .await
251 .unwrap()
252 .is_none());
253 assert!(validate_session(&pool, &s2.raw_token)
254 .await
255 .unwrap()
256 .is_some());
257 }
258
259 #[tokio::test]
260 async fn delete_nonexistent_session_is_noop() {
261 let pool = init_test_db().await.unwrap();
262 delete_session(&pool, "totally-fake-token").await.unwrap();
264 }
265
266 #[tokio::test]
267 async fn session_has_unique_tokens() {
268 let pool = init_test_db().await.unwrap();
269 let s1 = create_session(&pool).await.unwrap();
270 let s2 = create_session(&pool).await.unwrap();
271 assert_ne!(s1.raw_token, s2.raw_token);
272 assert_ne!(s1.csrf_token, s2.csrf_token);
273 }
274
275 #[tokio::test]
276 async fn validate_expired_session_returns_none() {
277 let pool = init_test_db().await.unwrap();
278
279 let token_hash = hash_token("my-raw-token");
280 sqlx::query(
281 "INSERT INTO sessions (id, token_hash, csrf_token, created_at, expires_at, last_accessed_at)
282 VALUES ('exp', ?, 'csrf', '2020-01-01T00:00:00Z', '2020-01-02T00:00:00Z', '2020-01-01T00:00:00Z')",
283 )
284 .bind(&token_hash)
285 .execute(&pool)
286 .await
287 .unwrap();
288
289 let session = validate_session(&pool, "my-raw-token").await.unwrap();
290 assert!(session.is_none(), "expired session should not validate");
291 }
292}