1use crate::{
5 connection,
6 connection::{id::ConnectionInfo, ProcessingError},
7 crypto::{packet_protection, EncryptedPayload, OneRttHeaderKey, OneRttKey, ProtectedPayload},
8 packet::{
9 decoding::HeaderDecoder,
10 encoding::{PacketEncoder, PacketPayloadEncoder},
11 number::{
12 PacketNumber, PacketNumberLen, PacketNumberSpace, ProtectedPacketNumber,
13 TruncatedPacketNumber,
14 },
15 KeyPhase, ProtectedKeyPhase, Tag,
16 },
17 transport,
18};
19use s2n_codec::{CheckedRange, DecoderBufferMut, DecoderBufferMutResult, Encoder, EncoderValue};
20
21macro_rules! short_tag {
41 () => {
42 0b0100u8..=0b0111u8
43 };
44}
45
46const ENCODING_TAG: u8 = 0b0100_0000;
47
48const SPIN_BIT_MASK: u8 = 0x20;
53
54const RESERVED_BITS_MASK: u8 = 0x18;
59
60#[derive(Clone, Copy, Debug, PartialEq)]
61pub enum SpinBit {
62 Zero,
63 One,
64}
65
66impl Default for SpinBit {
67 fn default() -> Self {
68 Self::Zero
69 }
70}
71
72impl SpinBit {
73 fn from_tag(tag: Tag) -> Self {
74 if tag & SPIN_BIT_MASK == SPIN_BIT_MASK {
75 Self::One
76 } else {
77 Self::Zero
78 }
79 }
80
81 fn into_packet_tag_mask(self) -> u8 {
82 match self {
83 Self::One => SPIN_BIT_MASK,
84 Self::Zero => 0,
85 }
86 }
87}
88
89#[derive(Debug)]
117pub struct Short<DCID, KeyPhase, PacketNumber, Payload> {
118 pub spin_bit: SpinBit,
119 pub key_phase: KeyPhase,
120 pub destination_connection_id: DCID,
121 pub packet_number: PacketNumber,
122 pub payload: Payload,
123}
124
125pub type ProtectedShort<'a> =
126 Short<CheckedRange, ProtectedKeyPhase, ProtectedPacketNumber, ProtectedPayload<'a>>;
127pub type EncryptedShort<'a> = Short<CheckedRange, KeyPhase, PacketNumber, EncryptedPayload<'a>>;
128pub type CleartextShort<'a> = Short<&'a [u8], KeyPhase, PacketNumber, DecoderBufferMut<'a>>;
129
130impl<'a> ProtectedShort<'a> {
131 #[inline]
132 pub(crate) fn decode<Validator: connection::id::Validator>(
133 tag: Tag,
134 buffer: DecoderBufferMut<'a>,
135 connection_info: &ConnectionInfo,
136 destination_connection_id_decoder: &Validator,
137 ) -> DecoderBufferMutResult<'a, ProtectedShort<'a>> {
138 let mut decoder = HeaderDecoder::new_short(&buffer);
139
140 let spin_bit = SpinBit::from_tag(tag);
141 let key_phase = ProtectedKeyPhase;
142
143 let destination_connection_id = decoder.decode_short_destination_connection_id(
144 &buffer,
145 connection_info,
146 destination_connection_id_decoder,
147 )?;
148
149 let (payload, packet_number, remaining) =
150 decoder.finish_short()?.split_off_packet(buffer)?;
151
152 let packet = Short {
153 spin_bit,
154 key_phase,
155 destination_connection_id,
156 packet_number,
157 payload,
158 };
159
160 Ok((packet, remaining))
161 }
162
163 pub fn unprotect<H: OneRttHeaderKey>(
164 self,
165 header_key: &H,
166 largest_acknowledged_packet_number: PacketNumber,
167 ) -> Result<EncryptedShort<'a>, packet_protection::Error> {
168 let Short {
169 spin_bit,
170 destination_connection_id,
171 payload,
172 ..
173 } = self;
174
175 let (truncated_packet_number, payload) =
176 crate::crypto::unprotect(header_key, PacketNumberSpace::ApplicationData, payload)?;
177
178 let key_phase = KeyPhase::from_tag(payload.get_tag());
179
180 let packet_number = truncated_packet_number.expand(largest_acknowledged_packet_number);
181
182 Ok(Short {
183 spin_bit,
184 key_phase,
185 destination_connection_id,
186 packet_number,
187 payload,
188 })
189 }
190
191 #[inline]
192 pub fn destination_connection_id(&self) -> &[u8] {
193 self.payload
194 .get_checked_range(&self.destination_connection_id)
195 .into_less_safe_slice()
196 }
197}
198
199impl<'a> EncryptedShort<'a> {
200 pub fn decrypt<C: OneRttKey>(self, crypto: &C) -> Result<CleartextShort<'a>, ProcessingError> {
201 let Short {
202 spin_bit,
203 key_phase,
204 destination_connection_id,
205 packet_number,
206 payload,
207 } = self;
208
209 let (header, payload) = crate::crypto::decrypt(crypto, packet_number, payload)?;
210
211 let header = header.into_less_safe_slice();
212
213 if header[0] & RESERVED_BITS_MASK != 0 {
219 return Err(transport::Error::PROTOCOL_VIOLATION
220 .with_reason("reserved bits are non-zero")
221 .into());
222 }
223
224 let destination_connection_id = destination_connection_id.get(header);
225
226 Ok(Short {
227 spin_bit,
228 key_phase,
229 destination_connection_id,
230 packet_number,
231 payload,
232 })
233 }
234
235 #[inline]
236 pub fn key_phase(&self) -> KeyPhase {
237 self.key_phase
238 }
239
240 #[inline]
241 pub fn destination_connection_id(&self) -> &[u8] {
242 self.payload
243 .get_checked_range(&self.destination_connection_id)
244 .into_less_safe_slice()
245 }
246}
247
248impl CleartextShort<'_> {
249 #[inline]
250 pub fn destination_connection_id(&self) -> &[u8] {
251 self.destination_connection_id
252 }
253}
254
255impl<DCID: EncoderValue, Payload: EncoderValue> EncoderValue
256 for Short<DCID, KeyPhase, TruncatedPacketNumber, Payload>
257{
258 #[inline]
259 fn encode<E: Encoder>(&self, encoder: &mut E) {
260 self.encode_header(self.packet_number.len(), encoder);
261 self.packet_number.encode(encoder);
262 self.payload.encode(encoder);
263 }
264}
265
266impl<DCID: EncoderValue, PacketNumber, Payload> Short<DCID, KeyPhase, PacketNumber, Payload> {
267 #[inline]
268 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
269 (ENCODING_TAG
270 | self.spin_bit.into_packet_tag_mask()
271 | self.key_phase.into_packet_tag_mask()
272 | packet_number_len.into_packet_tag_mask())
273 .encode(encoder);
274
275 self.destination_connection_id.encode(encoder);
276 }
277}
278
279impl<DCID: EncoderValue, Payload: PacketPayloadEncoder, K: OneRttKey, H: OneRttHeaderKey>
280 PacketEncoder<K, H, Payload> for Short<DCID, KeyPhase, PacketNumber, Payload>
281{
282 type PayloadLenCursor = ();
283
284 #[inline]
285 fn packet_number(&self) -> PacketNumber {
286 self.packet_number
287 }
288
289 #[inline]
290 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
291 Short::encode_header(self, packet_number_len, encoder);
292 }
293
294 #[inline]
295 fn payload(&mut self) -> &mut Payload {
296 &mut self.payload
297 }
298}