1use crate::{
5 connection,
6 connection::{id::ConnectionInfo, ProcessingError},
7 crypto::{packet_protection, EncryptedPayload, OneRttHeaderKey, OneRttKey, ProtectedPayload},
8 packet::{
9 decoding::HeaderDecoder,
10 encoding::{PacketEncoder, PacketPayloadEncoder},
11 number::{
12 PacketNumber, PacketNumberLen, PacketNumberSpace, ProtectedPacketNumber,
13 TruncatedPacketNumber,
14 },
15 KeyPhase, ProtectedKeyPhase, Tag,
16 },
17 transport,
18};
19use s2n_codec::{CheckedRange, DecoderBufferMut, DecoderBufferMutResult, Encoder, EncoderValue};
20
21macro_rules! short_tag {
41 () => {
42 0b0100u8..=0b0111u8
43 };
44}
45
46const ENCODING_TAG: u8 = 0b0100_0000;
47
48const SPIN_BIT_MASK: u8 = 0x20;
53
54const RESERVED_BITS_MASK: u8 = 0x18;
59
60#[derive(Clone, Copy, Debug, PartialEq, Default)]
61pub enum SpinBit {
62 #[default]
63 Zero,
64 One,
65}
66
67impl SpinBit {
68 fn from_tag(tag: Tag) -> Self {
69 if tag & SPIN_BIT_MASK == SPIN_BIT_MASK {
70 Self::One
71 } else {
72 Self::Zero
73 }
74 }
75
76 fn into_packet_tag_mask(self) -> u8 {
77 match self {
78 Self::One => SPIN_BIT_MASK,
79 Self::Zero => 0,
80 }
81 }
82}
83
84#[derive(Debug)]
112pub struct Short<DCID, KeyPhase, PacketNumber, Payload> {
113 pub spin_bit: SpinBit,
114 pub key_phase: KeyPhase,
115 pub destination_connection_id: DCID,
116 pub packet_number: PacketNumber,
117 pub payload: Payload,
118}
119
120pub type ProtectedShort<'a> =
121 Short<CheckedRange, ProtectedKeyPhase, ProtectedPacketNumber, ProtectedPayload<'a>>;
122pub type EncryptedShort<'a> = Short<CheckedRange, KeyPhase, PacketNumber, EncryptedPayload<'a>>;
123pub type CleartextShort<'a> = Short<&'a [u8], KeyPhase, PacketNumber, DecoderBufferMut<'a>>;
124
125impl<'a> ProtectedShort<'a> {
126 #[inline]
127 pub(crate) fn decode<Validator: connection::id::Validator>(
128 tag: Tag,
129 buffer: DecoderBufferMut<'a>,
130 connection_info: &ConnectionInfo,
131 destination_connection_id_decoder: &Validator,
132 ) -> DecoderBufferMutResult<'a, ProtectedShort<'a>> {
133 let mut decoder = HeaderDecoder::new_short(&buffer);
134
135 let spin_bit = SpinBit::from_tag(tag);
136 let key_phase = ProtectedKeyPhase;
137
138 let destination_connection_id = decoder.decode_short_destination_connection_id(
139 &buffer,
140 connection_info,
141 destination_connection_id_decoder,
142 )?;
143
144 let (payload, packet_number, remaining) =
145 decoder.finish_short()?.split_off_packet(buffer)?;
146
147 let packet = Short {
148 spin_bit,
149 key_phase,
150 destination_connection_id,
151 packet_number,
152 payload,
153 };
154
155 Ok((packet, remaining))
156 }
157
158 pub fn unprotect<H: OneRttHeaderKey>(
159 self,
160 header_key: &H,
161 largest_acknowledged_packet_number: PacketNumber,
162 ) -> Result<EncryptedShort<'a>, packet_protection::Error> {
163 let Short {
164 spin_bit,
165 destination_connection_id,
166 payload,
167 ..
168 } = self;
169
170 let (truncated_packet_number, payload) =
171 crate::crypto::unprotect(header_key, PacketNumberSpace::ApplicationData, payload)?;
172
173 let key_phase = KeyPhase::from_tag(payload.get_tag());
174
175 let packet_number = truncated_packet_number.expand(largest_acknowledged_packet_number);
176
177 Ok(Short {
178 spin_bit,
179 key_phase,
180 destination_connection_id,
181 packet_number,
182 payload,
183 })
184 }
185
186 #[inline]
187 pub fn destination_connection_id(&self) -> &[u8] {
188 self.payload
189 .get_checked_range(&self.destination_connection_id)
190 .into_less_safe_slice()
191 }
192}
193
194impl<'a> EncryptedShort<'a> {
195 pub fn decrypt<C: OneRttKey>(self, crypto: &C) -> Result<CleartextShort<'a>, ProcessingError> {
196 let Short {
197 spin_bit,
198 key_phase,
199 destination_connection_id,
200 packet_number,
201 payload,
202 } = self;
203
204 let (header, payload) = crate::crypto::decrypt(crypto, packet_number, payload)?;
205
206 let header = header.into_less_safe_slice();
207
208 if header[0] & RESERVED_BITS_MASK != 0 {
214 return Err(transport::Error::PROTOCOL_VIOLATION
215 .with_reason("reserved bits are non-zero")
216 .into());
217 }
218
219 let destination_connection_id = destination_connection_id.get(header);
220
221 Ok(Short {
222 spin_bit,
223 key_phase,
224 destination_connection_id,
225 packet_number,
226 payload,
227 })
228 }
229
230 #[inline]
231 pub fn key_phase(&self) -> KeyPhase {
232 self.key_phase
233 }
234
235 #[inline]
236 pub fn destination_connection_id(&self) -> &[u8] {
237 self.payload
238 .get_checked_range(&self.destination_connection_id)
239 .into_less_safe_slice()
240 }
241}
242
243impl CleartextShort<'_> {
244 #[inline]
245 pub fn destination_connection_id(&self) -> &[u8] {
246 self.destination_connection_id
247 }
248}
249
250impl<DCID: EncoderValue, Payload: EncoderValue> EncoderValue
251 for Short<DCID, KeyPhase, TruncatedPacketNumber, Payload>
252{
253 #[inline]
254 fn encode<E: Encoder>(&self, encoder: &mut E) {
255 self.encode_header(self.packet_number.len(), encoder);
256 self.packet_number.encode(encoder);
257 self.payload.encode(encoder);
258 }
259}
260
261impl<DCID: EncoderValue, PacketNumber, Payload> Short<DCID, KeyPhase, PacketNumber, Payload> {
262 #[inline]
263 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
264 (ENCODING_TAG
265 | self.spin_bit.into_packet_tag_mask()
266 | self.key_phase.into_packet_tag_mask()
267 | packet_number_len.into_packet_tag_mask())
268 .encode(encoder);
269
270 self.destination_connection_id.encode(encoder);
271 }
272}
273
274impl<DCID: EncoderValue, Payload: PacketPayloadEncoder, K: OneRttKey, H: OneRttHeaderKey>
275 PacketEncoder<K, H, Payload> for Short<DCID, KeyPhase, PacketNumber, Payload>
276{
277 type PayloadLenCursor = ();
278
279 #[inline]
280 fn packet_number(&self) -> PacketNumber {
281 self.packet_number
282 }
283
284 #[inline]
285 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
286 Short::encode_header(self, packet_number_len, encoder);
287 }
288
289 #[inline]
290 fn payload(&mut self) -> &mut Payload {
291 &mut self.payload
292 }
293}