Skip to main content

umbral_auth/
login_required.rs

1//! A login-required gate for umbral handlers.
2//!
3//! Gates a view behind authentication. umbral ships the idea in two
4//! composable shapes:
5//!
6//! - [`LoggedIn<U>`] — a per-handler axum extractor. Drop it in a handler
7//!   signature and the handler only runs when a valid session exists.
8//! - [`LoginRequiredLayer`] — a per-Router tower middleware layer. Every
9//!   route in the wrapped subtree is gated; unauthenticated requests never
10//!   reach the inner handler.
11//!
12//! Both shapes share [`LoginRequired`] for the redirect vs. 401 fork.
13//!
14//! ## Design decisions
15//!
16//! - `LoggedIn<U: UserModel>` is **fully generic** over the user model
17//!   (option a from the spec). The cookie/session reading is ~25 lines of
18//!   direct logic (read cookie, hash it, query session table, hydrate U).
19//!   Keeping it generic means a custom user model (`TenantUser` etc.) can
20//!   use `LoggedIn<TenantUser>` without any wrapper or code duplication.
21//!
22//! - The `LoginRequired` config is read from `request.extensions()` when
23//!   set by `LoginRequiredLayer`, or falls back to `LoginRequired::API`
24//!   (401 JSON) if the extractor is used directly without the layer.
25//!
26//! - `LoginRequiredLayer` implements `tower::Layer<S>` directly so it
27//!   works with `Router::layer(login_required())` and
28//!   `Router::layer(login_required_html("/login"))` without extra
29//!   wrapping.
30//!
31//! - The layer gate does NOT load the full user struct — it checks only
32//!   the session table (`user_id IS NOT NULL AND expires_at > now`). The
33//!   `LoggedIn<U>` extractor does the full hydration. This avoids the `U`
34//!   bound at the layer level, so `login_required()` works with any user
35//!   model without a type parameter on the layer.
36//!
37//! ## Deferred
38//!
39//! - `permission_required(perm)` and `staff_member_required` are deferred
40//!   pending gap 33 (groups + content-type model). They can be added as
41//!   thin wrappers once permission objects exist.
42
43use std::future::Future;
44use std::pin::Pin;
45use std::task::{Context, Poll};
46
47use axum::body::Body;
48use axum::http::{StatusCode, Uri};
49use axum::response::{IntoResponse, Response};
50use axum_core::extract::FromRequestParts;
51use chrono::{DateTime, Utc};
52use http::request::Parts;
53use serde_json::json;
54use sha2::{Digest, Sha256};
55use tower::{Layer, Service};
56
57use crate::UserModel;
58
59// =========================================================================
60// LoginRequired — shared config struct
61// =========================================================================
62
63/// Configuration shared by both the extractor and the middleware.
64///
65/// Controls whether an unauthenticated request gets a JSON 401 (REST/API
66/// behaviour) or a 302 redirect to a login page (server-rendered HTML
67/// behaviour).
68#[derive(Debug, Clone)]
69pub struct LoginRequired {
70    /// `None` = return 401 JSON. `Some("/login")` = 302 to
71    /// `login_url?next=<uri>`.
72    pub login_url: Option<String>,
73    /// The query-string parameter name to append with the original URI.
74    /// `Some("next")` appends `?next=<uri>`; `None` redirects without it.
75    /// Only used when `login_url` is `Some`.
76    pub next_param: Option<String>,
77}
78
79impl LoginRequired {
80    /// API/REST shape: return a JSON 401 with a `WWW-Authenticate: Bearer`
81    /// header.
82    pub const API: Self = Self {
83        login_url: None,
84        next_param: None,
85    };
86
87    /// HTML shape: redirect to `login_url?next=<original-uri>`. The `next`
88    /// parameter is named `"next"` by default.
89    pub fn html(login_url: impl Into<String>) -> Self {
90        Self {
91            login_url: Some(login_url.into()),
92            next_param: Some("next".to_string()),
93        }
94    }
95
96    /// Drop the `next` parameter from the redirect.
97    pub fn no_next(mut self) -> Self {
98        self.next_param = None;
99        self
100    }
101
102    /// Build the rejection response.
103    pub(crate) fn rejection_response(&self, uri: &Uri) -> Response {
104        match &self.login_url {
105            None => {
106                let body = json!({"error": "authentication required"}).to_string();
107                axum::http::Response::builder()
108                    .status(StatusCode::UNAUTHORIZED)
109                    .header("content-type", "application/json")
110                    .header("www-authenticate", "Bearer")
111                    .body(Body::from(body))
112                    .expect("building 401 response cannot fail")
113                    .into_response()
114            }
115            Some(url) => {
116                let location = match &self.next_param {
117                    Some(param) => {
118                        let original = uri.to_string();
119                        format!("{url}?{param}={}", urlencoded(original.as_str()))
120                    }
121                    None => url.clone(),
122                };
123                axum::http::Response::builder()
124                    .status(StatusCode::FOUND)
125                    .header("location", location)
126                    .body(Body::empty())
127                    .expect("building 302 response cannot fail")
128                    .into_response()
129            }
130        }
131    }
132}
133
134/// Percent-encode a URI for safe embedding in a query-string value.
135fn urlencoded(s: &str) -> String {
136    let mut out = String::with_capacity(s.len());
137    for c in s.chars() {
138        match c {
139            '?' => out.push_str("%3F"),
140            '&' => out.push_str("%26"),
141            '=' => out.push_str("%3D"),
142            '+' => out.push_str("%2B"),
143            '%' => out.push_str("%25"),
144            ' ' => out.push_str("%20"),
145            c => out.push(c),
146        }
147    }
148    out
149}
150
151// =========================================================================
152// LoggedIn<U> extractor
153// =========================================================================
154
155/// Per-handler axum extractor that resolves the session cookie into a user
156/// of type `U`.
157///
158/// ```rust,ignore
159/// use umbral_auth::{AuthUser, login_required::LoggedIn};
160///
161/// async fn dashboard(LoggedIn(user): LoggedIn<AuthUser>) -> String {
162///     format!("Hello, {}!", user.username())
163/// }
164/// ```
165///
166/// If no valid session exists the extractor returns the configured rejection
167/// response. The config is read from `request.extensions()` (set by
168/// [`LoginRequiredLayer`]) or falls back to [`LoginRequired::API`].
169pub struct LoggedIn<U: UserModel>(pub U);
170
171// `LoggedIn` is a tuple-newtype around `U`. Drop in `Deref` /
172// `DerefMut` (so `user.username()` works directly without the
173// `.0`) and `Serialize` (so it slots into template contexts via
174// `context!(user)` without `user.0`). Closes BUG-18 from
175// bugs/tests/testBugs.md — the original ergonomic gap that
176// pushed test code to write `let username = user.0.username();`
177// for what should be the obvious shape.
178impl<U: UserModel> std::ops::Deref for LoggedIn<U> {
179    type Target = U;
180    fn deref(&self) -> &Self::Target {
181        &self.0
182    }
183}
184
185impl<U: UserModel> std::ops::DerefMut for LoggedIn<U> {
186    fn deref_mut(&mut self) -> &mut Self::Target {
187        &mut self.0
188    }
189}
190
191impl<U: UserModel + serde::Serialize> serde::Serialize for LoggedIn<U> {
192    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193    where
194        S: serde::Serializer,
195    {
196        // Forward verbatim so `LoggedIn<AuthUser>` round-trips
197        // exactly the same shape `AuthUser` would on its own.
198        self.0.serialize(serializer)
199    }
200}
201
202impl<U, S> FromRequestParts<S> for LoggedIn<U>
203where
204    U: UserModel
205        + for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
206        + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
207        + umbral::orm::HydrateRelated
208        + Unpin
209        + Send,
210    // The session-row parse step is the bit that needs FromStr —
211    // an `i64`, `Uuid`, `String`, or hand-rolled PK type all
212    // implement it for free; a future PK shape with no string
213    // representation would have to override `id_string` AND
214    // provide a `FromStr` mirror to keep this extractor happy.
215    <U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
216    S: Send + Sync,
217{
218    type Rejection = Response;
219
220    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
221        let config = parts
222            .extensions
223            .get::<LoginRequired>()
224            .cloned()
225            .unwrap_or(LoginRequired::API);
226
227        let uri = parts.uri.clone();
228
229        match resolve_user::<U>(&parts.headers).await {
230            Some(user) => Ok(LoggedIn(user)),
231            None => Err(config.rejection_response(&uri)),
232        }
233    }
234}
235
236// =========================================================================
237// Session resolution helpers
238// =========================================================================
239
240/// SHA-256 hash the raw session token. Mirrors `umbral-sessions`'s
241/// `hash_token`. umbral-auth must not depend on umbral-sessions (the dep
242/// arrow runs the other way), so we re-implement the trivial hash step.
243fn hash_token(raw: &str) -> String {
244    let mut h = Sha256::new();
245    h.update(raw.as_bytes());
246    format!("{:x}", h.finalize())
247}
248
249/// Extract the `umbral_session` cookie from the request headers.
250fn cookie_from_headers(headers: &http::HeaderMap) -> Option<String> {
251    let header = headers.get(http::header::COOKIE)?.to_str().ok()?;
252    for pair in header.split(';') {
253        let pair = pair.trim();
254        if let Some(value) = pair.strip_prefix("umbral_session=") {
255            return Some(value.to_string());
256        }
257    }
258    None
259}
260
261/// Load a user of type `U` from the session cookie in the given
262/// headers. The generic shape powers both [`LoggedIn`] and the
263/// public [`crate::current_user_as`] helper — apps using a custom
264/// `UserModel` reach for the latter from their own handlers when
265/// the AuthUser-flavoured [`crate::current_user`] doesn't fit.
266///
267/// **Polymorphic over `U::PrimaryKey`** — the session row stores
268/// the user PK as text (gap #59); we parse it back to the typed PK
269/// via `FromStr` before feeding it to the ORM, so a `UuidUser`
270/// stays UUID-shaped on the WHERE clause and an `AuthUser` stays
271/// `i64`-shaped. There is no `parse::<i64>()` hardcoded anywhere
272/// in the framework's session-read path; the typed PK threads
273/// through verbatim.
274///
275/// Conventions assumed: `U` has an `id` column populated by the
276/// model's PK type, and an `is_active` boolean column the filter
277/// excludes deactivated rows on. Custom user models that rename
278/// either column write their own resolver against
279/// [`umbral_sessions::current_user_id_str`] instead.
280pub async fn resolve_user<U>(headers: &http::HeaderMap) -> Option<U>
281where
282    U: UserModel
283        + for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
284        + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
285        + umbral::orm::HydrateRelated
286        + Unpin
287        + Send,
288    <U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
289{
290    let user_id = current_session_user_pk::<U>(headers).await?;
291    umbral::orm::Manager::<U>::default()
292        .filter(
293            umbral::orm::Predicate::<U>::col_eq("id", user_id)
294                & umbral::orm::Predicate::<U>::col_eq("is_active", true),
295        )
296        .first()
297        .await
298        .ok()
299        .flatten()
300}
301
302/// Read the request's session cookie and resolve it to the
303/// authenticated user's TYPED primary key. The generic version of
304/// [`current_session_user_id`]; this is what [`resolve_user`] and
305/// the future `permission_required_as<U>` build on.
306///
307/// Parses the text `session.user_id` (gap #59) via
308/// `<U::PrimaryKey as FromStr>::from_str`. A non-parseable value
309/// (the row was written by a different `UserModel` impl) resolves
310/// to `None` — same as missing cookie or expired session, so the
311/// caller's "anonymous" branch fires.
312pub async fn current_session_user_pk<U>(
313    headers: &http::HeaderMap,
314) -> Option<<U as umbral::orm::Model>::PrimaryKey>
315where
316    U: UserModel,
317    <U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
318{
319    let raw_token = cookie_from_headers(headers)?;
320    let stored_id = hash_token(&raw_token);
321    let row: Option<SessionRow> = umbral::orm::Manager::<SessionRow>::default()
322        .filter(umbral::orm::Predicate::<SessionRow>::col_eq("id", stored_id))
323        .first()
324        .await
325        .ok()
326        .flatten();
327    let row = row?;
328    if row.expires_at < Utc::now() {
329        return None;
330    }
331    row.user_id?.parse().ok()
332}
333
334/// Check whether headers carry a valid authenticated session.
335/// Returns `true` iff a valid, non-expired, non-anonymous session is present.
336pub(crate) async fn is_authenticated(headers: &http::HeaderMap) -> bool {
337    current_session_user_id(headers).await.is_some()
338}
339
340/// Resolve the `umbral_session` cookie in `headers` to the
341/// authenticated user's `i64` PK — the AuthUser-specific shorthand
342/// for [`current_session_user_pk::<AuthUser>`]. Returns `None` for
343/// missing cookie, expired session, anonymous session, a
344/// non-parseable `user_id` (session written by a non-AuthUser
345/// model), or any sqlx error.
346///
347/// This is the primitive `permission_required` (in `umbral-permissions`)
348/// builds on. Callers using a custom user model reach for
349/// [`current_session_user_pk`] (the typed generic) or
350/// [`umbral_sessions::current_user_id_str`] (the raw string)
351/// instead — both stay polymorphic over the active user model's PK.
352pub async fn current_session_user_id(headers: &http::HeaderMap) -> Option<i64> {
353    current_session_user_pk::<crate::AuthUser>(headers).await
354}
355
356/// Private mirror of `umbral_sessions::Session`. Lives here because
357/// `umbral-auth` does not depend on `umbral-sessions` (the dep arrow runs
358/// the other way), but we still need ORM access to the `session` table.
359/// Multiple `Model` impls can target the same table — sea-query treats
360/// the schema as data, not a type-level singleton.
361#[doc(hidden)]
362#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize, serde::Deserialize, umbral::orm::Model)]
363#[umbral(table = "session")]
364pub struct SessionRow {
365    pub id: String,
366    /// Polymorphic user-PK column (gap #59). Stored as the user's PK
367    /// `Display` form — i64 for AuthUser, UUID for custom user models,
368    /// etc. Parse with `<U::PrimaryKey as FromStr>::from_str` on the
369    /// way out.
370    pub user_id: Option<String>,
371    pub data: String,
372    pub created_at: DateTime<Utc>,
373    pub expires_at: DateTime<Utc>,
374}
375
376// =========================================================================
377// LoginRequiredLayer — tower::Layer impl
378// =========================================================================
379
380/// Per-router middleware layer that gates every route in the wrapped subtree.
381///
382/// ```rust,ignore
383/// use umbral_auth::login_required::{login_required, login_required_html};
384///
385/// // REST subtree — 401 JSON on unauthenticated.
386/// let api_router = Router::new()
387///     .route("/api/me", get(me_handler))
388///     .layer(login_required());
389///
390/// // HTML subtree — 302 to /login?next=<uri>.
391/// let app_router = Router::new()
392///     .route("/dashboard", get(dashboard_handler))
393///     .layer(login_required_html("/login"));
394/// ```
395///
396/// The layer also inserts the [`LoginRequired`] config into request
397/// extensions so nested [`LoggedIn<U>`] extractors pick it up without
398/// re-declaration.
399#[derive(Clone)]
400pub struct LoginRequiredLayer {
401    config: LoginRequired,
402}
403
404impl LoginRequiredLayer {
405    /// Build a layer with an explicit config.
406    pub fn new(config: LoginRequired) -> Self {
407        Self { config }
408    }
409
410    /// Apply this layer to a Router, returning the gated router.
411    ///
412    /// ```rust,ignore
413    /// let gated = LoginRequiredLayer::new(LoginRequired::html("/login"))
414    ///     .apply(my_router);
415    /// ```
416    pub fn apply(self, router: axum::Router) -> axum::Router {
417        router.layer(self)
418    }
419}
420
421impl<S> Layer<S> for LoginRequiredLayer {
422    type Service = LoginRequiredService<S>;
423
424    fn layer(&self, inner: S) -> Self::Service {
425        LoginRequiredService {
426            inner,
427            config: self.config.clone(),
428        }
429    }
430}
431
432/// The tower `Service` produced by [`LoginRequiredLayer`].
433#[derive(Clone)]
434pub struct LoginRequiredService<S> {
435    inner: S,
436    config: LoginRequired,
437}
438
439impl<S> Service<axum::extract::Request> for LoginRequiredService<S>
440where
441    S: Service<axum::extract::Request, Response = Response> + Clone + Send + 'static,
442    S::Future: Send + 'static,
443{
444    type Response = Response;
445    type Error = S::Error;
446    type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send + 'static>>;
447
448    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
449        self.inner.poll_ready(cx)
450    }
451
452    fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
453        let config = self.config.clone();
454        // Clone inner for the async block — `self.inner` is consumed
455        // by `call()` semantically and must be driven after `poll_ready`.
456        let mut inner = self.inner.clone();
457
458        Box::pin(async move {
459            let uri = req.uri().clone();
460
461            if !is_authenticated(req.headers()).await {
462                return Ok(config.rejection_response(&uri));
463            }
464
465            // Insert config so LoggedIn<U> extractors can find it.
466            req.extensions_mut().insert(config);
467
468            inner.call(req).await
469        })
470    }
471}
472
473// =========================================================================
474// Convenience constructors
475// =========================================================================
476
477/// Returns a [`LoginRequiredLayer`] configured for REST/API use (401 JSON).
478///
479/// ```rust,ignore
480/// Router::new()
481///     .route("/api/me", get(me_handler))
482///     .layer(login_required())
483/// ```
484pub fn login_required() -> LoginRequiredLayer {
485    LoginRequiredLayer::new(LoginRequired::API)
486}
487
488/// Returns a [`LoginRequiredLayer`] configured for HTML use (302 redirect).
489///
490/// ```rust,ignore
491/// Router::new()
492///     .route("/dashboard", get(dashboard_handler))
493///     .layer(login_required_html("/login"))
494/// ```
495pub fn login_required_html(login_url: impl Into<String>) -> LoginRequiredLayer {
496    LoginRequiredLayer::new(LoginRequired::html(login_url))
497}