tork_core/extract/
header.rs1use http::header::{HeaderName, AUTHORIZATION};
4
5use crate::constants::BEARER_PREFIX;
6use crate::error::{Error, Result};
7use crate::extract::{FromRequest, RequestContext};
8
9const LAST_EVENT_ID_HEADER: &str = "last-event-id";
11
12pub struct BearerToken(http::HeaderValue);
20
21impl BearerToken {
22 pub fn token(&self) -> &str {
24 self.0
25 .to_str()
26 .ok()
27 .and_then(|value| value.strip_prefix(BEARER_PREFIX))
28 .expect("BearerToken invariant: stored header must stay a valid bearer token")
29 }
30}
31
32impl FromRequest for BearerToken {
33 fn from_request(
34 ctx: &RequestContext,
35 ) -> impl std::future::Future<Output = Result<Self>> + Send {
36 let resolved = resolve(ctx);
37 async move { resolved }
38 }
39}
40
41pub struct LastEventId(Option<String>);
46
47impl LastEventId {
48 pub fn as_str(&self) -> Option<&str> {
50 self.0.as_deref()
51 }
52
53 pub fn into_inner(self) -> Option<String> {
55 self.0
56 }
57}
58
59impl FromRequest for LastEventId {
60 fn from_request(
61 ctx: &RequestContext,
62 ) -> impl std::future::Future<Output = Result<Self>> + Send {
63 let id = ctx
64 .headers()
65 .get(HeaderName::from_static(LAST_EVENT_ID_HEADER))
66 .and_then(|value| value.to_str().ok())
67 .map(str::to_owned);
68 async move { Ok(LastEventId(id)) }
69 }
70}
71
72pub struct SseResume<T>(Option<T>);
78
79impl<T> SseResume<T> {
80 pub fn last_id(&self) -> Option<&T> {
82 self.0.as_ref()
83 }
84
85 pub fn into_inner(self) -> Option<T> {
87 self.0
88 }
89}
90
91impl<T> FromRequest for SseResume<T>
92where
93 T: std::str::FromStr + Send,
94{
95 fn from_request(
96 ctx: &RequestContext,
97 ) -> impl std::future::Future<Output = Result<Self>> + Send {
98 let parsed = ctx
99 .headers()
100 .get(HeaderName::from_static(LAST_EVENT_ID_HEADER))
101 .and_then(|value| value.to_str().ok())
102 .and_then(|value| value.parse::<T>().ok());
103 async move { Ok(SseResume(parsed)) }
104 }
105}
106
107fn resolve(ctx: &RequestContext) -> Result<BearerToken> {
109 let header = ctx
110 .headers()
111 .get(AUTHORIZATION)
112 .ok_or_else(|| Error::unauthorized("missing Authorization header"))?;
113
114 let value = header
115 .to_str()
116 .map_err(|_| Error::unauthorized("invalid Authorization header"))?;
117
118 value
119 .strip_prefix(BEARER_PREFIX)
120 .ok_or_else(|| Error::unauthorized("expected a bearer token"))?;
121
122 Ok(BearerToken(header.clone()))
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::body::box_body;
129 use crate::extract::PathParams;
130 use crate::state::StateMap;
131 use bytes::Bytes;
132 use http_body_util::Full;
133 use std::sync::Arc;
134
135 fn context_with(header: Option<(&'static str, &'static str)>) -> RequestContext {
136 let mut builder = http::Request::builder();
137 if let Some((name, value)) = header {
138 builder = builder.header(name, value);
139 }
140 let head = builder.body(()).unwrap().into_parts().0;
141 let body = box_body(Full::new(Bytes::new()));
142 RequestContext::new(head, PathParams::new(), Arc::new(StateMap::new()), body)
143 }
144
145 #[tokio::test]
146 async fn last_event_id_reads_the_header() {
147 let ctx = context_with(Some(("last-event-id", "42")));
148 let id = LastEventId::from_request(&ctx).await.unwrap();
149 assert_eq!(id.as_str(), Some("42"));
150 }
151
152 #[tokio::test]
153 async fn last_event_id_is_none_when_absent() {
154 let ctx = context_with(None);
155 let id = LastEventId::from_request(&ctx).await.unwrap();
156 assert_eq!(id.into_inner(), None);
157 }
158
159 #[tokio::test]
160 async fn sse_resume_parses_a_typed_cursor() {
161 let ctx = context_with(Some(("last-event-id", "42")));
162 let resume = SseResume::<i64>::from_request(&ctx).await.unwrap();
163 assert_eq!(resume.last_id().copied(), Some(42));
164
165 let ctx = context_with(Some(("last-event-id", "abc")));
167 let resume = SseResume::<i64>::from_request(&ctx).await.unwrap();
168 assert_eq!(resume.into_inner(), None);
169 }
170
171 #[tokio::test]
172 async fn bearer_token_happy_path() {
173 let ctx = context_with(Some(("Authorization", "Bearer abc123")));
174 let token = BearerToken::from_request(&ctx).await.unwrap();
175 assert_eq!(token.token(), "abc123");
176 }
177
178 #[tokio::test]
179 async fn bearer_token_missing_header_is_unauthorized() {
180 let ctx = context_with(None);
181 let error = match BearerToken::from_request(&ctx).await {
182 Ok(_) => panic!("missing header must fail"),
183 Err(e) => e,
184 };
185 assert_eq!(error.kind(), crate::error::ErrorKind::Unauthorized);
186 assert_eq!(error.message(), "missing Authorization header");
187 }
188
189 #[tokio::test]
190 async fn bearer_token_invalid_utf8_header_is_unauthorized() {
191 let mut builder = http::Request::builder();
192 builder = builder.header(
193 "Authorization",
194 http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
195 );
196 let head = builder.body(()).unwrap().into_parts().0;
197 let body = box_body(Full::new(Bytes::new()));
198 let ctx = RequestContext::new(head, PathParams::new(), Arc::new(StateMap::new()), body);
199
200 let error = match BearerToken::from_request(&ctx).await {
201 Ok(_) => panic!("non-utf8 must fail"),
202 Err(e) => e,
203 };
204 assert_eq!(error.kind(), crate::error::ErrorKind::Unauthorized);
205 assert_eq!(error.message(), "invalid Authorization header");
206 }
207
208 #[tokio::test]
209 async fn bearer_token_wrong_scheme_is_unauthorized() {
210 for scheme in ["Basic dXNlcjpwYXNz", "Token xyz", "BearerLower xyz", ""] {
211 let ctx = context_with(Some(("Authorization", scheme)));
212 let error = match BearerToken::from_request(&ctx).await {
213 Ok(_) => panic!("scheme `{scheme}` must fail"),
214 Err(e) => e,
215 };
216 assert_eq!(error.kind(), crate::error::ErrorKind::Unauthorized);
217 assert_eq!(error.message(), "expected a bearer token");
218 }
219 }
220
221 #[tokio::test]
222 async fn last_event_id_into_inner_some_branch() {
223 let ctx = context_with(Some(("last-event-id", "hello")));
224 let id = LastEventId::from_request(&ctx).await.unwrap();
225 assert_eq!(id.into_inner(), Some("hello".to_owned()));
226 }
227
228 #[tokio::test]
229 async fn sse_resume_missing_header_yields_none() {
230 let ctx = context_with(None);
231 let resume = SseResume::<u32>::from_request(&ctx).await.unwrap();
232 assert_eq!(resume.last_id(), None);
233 assert_eq!(resume.into_inner(), None);
234 }
235
236 #[tokio::test]
237 async fn sse_resume_valid_value_is_accessible_via_both_accessors() {
238 let ctx = context_with(Some(("last-event-id", "42")));
239 let resume = SseResume::<u32>::from_request(&ctx).await.unwrap();
240 assert_eq!(resume.last_id().copied(), Some(42));
241 assert_eq!(resume.into_inner(), Some(42));
242 }
243}