1use crate::{
5 crypto::{packet_protection, EncryptedPayload, ProtectedPayload, ZeroRttHeaderKey, ZeroRttKey},
6 packet::{
7 decoding::HeaderDecoder,
8 encoding::{PacketEncoder, PacketPayloadEncoder},
9 long::{
10 DestinationConnectionIdLen, LongPayloadEncoder, LongPayloadLenCursor,
11 SourceConnectionIdLen, Version,
12 },
13 number::{
14 PacketNumber, PacketNumberLen, PacketNumberSpace, ProtectedPacketNumber,
15 TruncatedPacketNumber,
16 },
17 Tag,
18 },
19 varint::VarInt,
20};
21use s2n_codec::{CheckedRange, DecoderBufferMut, DecoderBufferMutResult, Encoder, EncoderValue};
22
23macro_rules! zero_rtt_tag {
43 () => {
44 0b1101u8
45 };
46}
47
48#[derive(Debug)]
49pub struct ZeroRtt<DCID, SCID, PacketNumber, Payload> {
50 pub version: Version,
51 pub destination_connection_id: DCID,
52 pub source_connection_id: SCID,
53 pub packet_number: PacketNumber,
54 pub payload: Payload,
55}
56
57pub type ProtectedZeroRtt<'a> =
58 ZeroRtt<CheckedRange, CheckedRange, ProtectedPacketNumber, ProtectedPayload<'a>>;
59pub type EncryptedZeroRtt<'a> =
60 ZeroRtt<CheckedRange, CheckedRange, PacketNumber, EncryptedPayload<'a>>;
61pub type CleartextZeroRtt<'a> = ZeroRtt<&'a [u8], &'a [u8], PacketNumber, DecoderBufferMut<'a>>;
62
63impl<'a> ProtectedZeroRtt<'a> {
64 #[inline]
65 pub(crate) fn decode(
66 _tag: Tag,
67 version: Version,
68 buffer: DecoderBufferMut,
69 ) -> DecoderBufferMutResult<ProtectedZeroRtt> {
70 let mut decoder = HeaderDecoder::new_long(&buffer);
71
72 let destination_connection_id = decoder.decode_destination_connection_id(&buffer)?;
76 let source_connection_id = decoder.decode_source_connection_id(&buffer)?;
77
78 let (payload, packet_number, remaining) =
79 decoder.finish_long()?.split_off_packet(buffer)?;
80
81 let packet = ZeroRtt {
82 version,
83 destination_connection_id,
84 source_connection_id,
85 packet_number,
86 payload,
87 };
88
89 Ok((packet, remaining))
90 }
91
92 pub fn unprotect<H: ZeroRttHeaderKey>(
93 self,
94 header_key: &H,
95 largest_acknowledged_packet_number: PacketNumber,
96 ) -> Result<EncryptedZeroRtt<'a>, packet_protection::Error> {
97 let ZeroRtt {
98 version,
99 destination_connection_id,
100 source_connection_id,
101 payload,
102 ..
103 } = self;
104
105 let (truncated_packet_number, payload) =
106 crate::crypto::unprotect(header_key, PacketNumberSpace::ApplicationData, payload)?;
107
108 let packet_number = truncated_packet_number.expand(largest_acknowledged_packet_number);
109
110 Ok(ZeroRtt {
111 version,
112 destination_connection_id,
113 source_connection_id,
114 packet_number,
115 payload,
116 })
117 }
118
119 #[inline]
120 pub fn destination_connection_id(&self) -> &[u8] {
121 self.payload
122 .get_checked_range(&self.destination_connection_id)
123 .into_less_safe_slice()
124 }
125
126 #[inline]
127 pub fn source_connection_id(&self) -> &[u8] {
128 self.payload
129 .get_checked_range(&self.source_connection_id)
130 .into_less_safe_slice()
131 }
132}
133
134impl<'a> EncryptedZeroRtt<'a> {
135 pub fn decrypt<C: ZeroRttKey>(
136 self,
137 crypto: &C,
138 ) -> Result<CleartextZeroRtt<'a>, packet_protection::Error> {
139 let ZeroRtt {
140 version,
141 destination_connection_id,
142 source_connection_id,
143 packet_number,
144 payload,
145 } = self;
146
147 let (header, payload) = crate::crypto::decrypt(crypto, packet_number, payload)?;
148
149 let header = header.into_less_safe_slice();
150
151 let destination_connection_id = destination_connection_id.get(header);
152 let source_connection_id = source_connection_id.get(header);
153
154 Ok(ZeroRtt {
155 version,
156 destination_connection_id,
157 source_connection_id,
158 packet_number,
159 payload,
160 })
161 }
162
163 #[inline]
164 pub fn destination_connection_id(&self) -> &[u8] {
165 self.payload
166 .get_checked_range(&self.destination_connection_id)
167 .into_less_safe_slice()
168 }
169
170 #[inline]
171 pub fn source_connection_id(&self) -> &[u8] {
172 self.payload
173 .get_checked_range(&self.source_connection_id)
174 .into_less_safe_slice()
175 }
176}
177
178impl CleartextZeroRtt<'_> {
179 #[inline]
180 pub fn destination_connection_id(&self) -> &[u8] {
181 self.destination_connection_id
182 }
183
184 #[inline]
185 pub fn source_connection_id(&self) -> &[u8] {
186 self.source_connection_id
187 }
188}
189
190impl<DCID: EncoderValue, SCID: EncoderValue, Payload: EncoderValue> EncoderValue
191 for ZeroRtt<DCID, SCID, TruncatedPacketNumber, Payload>
192{
193 fn encode<E: Encoder>(&self, encoder: &mut E) {
194 self.encode_header(self.packet_number.len(), encoder);
195 LongPayloadEncoder {
196 packet_number: self.packet_number,
197 payload: &self.payload,
198 }
199 .encode_with_len_prefix::<VarInt, E>(encoder)
200 }
201}
202
203impl<DCID: EncoderValue, SCID: EncoderValue, PacketNumber, Payload>
204 ZeroRtt<DCID, SCID, PacketNumber, Payload>
205{
206 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
207 let mut tag: u8 = zero_rtt_tag!() << 4;
208 tag |= packet_number_len.into_packet_tag_mask();
209 tag.encode(encoder);
210
211 self.version.encode(encoder);
212 self.destination_connection_id
213 .encode_with_len_prefix::<DestinationConnectionIdLen, E>(encoder);
214 self.source_connection_id
215 .encode_with_len_prefix::<SourceConnectionIdLen, E>(encoder);
216 }
217}
218
219impl<
220 DCID: EncoderValue,
221 SCID: EncoderValue,
222 Payload: PacketPayloadEncoder,
223 K: ZeroRttKey,
224 H: ZeroRttHeaderKey,
225 > PacketEncoder<K, H, Payload> for ZeroRtt<DCID, SCID, PacketNumber, Payload>
226{
227 type PayloadLenCursor = LongPayloadLenCursor;
228
229 fn packet_number(&self) -> PacketNumber {
230 self.packet_number
231 }
232
233 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
234 ZeroRtt::encode_header(self, packet_number_len, encoder);
235 }
236
237 fn payload(&mut self) -> &mut Payload {
238 &mut self.payload
239 }
240}