1use crate::{
5 connection::ProcessingError,
6 crypto::{
7 packet_protection, EncryptedPayload, HandshakeHeaderKey, HandshakeKey, ProtectedPayload,
8 },
9 packet::{
10 decoding::HeaderDecoder,
11 encoding::{PacketEncoder, PacketPayloadEncoder},
12 long::{
13 DestinationConnectionIdLen, LongPayloadEncoder, LongPayloadLenCursor,
14 SourceConnectionIdLen, Version,
15 },
16 number::{
17 PacketNumber, PacketNumberLen, PacketNumberSpace, ProtectedPacketNumber,
18 TruncatedPacketNumber,
19 },
20 KeyPhase, Tag,
21 },
22 transport,
23 varint::VarInt,
24};
25use s2n_codec::{CheckedRange, DecoderBufferMut, DecoderBufferMutResult, Encoder, EncoderValue};
26
27macro_rules! handshake_tag {
47 () => {
48 0b1110u8
49 };
50}
51
52#[derive(Debug)]
53pub struct Handshake<DCID, SCID, PacketNumber, Payload> {
54 pub version: Version,
55 pub destination_connection_id: DCID,
56 pub source_connection_id: SCID,
57 pub packet_number: PacketNumber,
58 pub payload: Payload,
59}
60
61pub type ProtectedHandshake<'a> =
62 Handshake<CheckedRange, CheckedRange, ProtectedPacketNumber, ProtectedPayload<'a>>;
63pub type EncryptedHandshake<'a> =
64 Handshake<CheckedRange, CheckedRange, PacketNumber, EncryptedPayload<'a>>;
65pub type CleartextHandshake<'a> = Handshake<&'a [u8], &'a [u8], PacketNumber, DecoderBufferMut<'a>>;
66
67impl<'a> ProtectedHandshake<'a> {
68 #[inline]
69 pub(crate) fn decode(
70 _tag: Tag,
71 version: Version,
72 buffer: DecoderBufferMut,
73 ) -> DecoderBufferMutResult<ProtectedHandshake> {
74 let mut decoder = HeaderDecoder::new_long(&buffer);
75
76 let destination_connection_id = decoder.decode_destination_connection_id(&buffer)?;
80 let source_connection_id = decoder.decode_source_connection_id(&buffer)?;
81
82 let (payload, packet_number, remaining) =
83 decoder.finish_long()?.split_off_packet(buffer)?;
84
85 let packet = Handshake {
86 version,
87 destination_connection_id,
88 source_connection_id,
89 packet_number,
90 payload,
91 };
92
93 Ok((packet, remaining))
94 }
95
96 pub fn unprotect<K: HandshakeHeaderKey>(
97 self,
98 key: &K,
99 largest_acknowledged_packet_number: PacketNumber,
100 ) -> Result<EncryptedHandshake<'a>, packet_protection::Error> {
101 let Handshake {
102 version,
103 destination_connection_id,
104 source_connection_id,
105 payload,
106 ..
107 } = self;
108
109 let (truncated_packet_number, payload) =
110 crate::crypto::unprotect(key, PacketNumberSpace::Handshake, payload)?;
111
112 let packet_number = truncated_packet_number.expand(largest_acknowledged_packet_number);
113
114 Ok(Handshake {
115 version,
116 destination_connection_id,
117 source_connection_id,
118 packet_number,
119 payload,
120 })
121 }
122
123 #[inline]
124 pub fn destination_connection_id(&self) -> &[u8] {
125 self.payload
126 .get_checked_range(&self.destination_connection_id)
127 .into_less_safe_slice()
128 }
129
130 #[inline]
131 pub fn source_connection_id(&self) -> &[u8] {
132 self.payload
133 .get_checked_range(&self.source_connection_id)
134 .into_less_safe_slice()
135 }
136}
137
138impl<'a> EncryptedHandshake<'a> {
139 pub fn decrypt<C: HandshakeKey>(
140 self,
141 crypto: &C,
142 ) -> Result<CleartextHandshake<'a>, ProcessingError> {
143 let Handshake {
144 version,
145 destination_connection_id,
146 source_connection_id,
147 packet_number,
148 payload,
149 } = self;
150
151 let (header, payload) = crate::crypto::decrypt(crypto, packet_number, payload)?;
152
153 let header = header.into_less_safe_slice();
154
155 if header[0] & super::long::RESERVED_BITS_MASK != 0 {
162 return Err(transport::Error::PROTOCOL_VIOLATION
163 .with_reason("reserved bits are non-zero")
164 .into());
165 }
166
167 let destination_connection_id = destination_connection_id.get(header);
168 let source_connection_id = source_connection_id.get(header);
169
170 Ok(Handshake {
171 version,
172 destination_connection_id,
173 source_connection_id,
174 packet_number,
175 payload,
176 })
177 }
178
179 #[inline]
180 pub fn destination_connection_id(&self) -> &[u8] {
181 self.payload
182 .get_checked_range(&self.destination_connection_id)
183 .into_less_safe_slice()
184 }
185
186 #[inline]
187 pub fn source_connection_id(&self) -> &[u8] {
188 self.payload
189 .get_checked_range(&self.source_connection_id)
190 .into_less_safe_slice()
191 }
192
193 #[inline]
195 pub fn key_phase(&self) -> KeyPhase {
196 KeyPhase::Zero
197 }
198}
199
200impl CleartextHandshake<'_> {
201 #[inline]
202 pub fn destination_connection_id(&self) -> &[u8] {
203 self.destination_connection_id
204 }
205
206 #[inline]
207 pub fn source_connection_id(&self) -> &[u8] {
208 self.source_connection_id
209 }
210}
211
212impl<DCID: EncoderValue, SCID: EncoderValue, Payload: EncoderValue> EncoderValue
213 for Handshake<DCID, SCID, TruncatedPacketNumber, Payload>
214{
215 fn encode<E: Encoder>(&self, encoder: &mut E) {
216 self.encode_header(self.packet_number.len(), encoder);
217 LongPayloadEncoder {
218 packet_number: self.packet_number,
219 payload: &self.payload,
220 }
221 .encode_with_len_prefix::<VarInt, E>(encoder)
222 }
223}
224
225impl<DCID: EncoderValue, SCID: EncoderValue, PacketNumber, Payload>
226 Handshake<DCID, SCID, PacketNumber, Payload>
227{
228 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
229 let mut tag: u8 = handshake_tag!() << 4;
230 tag |= packet_number_len.into_packet_tag_mask();
231 tag.encode(encoder);
232
233 self.version.encode(encoder);
234 self.destination_connection_id
235 .encode_with_len_prefix::<DestinationConnectionIdLen, E>(encoder);
236 self.source_connection_id
237 .encode_with_len_prefix::<SourceConnectionIdLen, E>(encoder);
238 }
239}
240
241impl<
242 DCID: EncoderValue,
243 SCID: EncoderValue,
244 Payload: PacketPayloadEncoder,
245 K: HandshakeKey,
246 H: HandshakeHeaderKey,
247 > PacketEncoder<K, H, Payload> for Handshake<DCID, SCID, PacketNumber, Payload>
248{
249 type PayloadLenCursor = LongPayloadLenCursor;
250
251 fn packet_number(&self) -> PacketNumber {
252 self.packet_number
253 }
254
255 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
256 Handshake::encode_header(self, packet_number_len, encoder);
257 }
258
259 fn payload(&mut self) -> &mut Payload {
260 &mut self.payload
261 }
262}