1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
17use chrono::{DateTime, Duration, Utc};
18use rand::RngCore;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21
22use crate::error::Result;
23use crate::orm::{Db, Row};
24
25use super::role::Role;
26use super::users::Identity;
27
28pub const SESSION_COOKIE: &str = "rustio_session";
31
32const SESSION_LENGTH_DAYS: i64 = 14;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum SessionTrust {
44 Authenticated,
45 Elevated,
46 MfaVerified,
47}
48
49impl SessionTrust {
50 pub const fn as_str(self) -> &'static str {
53 match self {
54 Self::Authenticated => "authenticated",
55 Self::Elevated => "elevated",
56 Self::MfaVerified => "mfa_verified",
57 }
58 }
59
60 pub const fn rank(self) -> u8 {
62 match self {
63 Self::Authenticated => 1,
64 Self::Elevated => 2,
65 Self::MfaVerified => 3,
66 }
67 }
68
69 pub const fn satisfies(self, other: SessionTrust) -> bool {
71 self.rank() >= other.rank()
72 }
73
74 pub fn parse(s: &str) -> Self {
78 match s {
79 "elevated" => Self::Elevated,
80 "mfa_verified" => Self::MfaVerified,
81 _ => Self::Authenticated,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum SessionInvalidationReason {
96 Logout,
97 Expired,
98 UserRequested,
99 AdministrativeRevoke,
100 PasswordReset,
101 PasswordResetByOther,
102 MfaEnabled,
103 MfaDisabled,
104 MfaDisabledByOther,
105 AuthorityEscalation,
106 EmergencyRecovery,
107 TrustEscalation,
111}
112
113impl SessionInvalidationReason {
114 pub const fn as_str(self) -> &'static str {
118 match self {
119 Self::Logout => "logout",
120 Self::Expired => "expired",
121 Self::UserRequested => "user_requested",
122 Self::AdministrativeRevoke => "administrative_revoke",
123 Self::PasswordReset => "password_reset",
124 Self::PasswordResetByOther => "password_reset_by_other",
125 Self::MfaEnabled => "mfa_enabled",
126 Self::MfaDisabled => "mfa_disabled",
127 Self::MfaDisabledByOther => "mfa_disabled_by_other",
128 Self::AuthorityEscalation => "authority_escalation",
129 Self::EmergencyRecovery => "emergency_recovery",
130 Self::TrustEscalation => "trust_escalation",
131 }
132 }
133}
134
135#[derive(Debug, Clone, Copy)]
137pub enum SessionTarget {
138 User { user_id: i64 },
140 UserExceptCurrent {
145 user_id: i64,
146 current_session_id: i64,
147 },
148 Single { session_id: i64 },
150}
151
152#[derive(Debug, Clone, Serialize)]
155pub struct Session {
156 pub session_id: i64,
157 pub user_id: i64,
158 pub trust_level: SessionTrust,
159 pub created_at: DateTime<Utc>,
160 pub last_seen: DateTime<Utc>,
161 pub expires_at: DateTime<Utc>,
162 pub elevated_until: Option<DateTime<Utc>>,
163 pub ip: Option<String>,
164 pub user_agent: Option<String>,
165}
166
167#[derive(Debug, Clone, Default)]
171pub struct InvalidationOutcome {
172 pub revoked_session_ids: Vec<i64>,
174 pub reason: Option<SessionInvalidationReason>,
176}
177
178pub async fn init_session_tables(db: &Db) -> Result<()> {
179 sqlx::query(
180 "CREATE TABLE IF NOT EXISTS rustio_sessions (
181 token TEXT PRIMARY KEY,
182 user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
183 expires_at TIMESTAMPTZ NOT NULL,
184 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
185 last_seen TIMESTAMPTZ NOT NULL DEFAULT NOW()
186 )",
187 )
188 .execute(db.pool())
189 .await?;
190
191 sqlx::query("CREATE INDEX IF NOT EXISTS rustio_sessions_user_idx ON rustio_sessions (user_id)")
192 .execute(db.pool())
193 .await?;
194
195 sqlx::query(
196 "CREATE INDEX IF NOT EXISTS rustio_sessions_expires_idx ON rustio_sessions (expires_at)",
197 )
198 .execute(db.pool())
199 .await?;
200
201 Ok(())
202}
203
204pub(crate) async fn migrate_session_schema(db: &Db) -> Result<()> {
208 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS ip TEXT")
209 .execute(db.pool())
210 .await?;
211 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS user_agent TEXT")
212 .execute(db.pool())
213 .await?;
214 Ok(())
215}
216
217pub(crate) async fn migrate_session_lifecycle(db: &Db) -> Result<()> {
242 sqlx::query("CREATE SEQUENCE IF NOT EXISTS rustio_sessions_session_id_seq")
243 .execute(db.pool())
244 .await?;
245 sqlx::query(
246 "ALTER TABLE rustio_sessions \
247 ADD COLUMN IF NOT EXISTS session_id BIGINT NOT NULL DEFAULT \
248 nextval('rustio_sessions_session_id_seq')",
249 )
250 .execute(db.pool())
251 .await?;
252 sqlx::query(
253 "ALTER SEQUENCE rustio_sessions_session_id_seq OWNED BY rustio_sessions.session_id",
254 )
255 .execute(db.pool())
256 .await?;
257 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS token_hash TEXT")
258 .execute(db.pool())
259 .await?;
260 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS device_id TEXT")
261 .execute(db.pool())
262 .await?;
263 sqlx::query(
264 "ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS trust_level TEXT \
265 NOT NULL DEFAULT 'authenticated'",
266 )
267 .execute(db.pool())
268 .await?;
269 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS elevated_until TIMESTAMPTZ")
270 .execute(db.pool())
271 .await?;
272 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS parent_session_id BIGINT")
273 .execute(db.pool())
274 .await?;
275 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ")
276 .execute(db.pool())
277 .await?;
278 sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_reason TEXT")
279 .execute(db.pool())
280 .await?;
281
282 sqlx::query(
285 "DO $$ BEGIN \
286 IF NOT EXISTS ( \
287 SELECT 1 FROM pg_constraint \
288 WHERE conname = 'rustio_sessions_trust_level_check' \
289 ) THEN \
290 ALTER TABLE rustio_sessions \
291 ADD CONSTRAINT rustio_sessions_trust_level_check \
292 CHECK (trust_level IN ('authenticated', 'elevated', 'mfa_verified')); \
293 END IF; \
294 END $$",
295 )
296 .execute(db.pool())
297 .await?;
298
299 sqlx::query(
300 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_session_id_uq \
301 ON rustio_sessions (session_id)",
302 )
303 .execute(db.pool())
304 .await?;
305 sqlx::query(
306 "CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_token_hash_uq \
307 ON rustio_sessions (token_hash) \
308 WHERE revoked_at IS NULL AND token_hash IS NOT NULL",
309 )
310 .execute(db.pool())
311 .await?;
312 sqlx::query(
313 "CREATE INDEX IF NOT EXISTS rustio_sessions_user_active_idx \
314 ON rustio_sessions (user_id) WHERE revoked_at IS NULL",
315 )
316 .execute(db.pool())
317 .await?;
318 sqlx::query(
319 "CREATE INDEX IF NOT EXISTS rustio_sessions_parent_idx \
320 ON rustio_sessions (parent_session_id) WHERE parent_session_id IS NOT NULL",
321 )
322 .execute(db.pool())
323 .await?;
324
325 Ok(())
326}
327
328pub async fn create_session(db: &Db, user_id: i64) -> Result<String> {
329 let token = random_token();
330 let token_hash = hash_token_for_storage(&token);
331 let expires = Utc::now() + Duration::days(SESSION_LENGTH_DAYS);
332 sqlx::query(
338 "INSERT INTO rustio_sessions (token, token_hash, user_id, expires_at) \
339 VALUES ($1, $2, $3, $4)",
340 )
341 .bind(&token)
342 .bind(&token_hash)
343 .bind(user_id)
344 .bind(expires)
345 .execute(db.pool())
346 .await?;
347 Ok(token)
348}
349
350pub async fn delete_session(db: &Db, token: &str) -> Result<()> {
358 sqlx::query("DELETE FROM rustio_sessions WHERE token = $1 OR token_hash = $2")
359 .bind(token)
360 .bind(hash_token_for_storage(token))
361 .execute(db.pool())
362 .await?;
363 Ok(())
364}
365
366pub async fn invalidate_sessions(
389 db: &Db,
390 target: SessionTarget,
391 reason: SessionInvalidationReason,
392) -> Result<InvalidationOutcome> {
393 let reason_str = reason.as_str();
394 let revoked_ids: Vec<i64> = match target {
395 SessionTarget::User { user_id } => {
396 sqlx::query_scalar::<_, i64>(
397 "UPDATE rustio_sessions \
398 SET revoked_at = NOW(), revoked_reason = $2 \
399 WHERE user_id = $1 AND revoked_at IS NULL \
400 RETURNING session_id",
401 )
402 .bind(user_id)
403 .bind(reason_str)
404 .fetch_all(db.pool())
405 .await?
406 }
407 SessionTarget::UserExceptCurrent {
408 user_id,
409 current_session_id,
410 } => {
411 sqlx::query_scalar::<_, i64>(
412 "UPDATE rustio_sessions \
413 SET revoked_at = NOW(), revoked_reason = $3 \
414 WHERE user_id = $1 AND session_id <> $2 AND revoked_at IS NULL \
415 RETURNING session_id",
416 )
417 .bind(user_id)
418 .bind(current_session_id)
419 .bind(reason_str)
420 .fetch_all(db.pool())
421 .await?
422 }
423 SessionTarget::Single { session_id } => {
424 sqlx::query_scalar::<_, i64>(
425 "UPDATE rustio_sessions \
426 SET revoked_at = NOW(), revoked_reason = $2 \
427 WHERE session_id = $1 AND revoked_at IS NULL \
428 RETURNING session_id",
429 )
430 .bind(session_id)
431 .bind(reason_str)
432 .fetch_all(db.pool())
433 .await?
434 }
435 };
436
437 Ok(InvalidationOutcome {
438 revoked_session_ids: revoked_ids,
439 reason: Some(reason),
440 })
441}
442
443pub async fn logout_session(db: &Db, token: &str) -> Result<()> {
451 let token_hash = hash_token_for_storage(token);
452 let session_id: Option<i64> = sqlx::query_scalar::<_, i64>(
453 "SELECT session_id FROM rustio_sessions \
454 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
455 AND revoked_at IS NULL \
456 LIMIT 1",
457 )
458 .bind(&token_hash)
459 .bind(token)
460 .fetch_optional(db.pool())
461 .await?;
462
463 if let Some(sid) = session_id {
464 invalidate_sessions(
465 db,
466 SessionTarget::Single { session_id: sid },
467 SessionInvalidationReason::Logout,
468 )
469 .await?;
470 }
471 Ok(())
472}
473
474pub async fn list_active_for_user(db: &Db, user_id: i64) -> Result<Vec<Session>> {
478 let rows = sqlx::query(
479 "SELECT session_id, user_id, trust_level, created_at, last_seen, expires_at, \
480 elevated_until, ip, user_agent \
481 FROM rustio_sessions \
482 WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() \
483 ORDER BY last_seen DESC",
484 )
485 .bind(user_id)
486 .fetch_all(db.pool())
487 .await?;
488
489 rows.iter()
490 .map(|r| {
491 let r = Row::from_pg(r);
492 Ok(Session {
493 session_id: r.get_i64("session_id")?,
494 user_id: r.get_i64("user_id")?,
495 trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
496 created_at: r.get_datetime("created_at")?,
497 last_seen: r.get_datetime("last_seen")?,
498 expires_at: r.get_datetime("expires_at")?,
499 elevated_until: r.get_optional_datetime("elevated_until")?,
500 ip: r.get_optional_string("ip")?,
501 user_agent: r.get_optional_string("user_agent")?,
502 })
503 })
504 .collect()
505}
506
507pub async fn current_session_id(db: &Db, token: &str) -> Result<Option<i64>> {
511 let token_hash = hash_token_for_storage(token);
512 let id: Option<i64> = sqlx::query_scalar::<_, i64>(
513 "SELECT session_id FROM rustio_sessions \
514 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
515 AND revoked_at IS NULL AND expires_at > NOW() \
516 LIMIT 1",
517 )
518 .bind(&token_hash)
519 .bind(token)
520 .fetch_optional(db.pool())
521 .await?;
522 Ok(id)
523}
524
525pub async fn identity_from_session(db: &Db, token: &str) -> Result<Option<Identity>> {
526 let token_hash = hash_token_for_storage(token);
532 let row = sqlx::query(
533 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
534 u.must_change_password, u.mfa_enabled, \
535 s.expires_at, s.trust_level, \
536 s.token_hash IS NOT NULL AS hashed \
537 FROM rustio_sessions s \
538 JOIN rustio_users u ON u.id = s.user_id \
539 WHERE s.token_hash = $1 AND s.revoked_at IS NULL",
540 )
541 .bind(&token_hash)
542 .fetch_optional(db.pool())
543 .await?;
544
545 let row = match row {
546 Some(r) => Some(r),
547 None => {
555 sqlx::query(
556 "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
557 u.must_change_password, u.mfa_enabled, \
558 s.expires_at, s.trust_level, FALSE AS hashed \
559 FROM rustio_sessions s \
560 JOIN rustio_users u ON u.id = s.user_id \
561 WHERE s.token = $1 AND s.token_hash IS NULL AND s.revoked_at IS NULL",
562 )
563 .bind(token)
564 .fetch_optional(db.pool())
565 .await?
566 }
567 };
568 let row = match row {
569 Some(r) => r,
570 None => return Ok(None),
571 };
572 let r = Row::from_pg(&row);
573 let expires_at = r.get_datetime("expires_at")?;
574 if expires_at < Utc::now() {
575 let _ = delete_session(db, token).await;
581 return Ok(None);
582 }
583
584 let db_clone = db.clone();
588 let token_owned = token.to_string();
589 let token_hash_owned = token_hash.clone();
590 tokio::spawn(async move {
591 let _ = sqlx::query(
592 "UPDATE rustio_sessions SET last_seen = NOW() \
593 WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
594 AND revoked_at IS NULL",
595 )
596 .bind(&token_hash_owned)
597 .bind(&token_owned)
598 .execute(db_clone.pool())
599 .await;
600 });
601
602 Ok(Some(Identity {
603 user_id: r.get_i64("id")?,
604 email: r.get_string("email")?,
605 role: Role::parse(&r.get_string("role")?)?,
606 is_active: r.get_bool("is_active")?,
607 is_demo: r.get_bool("is_demo")?,
608 demo_label: r.get_optional_string("demo_label")?,
609 must_change_password: r.get_bool("must_change_password")?,
610 mfa_enabled: r.get_bool("mfa_enabled")?,
611 trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
612 }))
613}
614
615pub async fn purge_expired_sessions(db: &Db) -> Result<u64> {
618 let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at < NOW()")
619 .execute(db.pool())
620 .await?;
621 Ok(result.rows_affected())
622}
623
624pub fn session_token_from_cookie(cookie_header: &str) -> Option<String> {
625 let prefix = format!("{SESSION_COOKIE}=");
626 for part in cookie_header.split(';') {
627 let part = part.trim();
628 if let Some(v) = part.strip_prefix(&prefix) {
629 return Some(v.to_string());
630 }
631 }
632 None
633}
634
635pub(crate) fn random_token() -> String {
642 let mut bytes = [0u8; 32];
643 rand::thread_rng().fill_bytes(&mut bytes);
644 URL_SAFE_NO_PAD.encode(bytes)
645}
646
647pub(crate) fn hash_token_for_storage(token: &str) -> String {
658 let digest = Sha256::digest(token.as_bytes());
659 URL_SAFE_NO_PAD.encode(digest)
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn extracts_token_from_cookie_header() {
668 let h = "foo=bar; rustio_session=abc123; other=x";
669 assert_eq!(session_token_from_cookie(h), Some("abc123".into()));
670 }
671
672 #[test]
673 fn returns_none_when_cookie_missing() {
674 let h = "foo=bar; other=x";
675 assert!(session_token_from_cookie(h).is_none());
676 }
677
678 #[test]
679 fn random_token_has_reasonable_entropy() {
680 assert_ne!(random_token(), random_token());
682 }
683
684 #[test]
685 fn hash_token_is_deterministic() {
686 let token = random_token();
689 assert_eq!(
690 hash_token_for_storage(&token),
691 hash_token_for_storage(&token)
692 );
693 }
694
695 #[test]
696 fn hash_token_differs_per_token() {
697 let a = hash_token_for_storage("aaaa");
700 let b = hash_token_for_storage("aaab");
701 assert_ne!(a, b);
702 }
703
704 #[test]
705 fn hash_token_output_is_url_safe_base64() {
706 let h = hash_token_for_storage("anything");
707 assert_eq!(h.len(), 43);
709 assert!(h
710 .chars()
711 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
712 }
713
714 #[test]
715 fn hash_token_does_not_leak_plaintext() {
716 let plaintext = "secret-cookie-value-12345";
719 let h = hash_token_for_storage(plaintext);
720 assert!(!h.contains("secret"));
721 assert!(!h.contains("12345"));
722 }
723
724 #[test]
727 fn session_trust_orders_correctly() {
728 assert!(SessionTrust::Authenticated.rank() < SessionTrust::Elevated.rank());
729 assert!(SessionTrust::Elevated.rank() < SessionTrust::MfaVerified.rank());
730 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Elevated));
731 assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Authenticated));
732 assert!(SessionTrust::Authenticated.satisfies(SessionTrust::Authenticated));
733 assert!(!SessionTrust::Authenticated.satisfies(SessionTrust::Elevated));
734 assert!(!SessionTrust::Elevated.satisfies(SessionTrust::MfaVerified));
735 }
736
737 #[test]
738 fn session_trust_round_trips_through_sql() {
739 for tier in [
740 SessionTrust::Authenticated,
741 SessionTrust::Elevated,
742 SessionTrust::MfaVerified,
743 ] {
744 assert_eq!(SessionTrust::parse(tier.as_str()), tier);
745 }
746 }
747
748 #[test]
749 fn session_trust_parse_defaults_safely_on_unknown() {
750 assert_eq!(SessionTrust::parse("garbage"), SessionTrust::Authenticated);
753 assert_eq!(SessionTrust::parse(""), SessionTrust::Authenticated);
754 }
755
756 #[test]
757 fn invalidation_reason_strings_are_distinct() {
758 let reasons = [
761 SessionInvalidationReason::Logout,
762 SessionInvalidationReason::Expired,
763 SessionInvalidationReason::UserRequested,
764 SessionInvalidationReason::AdministrativeRevoke,
765 SessionInvalidationReason::PasswordReset,
766 SessionInvalidationReason::PasswordResetByOther,
767 SessionInvalidationReason::MfaEnabled,
768 SessionInvalidationReason::MfaDisabled,
769 SessionInvalidationReason::MfaDisabledByOther,
770 SessionInvalidationReason::AuthorityEscalation,
771 SessionInvalidationReason::EmergencyRecovery,
772 SessionInvalidationReason::TrustEscalation,
773 ];
774 let mut set = std::collections::HashSet::new();
775 for r in reasons {
776 assert!(set.insert(r.as_str()), "duplicate as_str() for {r:?}");
777 }
778 assert_eq!(set.len(), reasons.len());
779 }
780}