Skip to main content

tork_core/extract/
header.rs

1//! Header-based extractors.
2
3use http::header::{HeaderName, AUTHORIZATION};
4
5use crate::constants::BEARER_PREFIX;
6use crate::error::{Error, Result};
7use crate::extract::{FromRequest, RequestContext};
8
9/// Header carrying the id of the last Server-Sent Event a client received.
10const LAST_EVENT_ID_HEADER: &str = "last-event-id";
11
12/// A bearer token extracted from the `Authorization` header.
13///
14/// Resolving this extractor requires a well-formed `Authorization: Bearer
15/// <token>` header; otherwise the request is rejected with `401 Unauthorized`.
16/// This extractor only parses the header. Verifying the token (signature,
17/// expiry, claims) is layered on top by application code, typically by a
18/// `#[tork::dependency]` that takes a `BearerToken`.
19pub struct BearerToken(http::HeaderValue);
20
21impl BearerToken {
22    /// Returns the raw token, the part of the header after the `Bearer ` prefix.
23    pub fn token(&self) -> &str {
24        self.0
25            .to_str()
26            .ok()
27            .and_then(|value| value.strip_prefix(BEARER_PREFIX))
28            .expect("BearerToken invariant: stored header must stay a valid bearer token")
29    }
30}
31
32impl FromRequest for BearerToken {
33    fn from_request(
34        ctx: &RequestContext,
35    ) -> impl std::future::Future<Output = Result<Self>> + Send {
36        let resolved = resolve(ctx);
37        async move { resolved }
38    }
39}
40
41/// The `Last-Event-ID` header sent by a client resuming an SSE stream.
42///
43/// Resolving this extractor never fails: a missing header yields `None`. Use it
44/// to resume a stream from where the client left off.
45pub struct LastEventId(Option<String>);
46
47impl LastEventId {
48    /// Returns the last event id, if the client sent one.
49    pub fn as_str(&self) -> Option<&str> {
50        self.0.as_deref()
51    }
52
53    /// Consumes the extractor, returning the last event id if present.
54    pub fn into_inner(self) -> Option<String> {
55        self.0
56    }
57}
58
59impl FromRequest for LastEventId {
60    fn from_request(
61        ctx: &RequestContext,
62    ) -> impl std::future::Future<Output = Result<Self>> + Send {
63        let id = ctx
64            .headers()
65            .get(HeaderName::from_static(LAST_EVENT_ID_HEADER))
66            .and_then(|value| value.to_str().ok())
67            .map(str::to_owned);
68        async move { Ok(LastEventId(id)) }
69    }
70}
71
72/// The `Last-Event-ID` header parsed into a typed resume cursor.
73///
74/// A thin, typed layer over [`LastEventId`] for resuming an SSE stream: the
75/// header value is parsed into `T` (a parse failure yields `None`, as does a
76/// missing header). Resolving never fails.
77pub struct SseResume<T>(Option<T>);
78
79impl<T> SseResume<T> {
80    /// Returns the parsed last event id, if the client sent a valid one.
81    pub fn last_id(&self) -> Option<&T> {
82        self.0.as_ref()
83    }
84
85    /// Consumes the extractor, returning the parsed last event id if present.
86    pub fn into_inner(self) -> Option<T> {
87        self.0
88    }
89}
90
91impl<T> FromRequest for SseResume<T>
92where
93    T: std::str::FromStr + Send,
94{
95    fn from_request(
96        ctx: &RequestContext,
97    ) -> impl std::future::Future<Output = Result<Self>> + Send {
98        let parsed = ctx
99            .headers()
100            .get(HeaderName::from_static(LAST_EVENT_ID_HEADER))
101            .and_then(|value| value.to_str().ok())
102            .and_then(|value| value.parse::<T>().ok());
103        async move { Ok(SseResume(parsed)) }
104    }
105}
106
107/// Parses the bearer token out of the request's `Authorization` header.
108fn resolve(ctx: &RequestContext) -> Result<BearerToken> {
109    let header = ctx
110        .headers()
111        .get(AUTHORIZATION)
112        .ok_or_else(|| Error::unauthorized("missing Authorization header"))?;
113
114    let value = header
115        .to_str()
116        .map_err(|_| Error::unauthorized("invalid Authorization header"))?;
117
118    value
119        .strip_prefix(BEARER_PREFIX)
120        .ok_or_else(|| Error::unauthorized("expected a bearer token"))?;
121
122    Ok(BearerToken(header.clone()))
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::body::box_body;
129    use crate::extract::PathParams;
130    use crate::state::StateMap;
131    use bytes::Bytes;
132    use http_body_util::Full;
133    use std::sync::Arc;
134
135    fn context_with(header: Option<(&'static str, &'static str)>) -> RequestContext {
136        let mut builder = http::Request::builder();
137        if let Some((name, value)) = header {
138            builder = builder.header(name, value);
139        }
140        let head = builder.body(()).unwrap().into_parts().0;
141        let body = box_body(Full::new(Bytes::new()));
142        RequestContext::new(head, PathParams::new(), Arc::new(StateMap::new()), body)
143    }
144
145    #[tokio::test]
146    async fn last_event_id_reads_the_header() {
147        let ctx = context_with(Some(("last-event-id", "42")));
148        let id = LastEventId::from_request(&ctx).await.unwrap();
149        assert_eq!(id.as_str(), Some("42"));
150    }
151
152    #[tokio::test]
153    async fn last_event_id_is_none_when_absent() {
154        let ctx = context_with(None);
155        let id = LastEventId::from_request(&ctx).await.unwrap();
156        assert_eq!(id.into_inner(), None);
157    }
158
159    #[tokio::test]
160    async fn sse_resume_parses_a_typed_cursor() {
161        let ctx = context_with(Some(("last-event-id", "42")));
162        let resume = SseResume::<i64>::from_request(&ctx).await.unwrap();
163        assert_eq!(resume.last_id().copied(), Some(42));
164
165        // A non-numeric value yields None for an i64 cursor.
166        let ctx = context_with(Some(("last-event-id", "abc")));
167        let resume = SseResume::<i64>::from_request(&ctx).await.unwrap();
168        assert_eq!(resume.into_inner(), None);
169    }
170
171    #[tokio::test]
172    async fn bearer_token_happy_path() {
173        let ctx = context_with(Some(("Authorization", "Bearer abc123")));
174        let token = BearerToken::from_request(&ctx).await.unwrap();
175        assert_eq!(token.token(), "abc123");
176    }
177
178    #[tokio::test]
179    async fn bearer_token_missing_header_is_unauthorized() {
180        let ctx = context_with(None);
181        let error = match BearerToken::from_request(&ctx).await {
182            Ok(_) => panic!("missing header must fail"),
183            Err(e) => e,
184        };
185        assert_eq!(error.kind(), crate::error::ErrorKind::Unauthorized);
186        assert_eq!(error.message(), "missing Authorization header");
187    }
188
189    #[tokio::test]
190    async fn bearer_token_invalid_utf8_header_is_unauthorized() {
191        let mut builder = http::Request::builder();
192        builder = builder.header(
193            "Authorization",
194            http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
195        );
196        let head = builder.body(()).unwrap().into_parts().0;
197        let body = box_body(Full::new(Bytes::new()));
198        let ctx = RequestContext::new(head, PathParams::new(), Arc::new(StateMap::new()), body);
199
200        let error = match BearerToken::from_request(&ctx).await {
201            Ok(_) => panic!("non-utf8 must fail"),
202            Err(e) => e,
203        };
204        assert_eq!(error.kind(), crate::error::ErrorKind::Unauthorized);
205        assert_eq!(error.message(), "invalid Authorization header");
206    }
207
208    #[tokio::test]
209    async fn bearer_token_wrong_scheme_is_unauthorized() {
210        for scheme in ["Basic dXNlcjpwYXNz", "Token xyz", "BearerLower xyz", ""] {
211            let ctx = context_with(Some(("Authorization", scheme)));
212            let error = match BearerToken::from_request(&ctx).await {
213                Ok(_) => panic!("scheme `{scheme}` must fail"),
214                Err(e) => e,
215            };
216            assert_eq!(error.kind(), crate::error::ErrorKind::Unauthorized);
217            assert_eq!(error.message(), "expected a bearer token");
218        }
219    }
220
221    #[tokio::test]
222    async fn last_event_id_into_inner_some_branch() {
223        let ctx = context_with(Some(("last-event-id", "hello")));
224        let id = LastEventId::from_request(&ctx).await.unwrap();
225        assert_eq!(id.into_inner(), Some("hello".to_owned()));
226    }
227
228    #[tokio::test]
229    async fn sse_resume_missing_header_yields_none() {
230        let ctx = context_with(None);
231        let resume = SseResume::<u32>::from_request(&ctx).await.unwrap();
232        assert_eq!(resume.last_id(), None);
233        assert_eq!(resume.into_inner(), None);
234    }
235
236    #[tokio::test]
237    async fn sse_resume_valid_value_is_accessible_via_both_accessors() {
238        let ctx = context_with(Some(("last-event-id", "42")));
239        let resume = SseResume::<u32>::from_request(&ctx).await.unwrap();
240        assert_eq!(resume.last_id().copied(), Some(42));
241        assert_eq!(resume.into_inner(), Some(42));
242    }
243}