zerortt_api/
acceptor.rs

1use std::{
2    net::SocketAddr,
3    time::{Duration, SystemTime},
4};
5
6use boring::sha::Sha256;
7use quiche::{ConnectionId, Header, RecvInfo};
8
9use crate::{Error, Result, random_conn_id};
10
11/// Address validation trait.
12pub trait AddressValidator {
13    /// Create a retry-token.
14    fn mint_retry_token(
15        &self,
16        scid: &ConnectionId<'_>,
17        dcid: &ConnectionId<'_>,
18        new_scid: &ConnectionId<'_>,
19        src: &SocketAddr,
20    ) -> Result<Vec<u8>>;
21
22    /// Validate the source address.
23    fn validate_address<'a>(
24        &self,
25        scid: &ConnectionId<'_>,
26        dcid: &ConnectionId<'_>,
27        src: &SocketAddr,
28        token: &'a [u8],
29    ) -> Option<ConnectionId<'a>>;
30}
31
32/// A default implementation for [`AddressValidator`]
33pub struct SimpleAddressValidator([u8; 20], Duration);
34
35impl SimpleAddressValidator {
36    /// Create a new `SimpleAddressValidator` instance with token expiration interval.
37    pub fn new(expiration_interval: Duration) -> Self {
38        let mut seed = [0; 20];
39        boring::rand::rand_bytes(&mut seed).unwrap();
40        Self(seed, expiration_interval)
41    }
42}
43
44impl AddressValidator for SimpleAddressValidator {
45    fn mint_retry_token(
46        &self,
47        _scid: &ConnectionId<'_>,
48        dcid: &ConnectionId<'_>,
49        new_scid: &ConnectionId<'_>,
50        src: &SocketAddr,
51    ) -> Result<Vec<u8>> {
52        let mut token = vec![];
53        // ip
54        match src.ip() {
55            std::net::IpAddr::V4(ipv4_addr) => token.extend_from_slice(&ipv4_addr.octets()),
56            std::net::IpAddr::V6(ipv6_addr) => token.extend_from_slice(&ipv6_addr.octets()),
57        };
58
59        let timestamp = SystemTime::now()
60            .duration_since(SystemTime::UNIX_EPOCH)
61            .unwrap()
62            .as_secs();
63
64        // timestamp
65        token.extend_from_slice(&timestamp.to_be_bytes());
66        // odcid
67        token.extend_from_slice(dcid);
68
69        // sha256
70        let mut hasher = Sha256::new();
71        // seed
72        hasher.update(&self.0);
73        // ip + timestamp + odcid
74        hasher.update(&token);
75        // new_scid
76        hasher.update(&new_scid);
77
78        token.extend_from_slice(&hasher.finish());
79
80        Ok(token)
81    }
82
83    fn validate_address<'a>(
84        &self,
85        _: &ConnectionId<'_>,
86        dcid: &ConnectionId<'_>,
87        src: &SocketAddr,
88        token: &'a [u8],
89    ) -> Option<ConnectionId<'a>> {
90        let addr = match src.ip() {
91            std::net::IpAddr::V4(a) => a.octets().to_vec(),
92            std::net::IpAddr::V6(a) => a.octets().to_vec(),
93        };
94
95        // token length is too short.
96        if addr.len() + 40 > token.len() {
97            return None;
98        }
99
100        // invalid address.
101        if addr != &token[..addr.len()] {
102            return None;
103        }
104
105        let timestamp = Duration::from_secs(u64::from_be_bytes(
106            token[addr.len()..addr.len() + 8].try_into().unwrap(),
107        ));
108        let now = SystemTime::now()
109            .duration_since(SystemTime::UNIX_EPOCH)
110            .unwrap();
111
112        // timeout
113        if now - timestamp > self.1 {
114            return None;
115        }
116
117        let sha256 = &token[token.len() - 32..];
118
119        // sha256
120        let mut hasher = Sha256::new();
121        // seed
122        hasher.update(&self.0);
123        // ip + timestamp + odcid
124        hasher.update(&token[..token.len() - 32]);
125        // new_scid
126        hasher.update(&dcid);
127
128        // sha256 check error.
129        if sha256 != hasher.finish() {
130            return None;
131        }
132
133        Some(ConnectionId::from_ref(
134            &token[addr.len() + 8..token.len() - 32],
135        ))
136    }
137}
138
139/// Handshake result, returns by [`handshake`](Acceptor::handshake) function.
140pub enum Handshake {
141    Handshake(usize),
142    Accept(quiche::Connection),
143}
144
145/// Accept new inbound `quic` connection.
146pub struct Acceptor {
147    /// configuration for quiche `Connection`.
148    config: quiche::Config,
149    /// Algorithem for address validation.
150    address_validator: Box<dyn AddressValidator + Send>,
151}
152
153impl Acceptor {
154    /// Create a new acceptor with custom `quiche::Config` and `AddressValidator`
155    pub fn new<A: AddressValidator + Send + 'static>(
156        config: quiche::Config,
157        address_validator: A,
158    ) -> Self {
159        Self {
160            config,
161            address_validator: Box::new(address_validator),
162        }
163    }
164
165    /// Process quic handshake
166    pub fn handshake(
167        &mut self,
168        header: &Header<'_>,
169        buf: &mut [u8],
170        read_size: usize,
171        recv_info: RecvInfo,
172    ) -> Result<Handshake> {
173        // send Version negotiation packet.
174        if !quiche::version_is_supported(header.version) {
175            return self.negotiate_version(header, buf, read_size, recv_info);
176        }
177
178        // Safety: present in `Initial` packet.
179        let token = header.token.as_ref().unwrap();
180
181        // send retry packet.
182        if token.is_empty() {
183            return self.retry(header, buf, read_size, recv_info);
184        }
185
186        let odcid = match self.address_validator.validate_address(
187            &header.scid,
188            &header.dcid,
189            &recv_info.from,
190            token,
191        ) {
192            Some(odcid) => odcid,
193            None => {
194                log::error!(
195                    "failed to validate address, from={:?}, to={}, scid={:?}, dcid={:?}",
196                    recv_info.from,
197                    recv_info.to,
198                    header.scid,
199                    header.dcid
200                );
201                return Err(Error::ValidateAddress);
202            }
203        };
204
205        let quiche_conn = match quiche::accept(
206            &header.dcid,
207            Some(&odcid),
208            recv_info.to,
209            recv_info.from,
210            &mut self.config,
211        ) {
212            Ok(conn) => {
213                log::trace!(
214                    "QuicServer(initial) accept new conn, from={:?}, to={}, scid={:?}, dcid={:?}, odcid={:?}",
215                    recv_info.from,
216                    recv_info.to,
217                    header.scid,
218                    header.dcid,
219                    odcid
220                );
221                conn
222            }
223            Err(err) => {
224                log::error!(
225                    "failed to accept connection, from={:?}, to={}, scid={:?}, dcid={:?}, err={}",
226                    recv_info.from,
227                    recv_info.to,
228                    header.scid,
229                    header.dcid,
230                    err
231                );
232                return Err(Error::Quiche(err));
233            }
234        };
235
236        Ok(Handshake::Accept(quiche_conn))
237    }
238
239    fn retry(
240        &self,
241        header: &Header<'_>,
242        buf: &mut [u8],
243        _recv_size: usize,
244        recv_info: RecvInfo,
245    ) -> Result<Handshake> {
246        let new_scid = random_conn_id();
247
248        log::trace!(
249            "retry, from={:?}, to={}, scid={:?}, dcid={:?}, new_scid={:?}",
250            recv_info.from,
251            recv_info.to,
252            header.scid,
253            header.dcid,
254            new_scid
255        );
256
257        let token = self.address_validator.mint_retry_token(
258            &header.scid,
259            &header.dcid,
260            &new_scid,
261            &recv_info.from,
262        )?;
263
264        let send_size = match quiche::retry(
265            &header.scid,
266            &header.dcid,
267            &new_scid,
268            &token,
269            header.version,
270            buf,
271        ) {
272            Ok(send_size) => send_size,
273            Err(err) => {
274                log::error!(
275                    "failed to generate retry packet, from={:?}, to={}, scid={:?}, dcid={:?}, err={}",
276                    recv_info.from,
277                    recv_info.to,
278                    header.scid,
279                    header.dcid,
280                    err
281                );
282                return Err(Error::Quiche(err));
283            }
284        };
285
286        Ok(Handshake::Handshake(send_size))
287    }
288
289    fn negotiate_version(
290        &self,
291        header: &Header<'_>,
292        buf: &mut [u8],
293        _recv_size: usize,
294        recv_info: RecvInfo,
295    ) -> Result<Handshake> {
296        log::trace!(
297            "negotiate_version, from={:?}, to={}, scid={:?}, dcid={:?}",
298            recv_info.from,
299            recv_info.to,
300            header.scid,
301            header.dcid
302        );
303
304        let send_size = match quiche::negotiate_version(&header.scid, &header.dcid, buf) {
305            Ok(send_size) => send_size,
306            Err(err) => {
307                log::error!(
308                    "failed to generate negotiation_version packet, from={:?}, to={}, scid={:?}, dcid={:?}, err={}",
309                    recv_info.from,
310                    recv_info.to,
311                    header.scid,
312                    header.dcid,
313                    err
314                );
315                return Err(Error::Quiche(err));
316            }
317        };
318
319        Ok(Handshake::Handshake(send_size))
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use std::{net::SocketAddr, thread::sleep, time::Duration};
326
327    use super::*;
328
329    #[test]
330    fn test_default_address_validator() {
331        let _validator = SimpleAddressValidator::new(Duration::from_secs(100));
332
333        let scid = random_conn_id();
334        let dcid = random_conn_id();
335        let new_scid = random_conn_id();
336
337        let src: SocketAddr = "127.0.0.1:1234".parse().unwrap();
338
339        let token = _validator
340            .mint_retry_token(&scid, &dcid, &new_scid, &src)
341            .unwrap();
342
343        assert_eq!(
344            _validator.validate_address(&scid, &new_scid, &src, &token),
345            Some(dcid.clone())
346        );
347
348        assert_eq!(
349            _validator.validate_address(&scid, &dcid, &src, &token),
350            None
351        );
352
353        assert_eq!(
354            _validator.validate_address(&scid, &new_scid, &src, &token),
355            Some(dcid.clone())
356        );
357
358        let src: SocketAddr = "0.0.0.0:1234".parse().unwrap();
359
360        assert_eq!(
361            _validator.validate_address(&scid, &new_scid, &src, &token),
362            None
363        );
364
365        let _validator = SimpleAddressValidator::new(Duration::from_secs(1));
366
367        let token = _validator
368            .mint_retry_token(&scid, &dcid, &new_scid, &src)
369            .unwrap();
370
371        assert_eq!(
372            _validator.validate_address(&scid, &new_scid, &src, &token),
373            Some(dcid.clone())
374        );
375
376        sleep(Duration::from_secs(2));
377
378        assert_eq!(
379            _validator.validate_address(&scid, &new_scid, &src, &token),
380            None
381        );
382
383        // timeout.
384    }
385}