Skip to main content

h11/
_events.rs

1use crate::_abnf::{METHOD, REASON_PHRASE, REQUEST_TARGET};
2use crate::{_headers::Headers, _util::ProtocolError};
3use lazy_static::lazy_static;
4use regex::bytes::Regex;
5use std::fmt::{self, Formatter};
6
7lazy_static! {
8    static ref HTTP_VERSION_RE: Regex = Regex::new(r"^[0-9]\.[0-9]$").unwrap();
9    static ref METHOD_RE: Regex = Regex::new(&format!(r"^{}$", *METHOD)).unwrap();
10    static ref REASON_RE: Regex = Regex::new(&format!(r"^{}$", *REASON_PHRASE)).unwrap();
11    static ref REQUEST_TARGET_RE: Regex = Regex::new(&format!(r"^{}$", *REQUEST_TARGET)).unwrap();
12}
13
14/// HTTP request head event.
15///
16/// Use [`Request::new`] or [`Request::new_http11`] for fallible construction.
17/// Direct struct literals are possible, but callers should run
18/// [`Request::validate`] before sending values built from untrusted input.
19#[derive(Clone, PartialEq, Eq, Default)]
20pub struct Request {
21    /// Request method bytes, for example `GET` or `POST`.
22    pub method: Vec<u8>,
23    /// Normalized request headers with original casing preserved.
24    pub headers: Headers,
25    /// Request target bytes.
26    pub target: Vec<u8>,
27    /// HTTP version without the `HTTP/` prefix, for example `1.1`.
28    pub http_version: Vec<u8>,
29}
30
31impl Request {
32    /// Builds and validates a request with an explicit HTTP version.
33    pub fn new<M, T, V>(
34        method: M,
35        headers: Headers,
36        target: T,
37        http_version: V,
38    ) -> Result<Self, ProtocolError>
39    where
40        M: AsRef<[u8]>,
41        T: AsRef<[u8]>,
42        V: AsRef<[u8]>,
43    {
44        let request = Self {
45            method: method.as_ref().to_vec(),
46            headers,
47            target: target.as_ref().to_vec(),
48            http_version: http_version.as_ref().to_vec(),
49        };
50        request.validate()?;
51        Ok(request)
52    }
53
54    /// Builds and validates an HTTP/1.1 request.
55    pub fn new_http11<M, T>(method: M, headers: Headers, target: T) -> Result<Self, ProtocolError>
56    where
57        M: AsRef<[u8]>,
58        T: AsRef<[u8]>,
59    {
60        Self::new(method, headers, target, b"1.1")
61    }
62
63    /// Validates request method, target, HTTP version, and Host header rules.
64    pub fn validate(&self) -> Result<(), ProtocolError> {
65        let mut host_count = 0;
66        for (name, _) in self.headers.iter() {
67            if name == b"host" {
68                host_count += 1;
69            }
70        }
71        if !HTTP_VERSION_RE.is_match(&self.http_version) {
72            return Err(ProtocolError::LocalProtocolError(
73                ("Illegal HTTP version".to_string(), 400).into(),
74            ));
75        }
76        if self.http_version == b"1.1" && host_count == 0 {
77            return Err(ProtocolError::LocalProtocolError(
78                ("Missing mandatory Host: header".to_string(), 400).into(),
79            ));
80        }
81        if host_count > 1 {
82            return Err(ProtocolError::LocalProtocolError(
83                ("Found multiple Host: headers".to_string(), 400).into(),
84            ));
85        }
86
87        if !METHOD_RE.is_match(&self.method) {
88            return Err(ProtocolError::LocalProtocolError(
89                ("Illegal method characters".to_string(), 400).into(),
90            ));
91        }
92        if !REQUEST_TARGET_RE.is_match(&self.target) {
93            return Err(ProtocolError::LocalProtocolError(
94                ("Illegal target characters".to_string(), 400).into(),
95            ));
96        }
97
98        Ok(())
99    }
100}
101
102impl std::fmt::Debug for Request {
103    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
104        f.debug_struct("Request")
105            .field("method", &String::from_utf8_lossy(&self.method))
106            .field("headers", &self.headers)
107            .field("target", &String::from_utf8_lossy(&self.target))
108            .field("http_version", &String::from_utf8_lossy(&self.http_version))
109            .finish()
110    }
111}
112
113/// HTTP response head event.
114///
115/// The same struct is used for informational and final responses. Use the
116/// range-checked constructors or [`Event::informational_response`] /
117/// [`Event::normal_response`] when that distinction matters.
118#[derive(Debug, Clone, PartialEq, Eq, Default)]
119pub struct Response {
120    /// Normalized response headers with original casing preserved.
121    pub headers: Headers,
122    /// HTTP version without the `HTTP/` prefix, for example `1.1`.
123    pub http_version: Vec<u8>,
124    /// Reason phrase bytes.
125    pub reason: Vec<u8>,
126    /// Three-digit HTTP status code.
127    pub status_code: u16,
128}
129
130impl Response {
131    /// Builds and validates a response with an explicit HTTP version.
132    pub fn new<R, V>(
133        status_code: u16,
134        headers: Headers,
135        reason: R,
136        http_version: V,
137    ) -> Result<Self, ProtocolError>
138    where
139        R: AsRef<[u8]>,
140        V: AsRef<[u8]>,
141    {
142        let response = Self {
143            headers,
144            http_version: http_version.as_ref().to_vec(),
145            reason: reason.as_ref().to_vec(),
146            status_code,
147        };
148        response.validate()?;
149        Ok(response)
150    }
151
152    /// Builds and validates an HTTP/1.1 response.
153    pub fn new_http11<R>(
154        status_code: u16,
155        headers: Headers,
156        reason: R,
157    ) -> Result<Self, ProtocolError>
158    where
159        R: AsRef<[u8]>,
160    {
161        Self::new(status_code, headers, reason, b"1.1")
162    }
163
164    /// Builds and validates an informational response.
165    ///
166    /// The status code must be in `100..=199`.
167    pub fn new_informational<R, V>(
168        status_code: u16,
169        headers: Headers,
170        reason: R,
171        http_version: V,
172    ) -> Result<Self, ProtocolError>
173    where
174        R: AsRef<[u8]>,
175        V: AsRef<[u8]>,
176    {
177        let response = Self::new(status_code, headers, reason, http_version)?;
178        if !(100..=199).contains(&response.status_code) {
179            return Err(ProtocolError::LocalProtocolError(
180                (
181                    "Informational responses must use status codes in the range 100..=199",
182                    400,
183                )
184                    .into(),
185            ));
186        }
187        Ok(response)
188    }
189
190    /// Builds and validates an HTTP/1.1 informational response.
191    ///
192    /// The status code must be in `100..=199`.
193    pub fn new_informational_http11<R>(
194        status_code: u16,
195        headers: Headers,
196        reason: R,
197    ) -> Result<Self, ProtocolError>
198    where
199        R: AsRef<[u8]>,
200    {
201        Self::new_informational(status_code, headers, reason, b"1.1")
202    }
203
204    /// Builds and validates a final response.
205    ///
206    /// The status code must be `>= 200`.
207    pub fn new_final<R, V>(
208        status_code: u16,
209        headers: Headers,
210        reason: R,
211        http_version: V,
212    ) -> Result<Self, ProtocolError>
213    where
214        R: AsRef<[u8]>,
215        V: AsRef<[u8]>,
216    {
217        let response = Self::new(status_code, headers, reason, http_version)?;
218        if response.status_code < 200 {
219            return Err(ProtocolError::LocalProtocolError(
220                ("Final responses must use status codes >= 200", 400).into(),
221            ));
222        }
223        Ok(response)
224    }
225
226    /// Builds and validates an HTTP/1.1 final response.
227    ///
228    /// The status code must be `>= 200`.
229    pub fn new_final_http11<R>(
230        status_code: u16,
231        headers: Headers,
232        reason: R,
233    ) -> Result<Self, ProtocolError>
234    where
235        R: AsRef<[u8]>,
236    {
237        Self::new_final(status_code, headers, reason, b"1.1")
238    }
239
240    /// Validates response status code, reason phrase, and HTTP version.
241    pub fn validate(&self) -> Result<(), ProtocolError> {
242        if !(100..=999).contains(&self.status_code) {
243            return Err(ProtocolError::LocalProtocolError(
244                ("Illegal status code".to_string(), 400).into(),
245            ));
246        }
247        if !HTTP_VERSION_RE.is_match(&self.http_version) {
248            return Err(ProtocolError::LocalProtocolError(
249                ("Illegal HTTP version".to_string(), 400).into(),
250            ));
251        }
252        if !REASON_RE.is_match(&self.reason) {
253            return Err(ProtocolError::LocalProtocolError(
254                ("Illegal reason phrase".to_string(), 400).into(),
255            ));
256        }
257        Ok(())
258    }
259}
260
261/// HTTP message body data.
262#[derive(Debug, Clone, PartialEq, Eq, Default)]
263pub struct Data {
264    /// Body bytes for this chunk.
265    pub data: Vec<u8>,
266    /// Whether this event begins a transfer-coding chunk.
267    pub chunk_start: bool,
268    /// Whether this event ends a transfer-coding chunk.
269    pub chunk_end: bool,
270}
271
272/// End of the current HTTP message.
273#[derive(Debug, Clone, PartialEq, Eq, Default)]
274pub struct EndOfMessage {
275    /// Trailer fields sent after a chunked body.
276    pub headers: Headers,
277}
278
279/// Notification that the connection has closed.
280#[derive(Debug, Clone, PartialEq, Eq, Default)]
281pub struct ConnectionClosed {}
282
283/// Protocol events emitted and accepted by [`crate::Connection`].
284#[derive(Debug, Clone, PartialEq, Eq)]
285pub enum Event {
286    /// Request head event.
287    Request(Request),
288    /// Final response head event with status code `>= 200`.
289    NormalResponse(Response),
290    /// Informational response head event with status code in `100..=199`.
291    InformationalResponse(Response),
292    /// Message body data.
293    Data(Data),
294    /// End of a request or response message.
295    EndOfMessage(EndOfMessage),
296    /// Connection close notification.
297    ConnectionClosed(ConnectionClosed),
298    /// More bytes are needed before another inbound event can be produced.
299    NeedData(),
300    /// Inbound data is paused until the current cycle is completed.
301    Paused(),
302}
303
304impl From<Request> for Event {
305    fn from(request: Request) -> Self {
306        Self::Request(request)
307    }
308}
309
310impl From<Response> for Event {
311    fn from(response: Response) -> Self {
312        match response.status_code {
313            100..=199 => Self::InformationalResponse(response),
314            _ => Self::NormalResponse(response),
315        }
316    }
317}
318
319impl Event {
320    /// Converts a validated response into an informational response event.
321    pub fn informational_response(response: Response) -> Result<Self, ProtocolError> {
322        if !(100..=199).contains(&response.status_code) {
323            return Err(ProtocolError::LocalProtocolError(
324                (
325                    "Informational responses must use status codes in the range 100..=199",
326                    400,
327                )
328                    .into(),
329            ));
330        }
331        response.validate()?;
332        Ok(Self::InformationalResponse(response))
333    }
334
335    /// Converts a validated response into a final response event.
336    pub fn normal_response(response: Response) -> Result<Self, ProtocolError> {
337        if response.status_code < 200 {
338            return Err(ProtocolError::LocalProtocolError(
339                ("Normal responses must use status codes >= 200", 400).into(),
340            ));
341        }
342        response.validate()?;
343        Ok(Self::NormalResponse(response))
344    }
345}
346
347impl From<Data> for Event {
348    fn from(data: Data) -> Self {
349        Self::Data(data)
350    }
351}
352
353impl From<EndOfMessage> for Event {
354    fn from(end_of_message: EndOfMessage) -> Self {
355        Self::EndOfMessage(end_of_message)
356    }
357}
358
359impl From<ConnectionClosed> for Event {
360    fn from(connection_closed: ConnectionClosed) -> Self {
361        Self::ConnectionClosed(connection_closed)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_response_new_rejects_invalid_input() {
371        assert!(Response::new(99, Headers::default(), b"OK".to_vec(), b"1.1".to_vec()).is_err());
372        assert!(Response::new(1000, Headers::default(), b"OK".to_vec(), b"1.1".to_vec()).is_err());
373        assert!(Response::new(
374            200,
375            Headers::default(),
376            b"OK".to_vec(),
377            b"HTTP/1.1".to_vec()
378        )
379        .is_err());
380        assert!(Response::new(
381            200,
382            Headers::default(),
383            b"bad\nreason".to_vec(),
384            b"1.1".to_vec()
385        )
386        .is_err());
387    }
388
389    #[test]
390    fn test_request_new_rejects_invalid_http_version() {
391        assert!(Request::new(
392            b"GET".to_vec(),
393            Headers::new(vec![(b"Host".to_vec(), b"example.com".to_vec())]).unwrap(),
394            b"/".to_vec(),
395            b"HTTP/1.1".to_vec(),
396        )
397        .is_err());
398    }
399
400    #[test]
401    fn test_request_new_accepts_borrowed_inputs_and_http11_default() {
402        let request =
403            Request::new_http11("GET", Headers::new([("Host", "example.com")]).unwrap(), "/")
404                .unwrap();
405
406        assert_eq!(request.method, b"GET");
407        assert_eq!(request.target, b"/");
408        assert_eq!(request.http_version, b"1.1");
409    }
410
411    #[test]
412    fn test_response_new_accepts_borrowed_inputs_and_http11_default() {
413        let response =
414            Response::new_http11(200, Headers::new([("Content-Length", "0")]).unwrap(), "OK")
415                .unwrap();
416
417        assert_eq!(response.status_code, 200);
418        assert_eq!(response.reason, b"OK");
419        assert_eq!(response.http_version, b"1.1");
420    }
421
422    #[test]
423    fn test_response_range_checked_constructors() {
424        let informational =
425            Response::new_informational_http11(100, Headers::default(), "Continue").unwrap();
426        assert_eq!(informational.status_code, 100);
427
428        let final_response = Response::new_final_http11(200, Headers::default(), "OK").unwrap();
429        assert_eq!(final_response.status_code, 200);
430
431        assert!(Response::new_informational_http11(200, Headers::default(), "OK").is_err());
432        assert!(Response::new_final_http11(199, Headers::default(), "Early").is_err());
433    }
434
435    #[test]
436    fn test_event_response_constructors_validate_status_ranges() {
437        let informational =
438            Response::new_informational_http11(100, Headers::default(), "Continue").unwrap();
439        assert!(matches!(
440            Event::informational_response(informational).unwrap(),
441            Event::InformationalResponse(_)
442        ));
443
444        let final_response =
445            Response::new_final_http11(204, Headers::default(), "No Content").unwrap();
446        assert!(matches!(
447            Event::normal_response(final_response).unwrap(),
448            Event::NormalResponse(_)
449        ));
450
451        let informational =
452            Response::new_informational_http11(101, Headers::default(), "Switching Protocols")
453                .unwrap();
454        assert!(Event::normal_response(informational).is_err());
455
456        let final_response = Response::new_final_http11(200, Headers::default(), "OK").unwrap();
457        assert!(Event::informational_response(final_response).is_err());
458    }
459}