1use std::{
2 net::SocketAddr,
3 time::{Duration, SystemTime},
4};
5
6use boring::sha::Sha256;
7use quiche::{ConnectionId, Header, RecvInfo};
8
9use crate::{Error, Result, random_conn_id};
10
11pub trait AddressValidator {
13 fn mint_retry_token(
15 &self,
16 scid: &ConnectionId<'_>,
17 dcid: &ConnectionId<'_>,
18 new_scid: &ConnectionId<'_>,
19 src: &SocketAddr,
20 ) -> Result<Vec<u8>>;
21
22 fn validate_address<'a>(
24 &self,
25 scid: &ConnectionId<'_>,
26 dcid: &ConnectionId<'_>,
27 src: &SocketAddr,
28 token: &'a [u8],
29 ) -> Option<ConnectionId<'a>>;
30}
31
32pub struct SimpleAddressValidator([u8; 20], Duration);
34
35impl SimpleAddressValidator {
36 pub fn new(expiration_interval: Duration) -> Self {
38 let mut seed = [0; 20];
39 boring::rand::rand_bytes(&mut seed).unwrap();
40 Self(seed, expiration_interval)
41 }
42}
43
44impl AddressValidator for SimpleAddressValidator {
45 fn mint_retry_token(
46 &self,
47 _scid: &ConnectionId<'_>,
48 dcid: &ConnectionId<'_>,
49 new_scid: &ConnectionId<'_>,
50 src: &SocketAddr,
51 ) -> Result<Vec<u8>> {
52 let mut token = vec![];
53 match src.ip() {
55 std::net::IpAddr::V4(ipv4_addr) => token.extend_from_slice(&ipv4_addr.octets()),
56 std::net::IpAddr::V6(ipv6_addr) => token.extend_from_slice(&ipv6_addr.octets()),
57 };
58
59 let timestamp = SystemTime::now()
60 .duration_since(SystemTime::UNIX_EPOCH)
61 .unwrap()
62 .as_secs();
63
64 token.extend_from_slice(×tamp.to_be_bytes());
66 token.extend_from_slice(dcid);
68
69 let mut hasher = Sha256::new();
71 hasher.update(&self.0);
73 hasher.update(&token);
75 hasher.update(&new_scid);
77
78 token.extend_from_slice(&hasher.finish());
79
80 Ok(token)
81 }
82
83 fn validate_address<'a>(
84 &self,
85 _: &ConnectionId<'_>,
86 dcid: &ConnectionId<'_>,
87 src: &SocketAddr,
88 token: &'a [u8],
89 ) -> Option<ConnectionId<'a>> {
90 let addr = match src.ip() {
91 std::net::IpAddr::V4(a) => a.octets().to_vec(),
92 std::net::IpAddr::V6(a) => a.octets().to_vec(),
93 };
94
95 if addr.len() + 40 > token.len() {
97 return None;
98 }
99
100 if addr != &token[..addr.len()] {
102 return None;
103 }
104
105 let timestamp = Duration::from_secs(u64::from_be_bytes(
106 token[addr.len()..addr.len() + 8].try_into().unwrap(),
107 ));
108 let now = SystemTime::now()
109 .duration_since(SystemTime::UNIX_EPOCH)
110 .unwrap();
111
112 if now - timestamp > self.1 {
114 return None;
115 }
116
117 let sha256 = &token[token.len() - 32..];
118
119 let mut hasher = Sha256::new();
121 hasher.update(&self.0);
123 hasher.update(&token[..token.len() - 32]);
125 hasher.update(&dcid);
127
128 if sha256 != hasher.finish() {
130 return None;
131 }
132
133 Some(ConnectionId::from_ref(
134 &token[addr.len() + 8..token.len() - 32],
135 ))
136 }
137}
138
139pub enum Handshake {
141 Handshake(usize),
142 Accept(quiche::Connection),
143}
144
145pub struct Acceptor {
147 config: quiche::Config,
149 address_validator: Box<dyn AddressValidator + Send>,
151}
152
153impl Acceptor {
154 pub fn new<A: AddressValidator + Send + 'static>(
156 config: quiche::Config,
157 address_validator: A,
158 ) -> Self {
159 Self {
160 config,
161 address_validator: Box::new(address_validator),
162 }
163 }
164
165 pub fn handshake(
167 &mut self,
168 header: &Header<'_>,
169 buf: &mut [u8],
170 read_size: usize,
171 recv_info: RecvInfo,
172 ) -> Result<Handshake> {
173 if !quiche::version_is_supported(header.version) {
175 return self.negotiate_version(header, buf, read_size, recv_info);
176 }
177
178 let token = header.token.as_ref().unwrap();
180
181 if token.is_empty() {
183 return self.retry(header, buf, read_size, recv_info);
184 }
185
186 let odcid = match self.address_validator.validate_address(
187 &header.scid,
188 &header.dcid,
189 &recv_info.from,
190 token,
191 ) {
192 Some(odcid) => odcid,
193 None => {
194 log::error!(
195 "failed to validate address, from={:?}, to={}, scid={:?}, dcid={:?}",
196 recv_info.from,
197 recv_info.to,
198 header.scid,
199 header.dcid
200 );
201 return Err(Error::ValidateAddress);
202 }
203 };
204
205 let quiche_conn = match quiche::accept(
206 &header.dcid,
207 Some(&odcid),
208 recv_info.to,
209 recv_info.from,
210 &mut self.config,
211 ) {
212 Ok(conn) => {
213 log::trace!(
214 "QuicServer(initial) accept new conn, from={:?}, to={}, scid={:?}, dcid={:?}, odcid={:?}",
215 recv_info.from,
216 recv_info.to,
217 header.scid,
218 header.dcid,
219 odcid
220 );
221 conn
222 }
223 Err(err) => {
224 log::error!(
225 "failed to accept connection, from={:?}, to={}, scid={:?}, dcid={:?}, err={}",
226 recv_info.from,
227 recv_info.to,
228 header.scid,
229 header.dcid,
230 err
231 );
232 return Err(Error::Quiche(err));
233 }
234 };
235
236 Ok(Handshake::Accept(quiche_conn))
237 }
238
239 fn retry(
240 &self,
241 header: &Header<'_>,
242 buf: &mut [u8],
243 _recv_size: usize,
244 recv_info: RecvInfo,
245 ) -> Result<Handshake> {
246 let new_scid = random_conn_id();
247
248 log::trace!(
249 "retry, from={:?}, to={}, scid={:?}, dcid={:?}, new_scid={:?}",
250 recv_info.from,
251 recv_info.to,
252 header.scid,
253 header.dcid,
254 new_scid
255 );
256
257 let token = self.address_validator.mint_retry_token(
258 &header.scid,
259 &header.dcid,
260 &new_scid,
261 &recv_info.from,
262 )?;
263
264 let send_size = match quiche::retry(
265 &header.scid,
266 &header.dcid,
267 &new_scid,
268 &token,
269 header.version,
270 buf,
271 ) {
272 Ok(send_size) => send_size,
273 Err(err) => {
274 log::error!(
275 "failed to generate retry packet, from={:?}, to={}, scid={:?}, dcid={:?}, err={}",
276 recv_info.from,
277 recv_info.to,
278 header.scid,
279 header.dcid,
280 err
281 );
282 return Err(Error::Quiche(err));
283 }
284 };
285
286 Ok(Handshake::Handshake(send_size))
287 }
288
289 fn negotiate_version(
290 &self,
291 header: &Header<'_>,
292 buf: &mut [u8],
293 _recv_size: usize,
294 recv_info: RecvInfo,
295 ) -> Result<Handshake> {
296 log::trace!(
297 "negotiate_version, from={:?}, to={}, scid={:?}, dcid={:?}",
298 recv_info.from,
299 recv_info.to,
300 header.scid,
301 header.dcid
302 );
303
304 let send_size = match quiche::negotiate_version(&header.scid, &header.dcid, buf) {
305 Ok(send_size) => send_size,
306 Err(err) => {
307 log::error!(
308 "failed to generate negotiation_version packet, from={:?}, to={}, scid={:?}, dcid={:?}, err={}",
309 recv_info.from,
310 recv_info.to,
311 header.scid,
312 header.dcid,
313 err
314 );
315 return Err(Error::Quiche(err));
316 }
317 };
318
319 Ok(Handshake::Handshake(send_size))
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use std::{net::SocketAddr, thread::sleep, time::Duration};
326
327 use super::*;
328
329 #[test]
330 fn test_default_address_validator() {
331 let _validator = SimpleAddressValidator::new(Duration::from_secs(100));
332
333 let scid = random_conn_id();
334 let dcid = random_conn_id();
335 let new_scid = random_conn_id();
336
337 let src: SocketAddr = "127.0.0.1:1234".parse().unwrap();
338
339 let token = _validator
340 .mint_retry_token(&scid, &dcid, &new_scid, &src)
341 .unwrap();
342
343 assert_eq!(
344 _validator.validate_address(&scid, &new_scid, &src, &token),
345 Some(dcid.clone())
346 );
347
348 assert_eq!(
349 _validator.validate_address(&scid, &dcid, &src, &token),
350 None
351 );
352
353 assert_eq!(
354 _validator.validate_address(&scid, &new_scid, &src, &token),
355 Some(dcid.clone())
356 );
357
358 let src: SocketAddr = "0.0.0.0:1234".parse().unwrap();
359
360 assert_eq!(
361 _validator.validate_address(&scid, &new_scid, &src, &token),
362 None
363 );
364
365 let _validator = SimpleAddressValidator::new(Duration::from_secs(1));
366
367 let token = _validator
368 .mint_retry_token(&scid, &dcid, &new_scid, &src)
369 .unwrap();
370
371 assert_eq!(
372 _validator.validate_address(&scid, &new_scid, &src, &token),
373 Some(dcid.clone())
374 );
375
376 sleep(Duration::from_secs(2));
377
378 assert_eq!(
379 _validator.validate_address(&scid, &new_scid, &src, &token),
380 None
381 );
382
383 }
385}