s2n_quic_core/packet/
mod.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::connection;
5use s2n_codec::{decoder_invariant, DecoderBufferMut, DecoderBufferMutResult, DecoderError};
6
7pub(crate) type Tag = u8;
8
9#[macro_use]
10pub mod short;
11#[macro_use]
12pub mod version_negotiation;
13#[macro_use]
14pub mod initial;
15#[macro_use]
16pub mod zero_rtt;
17#[macro_use]
18pub mod handshake;
19#[macro_use]
20pub mod retry;
21
22pub mod decoding;
23pub mod encoding;
24pub mod interceptor;
25pub mod key_phase;
26pub mod long;
27
28pub mod number;
29pub mod stateless_reset;
30
31#[cfg(test)]
32mod tests;
33
34pub use key_phase::{KeyPhase, ProtectedKeyPhase};
35
36use connection::id::ConnectionInfo;
37use handshake::ProtectedHandshake;
38use initial::ProtectedInitial;
39use retry::ProtectedRetry;
40use short::ProtectedShort;
41use version_negotiation::ProtectedVersionNegotiation;
42use zero_rtt::ProtectedZeroRtt;
43
44// === API ===
45
46pub type RemainingBuffer<'a> = Option<DecoderBufferMut<'a>>;
47
48#[derive(Debug)]
49pub enum ProtectedPacket<'a> {
50    Short(ProtectedShort<'a>),
51    VersionNegotiation(ProtectedVersionNegotiation<'a>),
52    Initial(ProtectedInitial<'a>),
53    ZeroRtt(ProtectedZeroRtt<'a>),
54    Handshake(ProtectedHandshake<'a>),
55    Retry(ProtectedRetry<'a>),
56}
57
58impl<'a> ProtectedPacket<'a> {
59    pub fn decode<Validator: connection::id::Validator>(
60        buffer: DecoderBufferMut<'a>,
61        connection_info: &ConnectionInfo,
62        connection_id_validator: &Validator,
63    ) -> DecoderBufferMutResult<'a, Self> {
64        BasicPacketDecoder.decode_packet(buffer, connection_info, connection_id_validator)
65    }
66
67    /// Returns the packet's destination connection ID
68    pub fn destination_connection_id(&self) -> &[u8] {
69        match self {
70            ProtectedPacket::Short(packet) => packet.destination_connection_id(),
71            ProtectedPacket::VersionNegotiation(packet) => packet.destination_connection_id(),
72            ProtectedPacket::Initial(packet) => packet.destination_connection_id(),
73            ProtectedPacket::ZeroRtt(packet) => packet.destination_connection_id(),
74            ProtectedPacket::Handshake(packet) => packet.destination_connection_id(),
75            ProtectedPacket::Retry(packet) => packet.destination_connection_id(),
76        }
77    }
78
79    /// Returns the packet's source connection ID
80    pub fn source_connection_id(&self) -> Option<&[u8]> {
81        match self {
82            ProtectedPacket::Short(_packet) => None,
83            ProtectedPacket::VersionNegotiation(packet) => Some(packet.source_connection_id()),
84            ProtectedPacket::Initial(packet) => Some(packet.source_connection_id()),
85            ProtectedPacket::ZeroRtt(packet) => Some(packet.source_connection_id()),
86            ProtectedPacket::Handshake(packet) => Some(packet.source_connection_id()),
87            ProtectedPacket::Retry(packet) => Some(packet.source_connection_id()),
88        }
89    }
90
91    pub fn version(&self) -> Option<u32> {
92        match self {
93            ProtectedPacket::Short(_) => None,
94            ProtectedPacket::VersionNegotiation(_) => None,
95            ProtectedPacket::Initial(packet) => Some(packet.version),
96            ProtectedPacket::ZeroRtt(packet) => Some(packet.version),
97            ProtectedPacket::Handshake(packet) => Some(packet.version),
98            ProtectedPacket::Retry(packet) => Some(packet.version),
99        }
100    }
101}
102
103#[derive(Debug)]
104pub enum CleartextPacket<'a> {
105    Short(short::CleartextShort<'a>),
106    VersionNegotiation(version_negotiation::CleartextVersionNegotiation<'a>),
107    Initial(initial::CleartextInitial<'a>),
108    ZeroRtt(zero_rtt::CleartextZeroRtt<'a>),
109    Handshake(handshake::CleartextHandshake<'a>),
110    Retry(retry::CleartextRetry<'a>),
111}
112
113struct BasicPacketDecoder;
114
115impl<'a> PacketDecoder<'a> for BasicPacketDecoder {
116    type Error = DecoderError;
117    type Output = ProtectedPacket<'a>;
118
119    fn handle_short_packet(
120        &mut self,
121        packet: ProtectedShort<'a>,
122    ) -> Result<Self::Output, DecoderError> {
123        Ok(ProtectedPacket::Short(packet))
124    }
125
126    fn handle_version_negotiation_packet(
127        &mut self,
128        packet: ProtectedVersionNegotiation<'a>,
129    ) -> Result<Self::Output, DecoderError> {
130        Ok(ProtectedPacket::VersionNegotiation(packet))
131    }
132
133    fn handle_initial_packet(
134        &mut self,
135        packet: ProtectedInitial<'a>,
136    ) -> Result<Self::Output, DecoderError> {
137        Ok(ProtectedPacket::Initial(packet))
138    }
139
140    fn handle_zero_rtt_packet(
141        &mut self,
142        packet: ProtectedZeroRtt<'a>,
143    ) -> Result<Self::Output, DecoderError> {
144        Ok(ProtectedPacket::ZeroRtt(packet))
145    }
146
147    fn handle_handshake_packet(
148        &mut self,
149        packet: ProtectedHandshake<'a>,
150    ) -> Result<Self::Output, DecoderError> {
151        Ok(ProtectedPacket::Handshake(packet))
152    }
153
154    fn handle_retry_packet(
155        &mut self,
156        packet: ProtectedRetry<'a>,
157    ) -> Result<Self::Output, DecoderError> {
158        Ok(ProtectedPacket::Retry(packet))
159    }
160}
161
162pub trait PacketDecoder<'a> {
163    type Output;
164    type Error: From<DecoderError>;
165
166    fn handle_short_packet(
167        &mut self,
168        packet: ProtectedShort<'a>,
169    ) -> Result<Self::Output, Self::Error>;
170
171    fn handle_version_negotiation_packet(
172        &mut self,
173        packet: ProtectedVersionNegotiation<'a>,
174    ) -> Result<Self::Output, Self::Error>;
175
176    fn handle_initial_packet(
177        &mut self,
178        packet: ProtectedInitial<'a>,
179    ) -> Result<Self::Output, Self::Error>;
180
181    fn handle_zero_rtt_packet(
182        &mut self,
183        packet: ProtectedZeroRtt<'a>,
184    ) -> Result<Self::Output, Self::Error>;
185
186    fn handle_handshake_packet(
187        &mut self,
188        packet: ProtectedHandshake<'a>,
189    ) -> Result<Self::Output, Self::Error>;
190
191    fn handle_retry_packet(
192        &mut self,
193        packet: ProtectedRetry<'a>,
194    ) -> Result<Self::Output, Self::Error>;
195
196    fn decode_packet<Validator: connection::id::Validator>(
197        &mut self,
198        buffer: DecoderBufferMut<'a>,
199        connection_info: &ConnectionInfo,
200        connection_id_validator: &Validator,
201    ) -> Result<(Self::Output, DecoderBufferMut<'a>), Self::Error> {
202        let peek = buffer.peek();
203
204        let (tag, peek) = peek.decode()?;
205
206        macro_rules! version_negotiation {
207            ($version:ident) => {{
208                let (packet, buffer) = ProtectedVersionNegotiation::decode(tag, $version, buffer)?;
209                let output = self.handle_version_negotiation_packet(packet)?;
210                Ok((output, buffer))
211            }};
212        }
213
214        macro_rules! long_packet {
215            ($struct:ident, $handler:ident) => {{
216                let (version, _peek) = peek.decode()?;
217                if version == version_negotiation::VERSION {
218                    version_negotiation!(version)
219                } else {
220                    let (packet, buffer) = $struct::decode(tag, version, buffer)?;
221                    let output = self.$handler(packet)?;
222                    Ok((output, buffer))
223                }
224            }};
225        }
226
227        match tag >> 4 {
228            short_tag!() => {
229                let (packet, buffer) = short::ProtectedShort::decode(
230                    tag,
231                    buffer,
232                    connection_info,
233                    connection_id_validator,
234                )?;
235                let output = self.handle_short_packet(packet)?;
236                Ok((output, buffer))
237            }
238            version_negotiation_no_fixed_bit_tag!() => {
239                let (version, _peek) = peek.decode()?;
240                decoder_invariant!(
241                    version_negotiation::VERSION == version,
242                    "invalid version negotiation packet"
243                );
244                version_negotiation!(version)
245            }
246            initial_tag!() => long_packet!(ProtectedInitial, handle_initial_packet),
247            zero_rtt_tag!() => long_packet!(ProtectedZeroRtt, handle_zero_rtt_packet),
248            handshake_tag!() => long_packet!(ProtectedHandshake, handle_handshake_packet),
249            retry_tag!() => long_packet!(ProtectedRetry, handle_retry_packet),
250            _ => Err(DecoderError::InvariantViolation("invalid packet").into()),
251        }
252    }
253}
254
255#[cfg(test)]
256mod snapshots {
257    use super::*;
258
259    macro_rules! snapshot {
260        ($name:ident) => {
261            #[test]
262            fn $name() {
263                s2n_codec::assert_codec_round_trip_sample_file!(
264                    crate::packet::ProtectedPacket,
265                    concat!("src/packet/test_samples/", stringify!($name), ".bin"),
266                    |buffer| {
267                        let remote_address = crate::inet::ip::SocketAddress::default();
268                        let connection_info =
269                            crate::connection::id::ConnectionInfo::new(&remote_address);
270                        crate::packet::ProtectedPacket::decode(
271                            buffer,
272                            &connection_info,
273                            &long::DESTINATION_CONNECTION_ID_MAX_LEN,
274                        )
275                        .unwrap()
276                    }
277                );
278            }
279        };
280    }
281
282    snapshot!(short);
283    snapshot!(initial);
284    snapshot!(zero_rtt);
285    snapshot!(handshake);
286    snapshot!(retry);
287    snapshot!(version_negotiation);
288}