1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
17use chrono::{DateTime, Duration, Utc};
18use rand::RngCore;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21
22use crate::error::Result;
23use crate::orm::{Db, Row};
24
25use super::role::Role;
26use super::users::Identity;
27
28pub const SESSION_COOKIE: &str = "rustio_session";
31
32const SESSION_LENGTH_DAYS: i64 = 14;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum SessionTrust {
44 Authenticated,
45 Elevated,
46 MfaVerified,
47}
48
49impl SessionTrust {
50 pub const fn as_str(self) -> &'static str {
53 match self {
54 Self::Authenticated => "authenticated",
55 Self::Elevated => "elevated",
56 Self::MfaVerified => "mfa_verified",
57 }
58 }
59
60 pub const fn rank(self) -> u8 {
62 match self {
63 Self::Authenticated => 1,
64 Self::Elevated => 2,
65 Self::MfaVerified => 3,
66 }
67 }
68
69 pub const fn satisfies(self, other: SessionTrust) -> bool {
71 self.rank() >= other.rank()
72 }
73
74 pub fn parse(s: &str) -> Self {
78 match s {
79 "elevated" => Self::Elevated,
80 "mfa_verified" => Self::MfaVerified,
81 _ => Self::Authenticated,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum SessionInvalidationReason {
96 Logout,
97 Expired,
98 UserRequested,
99 AdministrativeRevoke,
100 PasswordReset,
101 PasswordResetByOther,
102 MfaEnabled,
103 MfaDisabled,
104 MfaDisabledByOther,
105 AuthorityEscalation,
106 EmergencyRecovery,
107 RoleChangedByOther,
114 TrustEscalation,
118}
119
120impl SessionInvalidationReason {
121 pub const fn as_str(self) -> &'static str {
125 match self {
126 Self::Logout => "logout",
127 Self::Expired => "expired",
128 Self::UserRequested => "user_requested",
129 Self::AdministrativeRevoke => "administrative_revoke",
130 Self::PasswordReset => "password_reset",
131 Self::PasswordResetByOther => "password_reset_by_other",
132 Self::MfaEnabled => "mfa_enabled",
133 Self::MfaDisabled => "mfa_disabled",
134 Self::MfaDisabledByOther => "mfa_disabled_by_other",
135 Self::AuthorityEscalation => "authority_escalation",
136 Self::EmergencyRecovery => "emergency_recovery",
137 Self::RoleChangedByOther => "role_changed_by_other",
138 Self::TrustEscalation => "trust_escalation",
139 }
140 }
141}
142
143#[derive(Debug, Clone, Copy)]
145pub enum SessionTarget {
146 User { user_id: i64 },
148 UserExceptCurrent {
153 user_id: i64,
154 current_session_id: i64,
155 },
156 Single { session_id: i64 },
158}
159
160#[derive(Debug, Clone, Serialize)]
163pub struct Session {
164 pub session_id: i64,
165 pub user_id: i64,
166 pub trust_level: SessionTrust,
167 pub created_at: DateTime<Utc>,
168 pub last_seen: DateTime<Utc>,
169 pub expires_at: DateTime<Utc>,
170 pub elevated_until: Option<DateTime<Utc>>,
171 pub ip: Option<String>,
172 pub user_agent: Option<String>,
173}
174
175#[derive(Debug, Clone, Default)]
179pub struct InvalidationOutcome {
180 pub revoked_session_ids: Vec<i64>,
182 pub reason: Option<SessionInvalidationReason>,
184}
185
186pub async fn init_session_tables(db: &Db) -> Result<()> {
187 sqlx::query(
188 "CREATE TABLE IF NOT EXISTS rustio_sessions (
189 token TEXT PRIMARY KEY,
190 user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
191 expires_at TIMESTAMPTZ NOT NULL,
192 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
193 last_seen TIMESTAMPTZ NOT NULL DEFAULT NOW()
194 )",
195 )
196 .execute(db.pool())
197 .await?;
198
199 sqlx::query("CREATE INDEX IF NOT EXISTS rustio_sessions_user_idx ON rustio_sessions (user_id)")
200 .execute(db.pool())
201 .await?;
202
203 sqlx::query(
204 "CREATE INDEX IF NOT EXISTS rustio_sessions_expires_idx ON rustio_sessions (expires_at)",
205 )
206 .execute(db.pool())
207 .await?;
208
209 Ok(())
210}
211
212pub(crate) async fn migrate_session_schema(db: &Db) -> Result<()> {
216 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS ip TEXT")
217 .execute(db.pool())
218 .await?;
219 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS user_agent TEXT")
220 .execute(db.pool())
221 .await?;
222 Ok(())
223}
224
225pub(crate) async fn migrate_session_lifecycle(db: &Db) -> Result<()> {
250 sqlx::query("CREATE SEQUENCE IF NOT EXISTS rustio_sessions_session_id_seq")
251 .execute(db.pool())
252 .await?;
253 sqlx::query(
254 "ALTER TABLE rustio_sessions \
255 ADD COLUMN IF NOT EXISTS session_id BIGINT NOT NULL DEFAULT \
256 nextval('rustio_sessions_session_id_seq')",
257 )
258 .execute(db.pool())
259 .await?;
260 sqlx::query(
261 "ALTER SEQUENCE rustio_sessions_session_id_seq OWNED BY rustio_sessions.session_id",
262 )
263 .execute(db.pool())
264 .await?;
265 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS token_hash TEXT")
266 .execute(db.pool())
267 .await?;
268 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS device_id TEXT")
269 .execute(db.pool())
270 .await?;
271 sqlx::query(
272 "ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS trust_level TEXT \
273 NOT NULL DEFAULT 'authenticated'",
274 )
275 .execute(db.pool())
276 .await?;
277 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS elevated_until TIMESTAMPTZ")
278 .execute(db.pool())
279 .await?;
280 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS parent_session_id BIGINT")
281 .execute(db.pool())
282 .await?;
283 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ")
284 .execute(db.pool())
285 .await?;
286 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_reason TEXT")
287 .execute(db.pool())
288 .await?;
289
290 sqlx::query(
293 "DO $$ BEGIN \
294 IF NOT EXISTS ( \
295 SELECT 1 FROM pg_constraint \
296 WHERE conname = 'rustio_sessions_trust_level_check' \
297 ) THEN \
298 ALTER TABLE rustio_sessions \
299 ADD CONSTRAINT rustio_sessions_trust_level_check \
300 CHECK (trust_level IN ('authenticated', 'elevated', 'mfa_verified')); \
301 END IF; \
302 END $$",
303 )
304 .execute(db.pool())
305 .await?;
306
307 sqlx::query(
308 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_session_id_uq \
309 ON rustio_sessions (session_id)",
310 )
311 .execute(db.pool())
312 .await?;
313 sqlx::query(
314 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_token_hash_uq \
315 ON rustio_sessions (token_hash) \
316 WHERE revoked_at IS NULL AND token_hash IS NOT NULL",
317 )
318 .execute(db.pool())
319 .await?;
320 sqlx::query(
321 "CREATE INDEX IF NOT EXISTS rustio_sessions_user_active_idx \
322 ON rustio_sessions (user_id) WHERE revoked_at IS NULL",
323 )
324 .execute(db.pool())
325 .await?;
326 sqlx::query(
327 "CREATE INDEX IF NOT EXISTS rustio_sessions_parent_idx \
328 ON rustio_sessions (parent_session_id) WHERE parent_session_id IS NOT NULL",
329 )
330 .execute(db.pool())
331 .await?;
332
333 Ok(())
334}
335
336pub async fn create_session(db: &Db, user_id: i64) -> Result<String> {
337 let token = random_token();
338 let token_hash = hash_token_for_storage(&token);
339 let expires = Utc::now() + Duration::days(SESSION_LENGTH_DAYS);
340 sqlx::query(
346 "INSERT INTO rustio_sessions (token, token_hash, user_id, expires_at) \
347 VALUES ($1, $2, $3, $4)",
348 )
349 .bind(&token)
350 .bind(&token_hash)
351 .bind(user_id)
352 .bind(expires)
353 .execute(db.pool())
354 .await?;
355 Ok(token)
356}
357
358pub async fn delete_session(db: &Db, token: &str) -> Result<()> {
366 sqlx::query("DELETE FROM rustio_sessions WHERE token = $1 OR token_hash = $2")
367 .bind(token)
368 .bind(hash_token_for_storage(token))
369 .execute(db.pool())
370 .await?;
371 Ok(())
372}
373
374pub async fn invalidate_sessions(
397 db: &Db,
398 target: SessionTarget,
399 reason: SessionInvalidationReason,
400) -> Result<InvalidationOutcome> {
401 let reason_str = reason.as_str();
402 let revoked_ids: Vec<i64> = match target {
403 SessionTarget::User { user_id } => {
404 sqlx::query_scalar::<_, i64>(
405 "UPDATE rustio_sessions \
406 SET revoked_at = NOW(), revoked_reason = $2 \
407 WHERE user_id = $1 AND revoked_at IS NULL \
408 RETURNING session_id",
409 )
410 .bind(user_id)
411 .bind(reason_str)
412 .fetch_all(db.pool())
413 .await?
414 }
415 SessionTarget::UserExceptCurrent {
416 user_id,
417 current_session_id,
418 } => {
419 sqlx::query_scalar::<_, i64>(
420 "UPDATE rustio_sessions \
421 SET revoked_at = NOW(), revoked_reason = $3 \
422 WHERE user_id = $1 AND session_id <> $2 AND revoked_at IS NULL \
423 RETURNING session_id",
424 )
425 .bind(user_id)
426 .bind(current_session_id)
427 .bind(reason_str)
428 .fetch_all(db.pool())
429 .await?
430 }
431 SessionTarget::Single { session_id } => {
432 sqlx::query_scalar::<_, i64>(
433 "UPDATE rustio_sessions \
434 SET revoked_at = NOW(), revoked_reason = $2 \
435 WHERE session_id = $1 AND revoked_at IS NULL \
436 RETURNING session_id",
437 )
438 .bind(session_id)
439 .bind(reason_str)
440 .fetch_all(db.pool())
441 .await?
442 }
443 };
444
445 Ok(InvalidationOutcome {
446 revoked_session_ids: revoked_ids,
447 reason: Some(reason),
448 })
449}
450
451pub async fn logout_session(db: &Db, token: &str) -> Result<()> {
459 let token_hash = hash_token_for_storage(token);
460 let session_id: Option<i64> = sqlx::query_scalar::<_, i64>(
461 "SELECT session_id FROM rustio_sessions \
462 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
463 AND revoked_at IS NULL \
464 LIMIT 1",
465 )
466 .bind(&token_hash)
467 .bind(token)
468 .fetch_optional(db.pool())
469 .await?;
470
471 if let Some(sid) = session_id {
472 invalidate_sessions(
473 db,
474 SessionTarget::Single { session_id: sid },
475 SessionInvalidationReason::Logout,
476 )
477 .await?;
478 }
479 Ok(())
480}
481
482pub async fn list_active_for_user(db: &Db, user_id: i64) -> Result<Vec<Session>> {
486 let rows = sqlx::query(
487 "SELECT session_id, user_id, trust_level, created_at, last_seen, expires_at, \
488 elevated_until, ip, user_agent \
489 FROM rustio_sessions \
490 WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() \
491 ORDER BY last_seen DESC",
492 )
493 .bind(user_id)
494 .fetch_all(db.pool())
495 .await?;
496
497 rows.iter()
498 .map(|r| {
499 let r = Row::from_pg(r);
500 Ok(Session {
501 session_id: r.get_i64("session_id")?,
502 user_id: r.get_i64("user_id")?,
503 trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
504 created_at: r.get_datetime("created_at")?,
505 last_seen: r.get_datetime("last_seen")?,
506 expires_at: r.get_datetime("expires_at")?,
507 elevated_until: r.get_optional_datetime("elevated_until")?,
508 ip: r.get_optional_string("ip")?,
509 user_agent: r.get_optional_string("user_agent")?,
510 })
511 })
512 .collect()
513}
514
515pub async fn current_session_id(db: &Db, token: &str) -> Result<Option<i64>> {
519 let token_hash = hash_token_for_storage(token);
520 let id: Option<i64> = sqlx::query_scalar::<_, i64>(
521 "SELECT session_id FROM rustio_sessions \
522 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
523 AND revoked_at IS NULL AND expires_at > NOW() \
524 LIMIT 1",
525 )
526 .bind(&token_hash)
527 .bind(token)
528 .fetch_optional(db.pool())
529 .await?;
530 Ok(id)
531}
532
533pub async fn identity_from_session(db: &Db, token: &str) -> Result<Option<Identity>> {
534 let token_hash = hash_token_for_storage(token);
540 let row = sqlx::query(
541 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
542 u.must_change_password, u.mfa_enabled, \
543 s.expires_at, s.trust_level, \
544 s.token_hash IS NOT NULL AS hashed \
545 FROM rustio_sessions s \
546 JOIN rustio_users u ON u.id = s.user_id \
547 WHERE s.token_hash = $1 AND s.revoked_at IS NULL",
548 )
549 .bind(&token_hash)
550 .fetch_optional(db.pool())
551 .await?;
552
553 let row = match row {
554 Some(r) => Some(r),
555 None => {
563 sqlx::query(
564 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
565 u.must_change_password, u.mfa_enabled, \
566 s.expires_at, s.trust_level, FALSE AS hashed \
567 FROM rustio_sessions s \
568 JOIN rustio_users u ON u.id = s.user_id \
569 WHERE s.token = $1 AND s.token_hash IS NULL AND s.revoked_at IS NULL",
570 )
571 .bind(token)
572 .fetch_optional(db.pool())
573 .await?
574 }
575 };
576 let row = match row {
577 Some(r) => r,
578 None => return Ok(None),
579 };
580 let r = Row::from_pg(&row);
581 let expires_at = r.get_datetime("expires_at")?;
582 if expires_at < Utc::now() {
583 let _ = delete_session(db, token).await;
589 return Ok(None);
590 }
591
592 let db_clone = db.clone();
596 let token_owned = token.to_string();
597 let token_hash_owned = token_hash.clone();
598 tokio::spawn(async move {
599 let _ = sqlx::query(
600 "UPDATE rustio_sessions SET last_seen = NOW() \
601 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
602 AND revoked_at IS NULL",
603 )
604 .bind(&token_hash_owned)
605 .bind(&token_owned)
606 .execute(db_clone.pool())
607 .await;
608 });
609
610 Ok(Some(Identity {
611 user_id: r.get_i64("id")?,
612 email: r.get_string("email")?,
613 role: Role::parse(&r.get_string("role")?)?,
614 is_active: r.get_bool("is_active")?,
615 is_demo: r.get_bool("is_demo")?,
616 demo_label: r.get_optional_string("demo_label")?,
617 must_change_password: r.get_bool("must_change_password")?,
618 mfa_enabled: r.get_bool("mfa_enabled")?,
619 trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
620 }))
621}
622
623pub async fn purge_expired_sessions(db: &Db) -> Result<u64> {
626 let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at < NOW()")
627 .execute(db.pool())
628 .await?;
629 Ok(result.rows_affected())
630}
631
632pub fn session_token_from_cookie(cookie_header: &str) -> Option<String> {
633 let prefix = format!("{SESSION_COOKIE}=");
634 for part in cookie_header.split(';') {
635 let part = part.trim();
636 if let Some(v) = part.strip_prefix(&prefix) {
637 return Some(v.to_string());
638 }
639 }
640 None
641}
642
643pub(crate) fn random_token() -> String {
650 let mut bytes = [0u8; 32];
651 rand::thread_rng().fill_bytes(&mut bytes);
652 URL_SAFE_NO_PAD.encode(bytes)
653}
654
655pub(crate) fn hash_token_for_storage(token: &str) -> String {
666 let digest = Sha256::digest(token.as_bytes());
667 URL_SAFE_NO_PAD.encode(digest)
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn extracts_token_from_cookie_header() {
676 let h = "foo=bar; rustio_session=abc123; other=x";
677 assert_eq!(session_token_from_cookie(h), Some("abc123".into()));
678 }
679
680 #[test]
681 fn returns_none_when_cookie_missing() {
682 let h = "foo=bar; other=x";
683 assert!(session_token_from_cookie(h).is_none());
684 }
685
686 #[test]
687 fn random_token_has_reasonable_entropy() {
688 assert_ne!(random_token(), random_token());
690 }
691
692 #[test]
693 fn hash_token_is_deterministic() {
694 let token = random_token();
697 assert_eq!(
698 hash_token_for_storage(&token),
699 hash_token_for_storage(&token)
700 );
701 }
702
703 #[test]
704 fn hash_token_differs_per_token() {
705 let a = hash_token_for_storage("aaaa");
708 let b = hash_token_for_storage("aaab");
709 assert_ne!(a, b);
710 }
711
712 #[test]
713 fn hash_token_output_is_url_safe_base64() {
714 let h = hash_token_for_storage("anything");
715 assert_eq!(h.len(), 43);
717 assert!(h
718 .chars()
719 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
720 }
721
722 #[test]
723 fn hash_token_does_not_leak_plaintext() {
724 let plaintext = "secret-cookie-value-12345";
727 let h = hash_token_for_storage(plaintext);
728 assert!(!h.contains("secret"));
729 assert!(!h.contains("12345"));
730 }
731
732 #[test]
735 fn session_trust_orders_correctly() {
736 assert!(SessionTrust::Authenticated.rank() < SessionTrust::Elevated.rank());
737 assert!(SessionTrust::Elevated.rank() < SessionTrust::MfaVerified.rank());
738 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Elevated));
739 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Authenticated));
740 assert!(SessionTrust::Authenticated.satisfies(SessionTrust::Authenticated));
741 assert!(!SessionTrust::Authenticated.satisfies(SessionTrust::Elevated));
742 assert!(!SessionTrust::Elevated.satisfies(SessionTrust::MfaVerified));
743 }
744
745 #[test]
746 fn session_trust_round_trips_through_sql() {
747 for tier in [
748 SessionTrust::Authenticated,
749 SessionTrust::Elevated,
750 SessionTrust::MfaVerified,
751 ] {
752 assert_eq!(SessionTrust::parse(tier.as_str()), tier);
753 }
754 }
755
756 #[test]
757 fn session_trust_parse_defaults_safely_on_unknown() {
758 assert_eq!(SessionTrust::parse("garbage"), SessionTrust::Authenticated);
761 assert_eq!(SessionTrust::parse(""), SessionTrust::Authenticated);
762 }
763
764 #[test]
765 fn invalidation_reason_strings_are_distinct() {
766 let reasons = [
769 SessionInvalidationReason::Logout,
770 SessionInvalidationReason::Expired,
771 SessionInvalidationReason::UserRequested,
772 SessionInvalidationReason::AdministrativeRevoke,
773 SessionInvalidationReason::PasswordReset,
774 SessionInvalidationReason::PasswordResetByOther,
775 SessionInvalidationReason::MfaEnabled,
776 SessionInvalidationReason::MfaDisabled,
777 SessionInvalidationReason::MfaDisabledByOther,
778 SessionInvalidationReason::AuthorityEscalation,
779 SessionInvalidationReason::EmergencyRecovery,
780 SessionInvalidationReason::TrustEscalation,
781 ];
782 let mut set = std::collections::HashSet::new();
783 for r in reasons {
784 assert!(set.insert(r.as_str()), "duplicate as_str() for {r:?}");
785 }
786 assert_eq!(set.len(), reasons.len());
787 }
788}