s2n_quic_core/packet/
decoding.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    connection,
6    connection::id::ConnectionInfo,
7    crypto::ProtectedPayload,
8    packet::{
9        long::{
10            validate_destination_connection_id_range, validate_source_connection_id_range,
11            DestinationConnectionIdLen, SourceConnectionIdLen, Version,
12        },
13        number::ProtectedPacketNumber,
14        Tag,
15    },
16    varint::VarInt,
17};
18use core::mem::size_of;
19use s2n_codec::{CheckedRange, DecoderBuffer, DecoderBufferMut, DecoderError, DecoderValue};
20
21pub struct HeaderDecoder<'a> {
22    initial_buffer_len: usize,
23    peek: DecoderBuffer<'a>,
24}
25
26impl<'a> HeaderDecoder<'a> {
27    pub fn new_long<'b>(buffer: &'a DecoderBufferMut<'b>) -> Self {
28        let initial_buffer_len = buffer.len();
29        let peek = buffer.peek();
30        let peek = peek
31            .skip(size_of::<Tag>() + size_of::<Version>())
32            .expect("tag and version already verified");
33        Self {
34            initial_buffer_len,
35            peek,
36        }
37    }
38
39    pub fn new_short<'b>(buffer: &'a DecoderBufferMut<'b>) -> Self {
40        let initial_buffer_len = buffer.len();
41        let peek = buffer.peek();
42        let peek = peek.skip(size_of::<Tag>()).expect("tag already verified");
43        Self {
44            initial_buffer_len,
45            peek,
46        }
47    }
48
49    pub fn decode_destination_connection_id(
50        &mut self,
51        buffer: &DecoderBufferMut<'_>,
52    ) -> Result<CheckedRange, DecoderError> {
53        let destination_connection_id =
54            self.decode_checked_range::<DestinationConnectionIdLen>(buffer)?;
55        validate_destination_connection_id_range(&destination_connection_id)?;
56        Ok(destination_connection_id)
57    }
58
59    pub fn decode_short_destination_connection_id<Validator: connection::id::Validator>(
60        &mut self,
61        buffer: &DecoderBufferMut<'_>,
62        connection_info: &ConnectionInfo,
63        connection_id_validator: &Validator,
64    ) -> Result<CheckedRange, DecoderError> {
65        let destination_connection_id_len = if let Some(len) = connection_id_validator
66            .validate(connection_info, self.peek.peek().into_less_safe_slice())
67        {
68            len
69        } else {
70            return Err(DecoderError::InvariantViolation("invalid connection id"));
71        };
72
73        let (destination_connection_id, peek) = self
74            .peek
75            .skip_into_range(destination_connection_id_len, buffer)?;
76        self.peek = peek;
77        validate_destination_connection_id_range(&destination_connection_id)?;
78        Ok(destination_connection_id)
79    }
80
81    pub fn decode_source_connection_id(
82        &mut self,
83        buffer: &DecoderBufferMut<'_>,
84    ) -> Result<CheckedRange, DecoderError> {
85        let source_connection_id = self.decode_checked_range::<SourceConnectionIdLen>(buffer)?;
86        validate_source_connection_id_range(&source_connection_id)?;
87        Ok(source_connection_id)
88    }
89
90    pub fn decode_checked_range<Len: DecoderValue<'a> + TryInto<usize>>(
91        &mut self,
92        buffer: &DecoderBufferMut<'_>,
93    ) -> Result<CheckedRange, DecoderError> {
94        let (value, peek) = self.peek.skip_into_range_with_len_prefix::<Len>(buffer)?;
95        self.peek = peek;
96        Ok(value)
97    }
98
99    pub fn finish_long(mut self) -> Result<HeaderDecoderResult, DecoderError> {
100        let (payload_len, peek) = self.peek.decode::<VarInt>()?;
101        self.peek = peek;
102        let header_len = self.decoded_len();
103
104        self.peek = peek.skip(*payload_len as usize)?;
105        let packet_len = self.decoded_len();
106
107        Ok(HeaderDecoderResult {
108            packet_len,
109            header_len,
110        })
111    }
112
113    pub fn finish_short(self) -> Result<HeaderDecoderResult, DecoderError> {
114        let header_len = self.decoded_len();
115        let packet_len = self.initial_buffer_len;
116
117        Ok(HeaderDecoderResult {
118            packet_len,
119            header_len,
120        })
121    }
122
123    pub fn decoded_len(&self) -> usize {
124        self.initial_buffer_len - self.peek.len()
125    }
126}
127
128#[derive(Debug)]
129pub struct HeaderDecoderResult {
130    pub packet_len: usize,
131    pub header_len: usize,
132}
133
134impl HeaderDecoderResult {
135    pub fn split_off_packet<'a>(
136        &self,
137        buffer: DecoderBufferMut<'a>,
138    ) -> Result<
139        (
140            ProtectedPayload<'a>,
141            ProtectedPacketNumber,
142            DecoderBufferMut<'a>,
143        ),
144        DecoderError,
145    > {
146        let (payload, remaining) = buffer.decode_slice(self.packet_len)?;
147        let packet_number = ProtectedPacketNumber;
148        let payload = ProtectedPayload::new(self.header_len, payload.into_less_safe_slice());
149
150        Ok((payload, packet_number, remaining))
151    }
152}