Skip to main content

rustio_core/auth/
sessions.rs

1//! DB-backed sessions with a background expiry sweeper.
2
3use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
4use chrono::{Duration, Utc};
5use rand::RngCore;
6
7use crate::error::Result;
8use crate::orm::{Db, Row};
9
10use super::role::Role;
11use super::users::Identity;
12
13/// The cookie name we look for and set. Constant so middleware and
14/// handlers stay in sync.
15pub const SESSION_COOKIE: &str = "rustio_session";
16
17const SESSION_LENGTH_DAYS: i64 = 14;
18
19pub async fn init_session_tables(db: &Db) -> Result<()> {
20    sqlx::query(
21        "CREATE TABLE IF NOT EXISTS rustio_sessions (
22            token      TEXT PRIMARY KEY,
23            user_id    BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
24            expires_at TIMESTAMPTZ NOT NULL,
25            created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
26            last_seen  TIMESTAMPTZ NOT NULL DEFAULT NOW()
27        )",
28    )
29    .execute(db.pool())
30    .await?;
31
32    sqlx::query(
33        "CREATE INDEX IF NOT EXISTS rustio_sessions_user_idx ON rustio_sessions (user_id)",
34    )
35    .execute(db.pool())
36    .await?;
37
38    sqlx::query(
39        "CREATE INDEX IF NOT EXISTS rustio_sessions_expires_idx ON rustio_sessions (expires_at)",
40    )
41    .execute(db.pool())
42    .await?;
43
44    Ok(())
45}
46
47pub async fn create_session(db: &Db, user_id: i64) -> Result<String> {
48    let token = random_token();
49    let expires = Utc::now() + Duration::days(SESSION_LENGTH_DAYS);
50    sqlx::query(
51        "INSERT INTO rustio_sessions (token, user_id, expires_at) VALUES ($1, $2, $3)",
52    )
53    .bind(&token)
54    .bind(user_id)
55    .bind(expires)
56    .execute(db.pool())
57    .await?;
58    Ok(token)
59}
60
61pub async fn delete_session(db: &Db, token: &str) -> Result<()> {
62    sqlx::query("DELETE FROM rustio_sessions WHERE token = $1")
63        .bind(token)
64        .execute(db.pool())
65        .await?;
66    Ok(())
67}
68
69pub async fn identity_from_session(db: &Db, token: &str) -> Result<Option<Identity>> {
70    let row = sqlx::query(
71        "SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, s.expires_at
72           FROM rustio_sessions s
73           JOIN rustio_users u ON u.id = s.user_id
74          WHERE s.token = $1",
75    )
76    .bind(token)
77    .fetch_optional(db.pool())
78    .await?;
79
80    let row = match row {
81        Some(r) => r,
82        None => return Ok(None),
83    };
84    let r = Row::from_pg(&row);
85    let expires_at = r.get_datetime("expires_at")?;
86    if expires_at < Utc::now() {
87        // Don't bother keeping the stale row around. Fire-and-forget.
88        let _ = delete_session(db, token).await;
89        return Ok(None);
90    }
91
92    // Touch last_seen without holding the request back.
93    let db_clone = db.clone();
94    let token_owned = token.to_string();
95    tokio::spawn(async move {
96        let _ = sqlx::query("UPDATE rustio_sessions SET last_seen = NOW() WHERE token = $1")
97            .bind(&token_owned)
98            .execute(db_clone.pool())
99            .await;
100    });
101
102    Ok(Some(Identity {
103        user_id: r.get_i64("id")?,
104        email: r.get_string("email")?,
105        role: Role::parse(&r.get_string("role")?)?,
106        is_active: r.get_bool("is_active")?,
107        is_demo: r.get_bool("is_demo")?,
108        demo_label: r.get_optional_string("demo_label")?,
109    }))
110}
111
112/// Delete all expired sessions. Intended to be called periodically
113/// from a background task (see `background::spawn_session_sweeper`).
114pub async fn purge_expired_sessions(db: &Db) -> Result<u64> {
115    let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at < NOW()")
116        .execute(db.pool())
117        .await?;
118    Ok(result.rows_affected())
119}
120
121pub fn session_token_from_cookie(cookie_header: &str) -> Option<String> {
122    let prefix = format!("{SESSION_COOKIE}=");
123    for part in cookie_header.split(';') {
124        let part = part.trim();
125        if let Some(v) = part.strip_prefix(&prefix) {
126            return Some(v.to_string());
127        }
128    }
129    None
130}
131
132fn random_token() -> String {
133    let mut bytes = [0u8; 32];
134    rand::thread_rng().fill_bytes(&mut bytes);
135    URL_SAFE_NO_PAD.encode(bytes)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn extracts_token_from_cookie_header() {
144        let h = "foo=bar; rustio_session=abc123; other=x";
145        assert_eq!(session_token_from_cookie(h), Some("abc123".into()));
146    }
147
148    #[test]
149    fn returns_none_when_cookie_missing() {
150        let h = "foo=bar; other=x";
151        assert!(session_token_from_cookie(h).is_none());
152    }
153
154    #[test]
155    fn random_token_has_reasonable_entropy() {
156        // Rough sanity check — two consecutive tokens should differ.
157        assert_ne!(random_token(), random_token());
158    }
159}