1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
17use chrono::{DateTime, Duration, Utc};
18use rand::RngCore;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21
22use crate::error::Result;
23use crate::orm::{Db, Row};
24
25use super::role::Role;
26use super::users::Identity;
27
28pub const SESSION_COOKIE: &str = "rustio_session";
31
32const SESSION_LENGTH_DAYS: i64 = 14;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum SessionTrust {
44 Authenticated,
45 Elevated,
46 MfaVerified,
47}
48
49impl SessionTrust {
50 pub const fn as_str(self) -> &'static str {
53 match self {
54 Self::Authenticated => "authenticated",
55 Self::Elevated => "elevated",
56 Self::MfaVerified => "mfa_verified",
57 }
58 }
59
60 pub const fn rank(self) -> u8 {
62 match self {
63 Self::Authenticated => 1,
64 Self::Elevated => 2,
65 Self::MfaVerified => 3,
66 }
67 }
68
69 pub const fn satisfies(self, other: SessionTrust) -> bool {
71 self.rank() >= other.rank()
72 }
73
74 pub fn parse(s: &str) -> Self {
78 match s {
79 "elevated" => Self::Elevated,
80 "mfa_verified" => Self::MfaVerified,
81 _ => Self::Authenticated,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum SessionInvalidationReason {
96 Logout,
97 Expired,
98 UserRequested,
99 AdministrativeRevoke,
100 PasswordReset,
101 PasswordResetByOther,
102 MfaEnabled,
103 MfaDisabled,
104 MfaDisabledByOther,
105 AuthorityEscalation,
106 EmergencyRecovery,
107 TrustEscalation,
111}
112
113impl SessionInvalidationReason {
114 pub const fn as_str(self) -> &'static str {
118 match self {
119 Self::Logout => "logout",
120 Self::Expired => "expired",
121 Self::UserRequested => "user_requested",
122 Self::AdministrativeRevoke => "administrative_revoke",
123 Self::PasswordReset => "password_reset",
124 Self::PasswordResetByOther => "password_reset_by_other",
125 Self::MfaEnabled => "mfa_enabled",
126 Self::MfaDisabled => "mfa_disabled",
127 Self::MfaDisabledByOther => "mfa_disabled_by_other",
128 Self::AuthorityEscalation => "authority_escalation",
129 Self::EmergencyRecovery => "emergency_recovery",
130 Self::TrustEscalation => "trust_escalation",
131 }
132 }
133}
134
135#[derive(Debug, Clone, Copy)]
137pub enum SessionTarget {
138 User { user_id: i64 },
140 UserExceptCurrent {
145 user_id: i64,
146 current_session_id: i64,
147 },
148 Single { session_id: i64 },
150}
151
152#[derive(Debug, Clone, Serialize)]
155pub struct Session {
156 pub session_id: i64,
157 pub user_id: i64,
158 pub trust_level: SessionTrust,
159 pub created_at: DateTime<Utc>,
160 pub last_seen: DateTime<Utc>,
161 pub expires_at: DateTime<Utc>,
162 pub elevated_until: Option<DateTime<Utc>>,
163 pub ip: Option<String>,
164 pub user_agent: Option<String>,
165}
166
167#[derive(Debug, Clone, Default)]
171pub struct InvalidationOutcome {
172 pub revoked_session_ids: Vec<i64>,
174 pub reason: Option<SessionInvalidationReason>,
176}
177
178pub async fn init_session_tables(db: &Db) -> Result<()> {
179 sqlx::query(
180 "CREATE TABLE IF NOT EXISTS rustio_sessions (
181 token TEXT PRIMARY KEY,
182 user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
183 expires_at TIMESTAMPTZ NOT NULL,
184 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
185 last_seen TIMESTAMPTZ NOT NULL DEFAULT NOW()
186 )",
187 )
188 .execute(db.pool())
189 .await?;
190
191 sqlx::query("CREATE INDEX IF NOT EXISTS rustio_sessions_user_idx ON rustio_sessions (user_id)")
192 .execute(db.pool())
193 .await?;
194
195 sqlx::query(
196 "CREATE INDEX IF NOT EXISTS rustio_sessions_expires_idx ON rustio_sessions (expires_at)",
197 )
198 .execute(db.pool())
199 .await?;
200
201 Ok(())
202}
203
204pub(crate) async fn migrate_session_schema(db: &Db) -> Result<()> {
208 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS ip TEXT")
209 .execute(db.pool())
210 .await?;
211 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS user_agent TEXT")
212 .execute(db.pool())
213 .await?;
214 Ok(())
215}
216
217pub(crate) async fn migrate_session_lifecycle(db: &Db) -> Result<()> {
242 sqlx::query("CREATE SEQUENCE IF NOT EXISTS rustio_sessions_session_id_seq")
243 .execute(db.pool())
244 .await?;
245 sqlx::query(
246 "ALTER TABLE rustio_sessions \
247 ADD COLUMN IF NOT EXISTS session_id BIGINT NOT NULL DEFAULT \
248 nextval('rustio_sessions_session_id_seq')",
249 )
250 .execute(db.pool())
251 .await?;
252 sqlx::query(
253 "ALTER SEQUENCE rustio_sessions_session_id_seq OWNED BY rustio_sessions.session_id",
254 )
255 .execute(db.pool())
256 .await?;
257 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS token_hash TEXT")
258 .execute(db.pool())
259 .await?;
260 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS device_id TEXT")
261 .execute(db.pool())
262 .await?;
263 sqlx::query(
264 "ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS trust_level TEXT \
265 NOT NULL DEFAULT 'authenticated'",
266 )
267 .execute(db.pool())
268 .await?;
269 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS elevated_until TIMESTAMPTZ")
270 .execute(db.pool())
271 .await?;
272 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS parent_session_id BIGINT")
273 .execute(db.pool())
274 .await?;
275 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ")
276 .execute(db.pool())
277 .await?;
278 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_reason TEXT")
279 .execute(db.pool())
280 .await?;
281
282 sqlx::query(
285 "DO $$ BEGIN \
286 IF NOT EXISTS ( \
287 SELECT 1 FROM pg_constraint \
288 WHERE conname = 'rustio_sessions_trust_level_check' \
289 ) THEN \
290 ALTER TABLE rustio_sessions \
291 ADD CONSTRAINT rustio_sessions_trust_level_check \
292 CHECK (trust_level IN ('authenticated', 'elevated', 'mfa_verified')); \
293 END IF; \
294 END $$",
295 )
296 .execute(db.pool())
297 .await?;
298
299 sqlx::query(
300 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_session_id_uq \
301 ON rustio_sessions (session_id)",
302 )
303 .execute(db.pool())
304 .await?;
305 sqlx::query(
306 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_token_hash_uq \
307 ON rustio_sessions (token_hash) \
308 WHERE revoked_at IS NULL AND token_hash IS NOT NULL",
309 )
310 .execute(db.pool())
311 .await?;
312 sqlx::query(
313 "CREATE INDEX IF NOT EXISTS rustio_sessions_user_active_idx \
314 ON rustio_sessions (user_id) WHERE revoked_at IS NULL",
315 )
316 .execute(db.pool())
317 .await?;
318 sqlx::query(
319 "CREATE INDEX IF NOT EXISTS rustio_sessions_parent_idx \
320 ON rustio_sessions (parent_session_id) WHERE parent_session_id IS NOT NULL",
321 )
322 .execute(db.pool())
323 .await?;
324
325 Ok(())
326}
327
328pub async fn create_session(db: &Db, user_id: i64) -> Result<String> {
329 let token = random_token();
330 let token_hash = hash_token_for_storage(&token);
331 let expires = Utc::now() + Duration::days(SESSION_LENGTH_DAYS);
332 sqlx::query(
338 "INSERT INTO rustio_sessions (token, token_hash, user_id, expires_at) \
339 VALUES ($1, $2, $3, $4)",
340 )
341 .bind(&token)
342 .bind(&token_hash)
343 .bind(user_id)
344 .bind(expires)
345 .execute(db.pool())
346 .await?;
347 Ok(token)
348}
349
350pub async fn delete_session(db: &Db, token: &str) -> Result<()> {
358 sqlx::query("DELETE FROM rustio_sessions WHERE token = $1 OR token_hash = $2")
359 .bind(token)
360 .bind(hash_token_for_storage(token))
361 .execute(db.pool())
362 .await?;
363 Ok(())
364}
365
366pub async fn invalidate_sessions(
389 db: &Db,
390 target: SessionTarget,
391 reason: SessionInvalidationReason,
392) -> Result<InvalidationOutcome> {
393 let reason_str = reason.as_str();
394 let revoked_ids: Vec<i64> = match target {
395 SessionTarget::User { user_id } => {
396 sqlx::query_scalar::<_, i64>(
397 "UPDATE rustio_sessions \
398 SET revoked_at = NOW(), revoked_reason = $2 \
399 WHERE user_id = $1 AND revoked_at IS NULL \
400 RETURNING session_id",
401 )
402 .bind(user_id)
403 .bind(reason_str)
404 .fetch_all(db.pool())
405 .await?
406 }
407 SessionTarget::UserExceptCurrent {
408 user_id,
409 current_session_id,
410 } => {
411 sqlx::query_scalar::<_, i64>(
412 "UPDATE rustio_sessions \
413 SET revoked_at = NOW(), revoked_reason = $3 \
414 WHERE user_id = $1 AND session_id <> $2 AND revoked_at IS NULL \
415 RETURNING session_id",
416 )
417 .bind(user_id)
418 .bind(current_session_id)
419 .bind(reason_str)
420 .fetch_all(db.pool())
421 .await?
422 }
423 SessionTarget::Single { session_id } => {
424 sqlx::query_scalar::<_, i64>(
425 "UPDATE rustio_sessions \
426 SET revoked_at = NOW(), revoked_reason = $2 \
427 WHERE session_id = $1 AND revoked_at IS NULL \
428 RETURNING session_id",
429 )
430 .bind(session_id)
431 .bind(reason_str)
432 .fetch_all(db.pool())
433 .await?
434 }
435 };
436
437 Ok(InvalidationOutcome {
438 revoked_session_ids: revoked_ids,
439 reason: Some(reason),
440 })
441}
442
443pub async fn logout_session(db: &Db, token: &str) -> Result<()> {
451 let token_hash = hash_token_for_storage(token);
452 let session_id: Option<i64> = sqlx::query_scalar::<_, i64>(
453 "SELECT session_id FROM rustio_sessions \
454 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
455 AND revoked_at IS NULL \
456 LIMIT 1",
457 )
458 .bind(&token_hash)
459 .bind(token)
460 .fetch_optional(db.pool())
461 .await?;
462
463 if let Some(sid) = session_id {
464 invalidate_sessions(
465 db,
466 SessionTarget::Single { session_id: sid },
467 SessionInvalidationReason::Logout,
468 )
469 .await?;
470 }
471 Ok(())
472}
473
474pub async fn list_active_for_user(db: &Db, user_id: i64) -> Result<Vec<Session>> {
478 let rows = sqlx::query(
479 "SELECT session_id, user_id, trust_level, created_at, last_seen, expires_at, \
480 elevated_until, ip, user_agent \
481 FROM rustio_sessions \
482 WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() \
483 ORDER BY last_seen DESC",
484 )
485 .bind(user_id)
486 .fetch_all(db.pool())
487 .await?;
488
489 rows.iter()
490 .map(|r| {
491 let r = Row::from_pg(r);
492 Ok(Session {
493 session_id: r.get_i64("session_id")?,
494 user_id: r.get_i64("user_id")?,
495 trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
496 created_at: r.get_datetime("created_at")?,
497 last_seen: r.get_datetime("last_seen")?,
498 expires_at: r.get_datetime("expires_at")?,
499 elevated_until: None, ip: r.get_optional_string("ip")?,
501 user_agent: r.get_optional_string("user_agent")?,
502 })
503 })
504 .collect()
505}
506
507pub async fn current_session_id(db: &Db, token: &str) -> Result<Option<i64>> {
511 let token_hash = hash_token_for_storage(token);
512 let id: Option<i64> = sqlx::query_scalar::<_, i64>(
513 "SELECT session_id FROM rustio_sessions \
514 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
515 AND revoked_at IS NULL AND expires_at > NOW() \
516 LIMIT 1",
517 )
518 .bind(&token_hash)
519 .bind(token)
520 .fetch_optional(db.pool())
521 .await?;
522 Ok(id)
523}
524
525pub async fn identity_from_session(db: &Db, token: &str) -> Result<Option<Identity>> {
526 let token_hash = hash_token_for_storage(token);
532 let row = sqlx::query(
533 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
534 s.expires_at, s.token_hash IS NOT NULL AS hashed \
535 FROM rustio_sessions s \
536 JOIN rustio_users u ON u.id = s.user_id \
537 WHERE s.token_hash = $1 AND s.revoked_at IS NULL",
538 )
539 .bind(&token_hash)
540 .fetch_optional(db.pool())
541 .await?;
542
543 let row = match row {
544 Some(r) => Some(r),
545 None => {
553 sqlx::query(
554 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
555 s.expires_at, FALSE AS hashed \
556 FROM rustio_sessions s \
557 JOIN rustio_users u ON u.id = s.user_id \
558 WHERE s.token = $1 AND s.token_hash IS NULL AND s.revoked_at IS NULL",
559 )
560 .bind(token)
561 .fetch_optional(db.pool())
562 .await?
563 }
564 };
565 let row = match row {
566 Some(r) => r,
567 None => return Ok(None),
568 };
569 let r = Row::from_pg(&row);
570 let expires_at = r.get_datetime("expires_at")?;
571 if expires_at < Utc::now() {
572 let _ = delete_session(db, token).await;
578 return Ok(None);
579 }
580
581 let db_clone = db.clone();
585 let token_owned = token.to_string();
586 let token_hash_owned = token_hash.clone();
587 tokio::spawn(async move {
588 let _ = sqlx::query(
589 "UPDATE rustio_sessions SET last_seen = NOW() \
590 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
591 AND revoked_at IS NULL",
592 )
593 .bind(&token_hash_owned)
594 .bind(&token_owned)
595 .execute(db_clone.pool())
596 .await;
597 });
598
599 Ok(Some(Identity {
600 user_id: r.get_i64("id")?,
601 email: r.get_string("email")?,
602 role: Role::parse(&r.get_string("role")?)?,
603 is_active: r.get_bool("is_active")?,
604 is_demo: r.get_bool("is_demo")?,
605 demo_label: r.get_optional_string("demo_label")?,
606 }))
607}
608
609pub async fn purge_expired_sessions(db: &Db) -> Result<u64> {
612 let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at < NOW()")
613 .execute(db.pool())
614 .await?;
615 Ok(result.rows_affected())
616}
617
618pub fn session_token_from_cookie(cookie_header: &str) -> Option<String> {
619 let prefix = format!("{SESSION_COOKIE}=");
620 for part in cookie_header.split(';') {
621 let part = part.trim();
622 if let Some(v) = part.strip_prefix(&prefix) {
623 return Some(v.to_string());
624 }
625 }
626 None
627}
628
629fn random_token() -> String {
630 let mut bytes = [0u8; 32];
631 rand::thread_rng().fill_bytes(&mut bytes);
632 URL_SAFE_NO_PAD.encode(bytes)
633}
634
635pub(crate) fn hash_token_for_storage(token: &str) -> String {
646 let digest = Sha256::digest(token.as_bytes());
647 URL_SAFE_NO_PAD.encode(digest)
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653
654 #[test]
655 fn extracts_token_from_cookie_header() {
656 let h = "foo=bar; rustio_session=abc123; other=x";
657 assert_eq!(session_token_from_cookie(h), Some("abc123".into()));
658 }
659
660 #[test]
661 fn returns_none_when_cookie_missing() {
662 let h = "foo=bar; other=x";
663 assert!(session_token_from_cookie(h).is_none());
664 }
665
666 #[test]
667 fn random_token_has_reasonable_entropy() {
668 assert_ne!(random_token(), random_token());
670 }
671
672 #[test]
673 fn hash_token_is_deterministic() {
674 let token = random_token();
677 assert_eq!(
678 hash_token_for_storage(&token),
679 hash_token_for_storage(&token)
680 );
681 }
682
683 #[test]
684 fn hash_token_differs_per_token() {
685 let a = hash_token_for_storage("aaaa");
688 let b = hash_token_for_storage("aaab");
689 assert_ne!(a, b);
690 }
691
692 #[test]
693 fn hash_token_output_is_url_safe_base64() {
694 let h = hash_token_for_storage("anything");
695 assert_eq!(h.len(), 43);
697 assert!(h
698 .chars()
699 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
700 }
701
702 #[test]
703 fn hash_token_does_not_leak_plaintext() {
704 let plaintext = "secret-cookie-value-12345";
707 let h = hash_token_for_storage(plaintext);
708 assert!(!h.contains("secret"));
709 assert!(!h.contains("12345"));
710 }
711
712 #[test]
715 fn session_trust_orders_correctly() {
716 assert!(SessionTrust::Authenticated.rank() < SessionTrust::Elevated.rank());
717 assert!(SessionTrust::Elevated.rank() < SessionTrust::MfaVerified.rank());
718 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Elevated));
719 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Authenticated));
720 assert!(SessionTrust::Authenticated.satisfies(SessionTrust::Authenticated));
721 assert!(!SessionTrust::Authenticated.satisfies(SessionTrust::Elevated));
722 assert!(!SessionTrust::Elevated.satisfies(SessionTrust::MfaVerified));
723 }
724
725 #[test]
726 fn session_trust_round_trips_through_sql() {
727 for tier in [
728 SessionTrust::Authenticated,
729 SessionTrust::Elevated,
730 SessionTrust::MfaVerified,
731 ] {
732 assert_eq!(SessionTrust::parse(tier.as_str()), tier);
733 }
734 }
735
736 #[test]
737 fn session_trust_parse_defaults_safely_on_unknown() {
738 assert_eq!(SessionTrust::parse("garbage"), SessionTrust::Authenticated);
741 assert_eq!(SessionTrust::parse(""), SessionTrust::Authenticated);
742 }
743
744 #[test]
745 fn invalidation_reason_strings_are_distinct() {
746 let reasons = [
749 SessionInvalidationReason::Logout,
750 SessionInvalidationReason::Expired,
751 SessionInvalidationReason::UserRequested,
752 SessionInvalidationReason::AdministrativeRevoke,
753 SessionInvalidationReason::PasswordReset,
754 SessionInvalidationReason::PasswordResetByOther,
755 SessionInvalidationReason::MfaEnabled,
756 SessionInvalidationReason::MfaDisabled,
757 SessionInvalidationReason::MfaDisabledByOther,
758 SessionInvalidationReason::AuthorityEscalation,
759 SessionInvalidationReason::EmergencyRecovery,
760 SessionInvalidationReason::TrustEscalation,
761 ];
762 let mut set = std::collections::HashSet::new();
763 for r in reasons {
764 assert!(set.insert(r.as_str()), "duplicate as_str() for {r:?}");
765 }
766 assert_eq!(set.len(), reasons.len());
767 }
768}