1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
17use chrono::{DateTime, Duration, Utc};
18use rand::RngCore;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21
22use crate::error::Result;
23use crate::orm::{Db, Row};
24
25use super::role::Role;
26use super::users::Identity;
27
28pub const SESSION_COOKIE: &str = "rustio_session";
32
33const SESSION_LENGTH_DAYS: i64 = 14;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(rename_all = "snake_case")]
45pub enum SessionTrust {
46 Authenticated,
47 Elevated,
48 MfaVerified,
49}
50
51impl SessionTrust {
52 pub const fn as_str(self) -> &'static str {
56 match self {
57 Self::Authenticated => "authenticated",
58 Self::Elevated => "elevated",
59 Self::MfaVerified => "mfa_verified",
60 }
61 }
62
63 pub const fn rank(self) -> u8 {
66 match self {
67 Self::Authenticated => 1,
68 Self::Elevated => 2,
69 Self::MfaVerified => 3,
70 }
71 }
72
73 pub const fn satisfies(self, other: SessionTrust) -> bool {
76 self.rank() >= other.rank()
77 }
78
79 pub fn parse(s: &str) -> Self {
84 match s {
85 "elevated" => Self::Elevated,
86 "mfa_verified" => Self::MfaVerified,
87 _ => Self::Authenticated,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum SessionInvalidationReason {
103 Logout,
104 Expired,
105 UserRequested,
106 AdministrativeRevoke,
107 PasswordReset,
108 PasswordResetByOther,
109 MfaEnabled,
110 MfaDisabled,
111 MfaDisabledByOther,
112 AuthorityEscalation,
113 EmergencyRecovery,
114 RoleChangedByOther,
121 TrustEscalation,
125}
126
127impl SessionInvalidationReason {
128 pub const fn as_str(self) -> &'static str {
133 match self {
134 Self::Logout => "logout",
135 Self::Expired => "expired",
136 Self::UserRequested => "user_requested",
137 Self::AdministrativeRevoke => "administrative_revoke",
138 Self::PasswordReset => "password_reset",
139 Self::PasswordResetByOther => "password_reset_by_other",
140 Self::MfaEnabled => "mfa_enabled",
141 Self::MfaDisabled => "mfa_disabled",
142 Self::MfaDisabledByOther => "mfa_disabled_by_other",
143 Self::AuthorityEscalation => "authority_escalation",
144 Self::EmergencyRecovery => "emergency_recovery",
145 Self::RoleChangedByOther => "role_changed_by_other",
146 Self::TrustEscalation => "trust_escalation",
147 }
148 }
149}
150
151#[derive(Debug, Clone, Copy)]
154pub enum SessionTarget {
155 User { user_id: i64 },
157 UserExceptCurrent {
162 user_id: i64,
163 current_session_id: i64,
164 },
165 Single { session_id: i64 },
167}
168
169#[derive(Debug, Clone, Serialize)]
173pub struct Session {
174 pub session_id: i64,
175 pub user_id: i64,
176 pub trust_level: SessionTrust,
177 pub created_at: DateTime<Utc>,
178 pub last_seen: DateTime<Utc>,
179 pub expires_at: DateTime<Utc>,
180 pub elevated_until: Option<DateTime<Utc>>,
181 pub ip: Option<String>,
182 pub user_agent: Option<String>,
183}
184
185#[derive(Debug, Clone, Default)]
190pub struct InvalidationOutcome {
191 pub revoked_session_ids: Vec<i64>,
193 pub reason: Option<SessionInvalidationReason>,
195}
196
197pub async fn init_session_tables(db: &Db) -> Result<()> {
199 sqlx::query(
200 "CREATE TABLE IF NOT EXISTS rustio_sessions (
201 token TEXT PRIMARY KEY,
202 user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
203 expires_at TIMESTAMPTZ NOT NULL,
204 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
205 last_seen TIMESTAMPTZ NOT NULL DEFAULT NOW()
206 )",
207 )
208 .execute(db.pool())
209 .await?;
210
211 sqlx::query("CREATE INDEX IF NOT EXISTS rustio_sessions_user_idx ON rustio_sessions (user_id)")
212 .execute(db.pool())
213 .await?;
214
215 sqlx::query(
216 "CREATE INDEX IF NOT EXISTS rustio_sessions_expires_idx ON rustio_sessions (expires_at)",
217 )
218 .execute(db.pool())
219 .await?;
220
221 Ok(())
222}
223
224pub(crate) async fn migrate_session_schema(db: &Db) -> Result<()> {
228 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS ip TEXT")
229 .execute(db.pool())
230 .await?;
231 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS user_agent TEXT")
232 .execute(db.pool())
233 .await?;
234 Ok(())
235}
236
237pub(crate) async fn migrate_session_lifecycle(db: &Db) -> Result<()> {
262 sqlx::query("CREATE SEQUENCE IF NOT EXISTS rustio_sessions_session_id_seq")
263 .execute(db.pool())
264 .await?;
265 sqlx::query(
266 "ALTER TABLE rustio_sessions \
267 ADD COLUMN IF NOT EXISTS session_id BIGINT NOT NULL DEFAULT \
268 nextval('rustio_sessions_session_id_seq')",
269 )
270 .execute(db.pool())
271 .await?;
272 sqlx::query(
273 "ALTER SEQUENCE rustio_sessions_session_id_seq OWNED BY rustio_sessions.session_id",
274 )
275 .execute(db.pool())
276 .await?;
277 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS token_hash TEXT")
278 .execute(db.pool())
279 .await?;
280 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS device_id TEXT")
281 .execute(db.pool())
282 .await?;
283 sqlx::query(
284 "ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS trust_level TEXT \
285 NOT NULL DEFAULT 'authenticated'",
286 )
287 .execute(db.pool())
288 .await?;
289 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS elevated_until TIMESTAMPTZ")
290 .execute(db.pool())
291 .await?;
292 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS parent_session_id BIGINT")
293 .execute(db.pool())
294 .await?;
295 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ")
296 .execute(db.pool())
297 .await?;
298 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_reason TEXT")
299 .execute(db.pool())
300 .await?;
301
302 sqlx::query(
305 "DO $$ BEGIN \
306 IF NOT EXISTS ( \
307 SELECT 1 FROM pg_constraint \
308 WHERE conname = 'rustio_sessions_trust_level_check' \
309 ) THEN \
310 ALTER TABLE rustio_sessions \
311 ADD CONSTRAINT rustio_sessions_trust_level_check \
312 CHECK (trust_level IN ('authenticated', 'elevated', 'mfa_verified')); \
313 END IF; \
314 END $$",
315 )
316 .execute(db.pool())
317 .await?;
318
319 sqlx::query(
320 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_session_id_uq \
321 ON rustio_sessions (session_id)",
322 )
323 .execute(db.pool())
324 .await?;
325 sqlx::query(
326 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_token_hash_uq \
327 ON rustio_sessions (token_hash) \
328 WHERE revoked_at IS NULL AND token_hash IS NOT NULL",
329 )
330 .execute(db.pool())
331 .await?;
332 sqlx::query(
333 "CREATE INDEX IF NOT EXISTS rustio_sessions_user_active_idx \
334 ON rustio_sessions (user_id) WHERE revoked_at IS NULL",
335 )
336 .execute(db.pool())
337 .await?;
338 sqlx::query(
339 "CREATE INDEX IF NOT EXISTS rustio_sessions_parent_idx \
340 ON rustio_sessions (parent_session_id) WHERE parent_session_id IS NOT NULL",
341 )
342 .execute(db.pool())
343 .await?;
344
345 Ok(())
346}
347
348pub async fn create_session(db: &Db, user_id: i64) -> Result<String> {
350 let token = random_token();
351 let token_hash = hash_token_for_storage(&token);
352 let expires = Utc::now() + Duration::days(SESSION_LENGTH_DAYS);
353 sqlx::query(
359 "INSERT INTO rustio_sessions (token, token_hash, user_id, expires_at) \
360 VALUES ($1, $2, $3, $4)",
361 )
362 .bind(&token)
363 .bind(&token_hash)
364 .bind(user_id)
365 .bind(expires)
366 .execute(db.pool())
367 .await?;
368 Ok(token)
369}
370
371pub async fn delete_session(db: &Db, token: &str) -> Result<()> {
380 sqlx::query("DELETE FROM rustio_sessions WHERE token = $1 OR token_hash = $2")
381 .bind(token)
382 .bind(hash_token_for_storage(token))
383 .execute(db.pool())
384 .await?;
385 Ok(())
386}
387
388pub async fn invalidate_sessions(
412 db: &Db,
413 target: SessionTarget,
414 reason: SessionInvalidationReason,
415) -> Result<InvalidationOutcome> {
416 let reason_str = reason.as_str();
417 let revoked_ids: Vec<i64> = match target {
418 SessionTarget::User { user_id } => {
419 sqlx::query_scalar::<_, i64>(
420 "UPDATE rustio_sessions \
421 SET revoked_at = NOW(), revoked_reason = $2 \
422 WHERE user_id = $1 AND revoked_at IS NULL \
423 RETURNING session_id",
424 )
425 .bind(user_id)
426 .bind(reason_str)
427 .fetch_all(db.pool())
428 .await?
429 }
430 SessionTarget::UserExceptCurrent {
431 user_id,
432 current_session_id,
433 } => {
434 sqlx::query_scalar::<_, i64>(
435 "UPDATE rustio_sessions \
436 SET revoked_at = NOW(), revoked_reason = $3 \
437 WHERE user_id = $1 AND session_id <> $2 AND revoked_at IS NULL \
438 RETURNING session_id",
439 )
440 .bind(user_id)
441 .bind(current_session_id)
442 .bind(reason_str)
443 .fetch_all(db.pool())
444 .await?
445 }
446 SessionTarget::Single { session_id } => {
447 sqlx::query_scalar::<_, i64>(
448 "UPDATE rustio_sessions \
449 SET revoked_at = NOW(), revoked_reason = $2 \
450 WHERE session_id = $1 AND revoked_at IS NULL \
451 RETURNING session_id",
452 )
453 .bind(session_id)
454 .bind(reason_str)
455 .fetch_all(db.pool())
456 .await?
457 }
458 };
459
460 Ok(InvalidationOutcome {
461 revoked_session_ids: revoked_ids,
462 reason: Some(reason),
463 })
464}
465
466pub async fn logout_session(db: &Db, token: &str) -> Result<()> {
475 let token_hash = hash_token_for_storage(token);
476 let session_id: Option<i64> = sqlx::query_scalar::<_, i64>(
477 "SELECT session_id FROM rustio_sessions \
478 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
479 AND revoked_at IS NULL \
480 LIMIT 1",
481 )
482 .bind(&token_hash)
483 .bind(token)
484 .fetch_optional(db.pool())
485 .await?;
486
487 if let Some(sid) = session_id {
488 invalidate_sessions(
489 db,
490 SessionTarget::Single { session_id: sid },
491 SessionInvalidationReason::Logout,
492 )
493 .await?;
494 }
495 Ok(())
496}
497
498pub async fn list_active_for_user(db: &Db, user_id: i64) -> Result<Vec<Session>> {
503 let rows = sqlx::query(
504 "SELECT session_id, user_id, trust_level, created_at, last_seen, expires_at, \
505 elevated_until, ip, user_agent \
506 FROM rustio_sessions \
507 WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() \
508 ORDER BY last_seen DESC",
509 )
510 .bind(user_id)
511 .fetch_all(db.pool())
512 .await?;
513
514 rows.iter()
515 .map(|r| {
516 let r = Row::from_pg(r);
517 Ok(Session {
518 session_id: r.get_i64("session_id")?,
519 user_id: r.get_i64("user_id")?,
520 trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
521 created_at: r.get_datetime("created_at")?,
522 last_seen: r.get_datetime("last_seen")?,
523 expires_at: r.get_datetime("expires_at")?,
524 elevated_until: r.get_optional_datetime("elevated_until")?,
525 ip: r.get_optional_string("ip")?,
526 user_agent: r.get_optional_string("user_agent")?,
527 })
528 })
529 .collect()
530}
531
532pub async fn current_session_id(db: &Db, token: &str) -> Result<Option<i64>> {
537 let token_hash = hash_token_for_storage(token);
538 let id: Option<i64> = sqlx::query_scalar::<_, i64>(
539 "SELECT session_id FROM rustio_sessions \
540 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
541 AND revoked_at IS NULL AND expires_at > NOW() \
542 LIMIT 1",
543 )
544 .bind(&token_hash)
545 .bind(token)
546 .fetch_optional(db.pool())
547 .await?;
548 Ok(id)
549}
550
551pub async fn identity_from_session(db: &Db, token: &str) -> Result<Option<Identity>> {
553 let token_hash = hash_token_for_storage(token);
559 let row = sqlx::query(
560 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
561 u.must_change_password, u.mfa_enabled, \
562 s.expires_at, s.trust_level, \
563 s.token_hash IS NOT NULL AS hashed \
564 FROM rustio_sessions s \
565 JOIN rustio_users u ON u.id = s.user_id \
566 WHERE s.token_hash = $1 AND s.revoked_at IS NULL",
567 )
568 .bind(&token_hash)
569 .fetch_optional(db.pool())
570 .await?;
571
572 let row = match row {
573 Some(r) => Some(r),
574 None => {
582 sqlx::query(
583 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
584 u.must_change_password, u.mfa_enabled, \
585 s.expires_at, s.trust_level, FALSE AS hashed \
586 FROM rustio_sessions s \
587 JOIN rustio_users u ON u.id = s.user_id \
588 WHERE s.token = $1 AND s.token_hash IS NULL AND s.revoked_at IS NULL",
589 )
590 .bind(token)
591 .fetch_optional(db.pool())
592 .await?
593 }
594 };
595 let row = match row {
596 Some(r) => r,
597 None => return Ok(None),
598 };
599 let r = Row::from_pg(&row);
600 let expires_at = r.get_datetime("expires_at")?;
601 if expires_at < Utc::now() {
602 let _ = delete_session(db, token).await;
608 return Ok(None);
609 }
610
611 let db_clone = db.clone();
615 let token_owned = token.to_string();
616 let token_hash_owned = token_hash.clone();
617 tokio::spawn(async move {
618 let _ = sqlx::query(
619 "UPDATE rustio_sessions SET last_seen = NOW() \
620 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
621 AND revoked_at IS NULL",
622 )
623 .bind(&token_hash_owned)
624 .bind(&token_owned)
625 .execute(db_clone.pool())
626 .await;
627 });
628
629 Ok(Some(Identity {
630 user_id: r.get_i64("id")?,
631 email: r.get_string("email")?,
632 role: Role::parse(&r.get_string("role")?)?,
633 is_active: r.get_bool("is_active")?,
634 is_demo: r.get_bool("is_demo")?,
635 demo_label: r.get_optional_string("demo_label")?,
636 must_change_password: r.get_bool("must_change_password")?,
637 mfa_enabled: r.get_bool("mfa_enabled")?,
638 trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
639 }))
640}
641
642pub async fn purge_expired_sessions(db: &Db) -> Result<u64> {
646 let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at < NOW()")
647 .execute(db.pool())
648 .await?;
649 Ok(result.rows_affected())
650}
651
652pub fn session_token_from_cookie(cookie_header: &str) -> Option<String> {
654 let prefix = format!("{SESSION_COOKIE}=");
655 for part in cookie_header.split(';') {
656 let part = part.trim();
657 if let Some(v) = part.strip_prefix(&prefix) {
658 return Some(v.to_string());
659 }
660 }
661 None
662}
663
664pub(crate) fn random_token() -> String {
671 let mut bytes = [0u8; 32];
672 rand::thread_rng().fill_bytes(&mut bytes);
673 URL_SAFE_NO_PAD.encode(bytes)
674}
675
676pub fn hash_token_for_storage(token: &str) -> String {
687 let digest = Sha256::digest(token.as_bytes());
688 URL_SAFE_NO_PAD.encode(digest)
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 #[test]
696 fn extracts_token_from_cookie_header() {
697 let h = "foo=bar; rustio_session=abc123; other=x";
698 assert_eq!(session_token_from_cookie(h), Some("abc123".into()));
699 }
700
701 #[test]
702 fn returns_none_when_cookie_missing() {
703 let h = "foo=bar; other=x";
704 assert!(session_token_from_cookie(h).is_none());
705 }
706
707 #[test]
708 fn random_token_has_reasonable_entropy() {
709 assert_ne!(random_token(), random_token());
711 }
712
713 #[test]
714 fn hash_token_is_deterministic() {
715 let token = random_token();
718 assert_eq!(
719 hash_token_for_storage(&token),
720 hash_token_for_storage(&token)
721 );
722 }
723
724 #[test]
725 fn hash_token_differs_per_token() {
726 let a = hash_token_for_storage("aaaa");
729 let b = hash_token_for_storage("aaab");
730 assert_ne!(a, b);
731 }
732
733 #[test]
734 fn hash_token_output_is_url_safe_base64() {
735 let h = hash_token_for_storage("anything");
736 assert_eq!(h.len(), 43);
738 assert!(h
739 .chars()
740 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
741 }
742
743 #[test]
744 fn hash_token_does_not_leak_plaintext() {
745 let plaintext = "secret-cookie-value-12345";
748 let h = hash_token_for_storage(plaintext);
749 assert!(!h.contains("secret"));
750 assert!(!h.contains("12345"));
751 }
752
753 #[test]
756 fn session_trust_orders_correctly() {
757 assert!(SessionTrust::Authenticated.rank() < SessionTrust::Elevated.rank());
758 assert!(SessionTrust::Elevated.rank() < SessionTrust::MfaVerified.rank());
759 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Elevated));
760 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Authenticated));
761 assert!(SessionTrust::Authenticated.satisfies(SessionTrust::Authenticated));
762 assert!(!SessionTrust::Authenticated.satisfies(SessionTrust::Elevated));
763 assert!(!SessionTrust::Elevated.satisfies(SessionTrust::MfaVerified));
764 }
765
766 #[test]
767 fn session_trust_round_trips_through_sql() {
768 for tier in [
769 SessionTrust::Authenticated,
770 SessionTrust::Elevated,
771 SessionTrust::MfaVerified,
772 ] {
773 assert_eq!(SessionTrust::parse(tier.as_str()), tier);
774 }
775 }
776
777 #[test]
778 fn session_trust_parse_defaults_safely_on_unknown() {
779 assert_eq!(SessionTrust::parse("garbage"), SessionTrust::Authenticated);
782 assert_eq!(SessionTrust::parse(""), SessionTrust::Authenticated);
783 }
784
785 #[test]
786 fn invalidation_reason_strings_are_distinct() {
787 let reasons = [
790 SessionInvalidationReason::Logout,
791 SessionInvalidationReason::Expired,
792 SessionInvalidationReason::UserRequested,
793 SessionInvalidationReason::AdministrativeRevoke,
794 SessionInvalidationReason::PasswordReset,
795 SessionInvalidationReason::PasswordResetByOther,
796 SessionInvalidationReason::MfaEnabled,
797 SessionInvalidationReason::MfaDisabled,
798 SessionInvalidationReason::MfaDisabledByOther,
799 SessionInvalidationReason::AuthorityEscalation,
800 SessionInvalidationReason::EmergencyRecovery,
801 SessionInvalidationReason::TrustEscalation,
802 ];
803 let mut set = std::collections::HashSet::new();
804 for r in reasons {
805 assert!(set.insert(r.as_str()), "duplicate as_str() for {r:?}");
806 }
807 assert_eq!(set.len(), reasons.len());
808 }
809}